Compare commits
334 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
3ff5aac8e0 | ||
|
|
67fef60466 | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
5567302316 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c96d653072 | ||
|
|
8b22d2b5d3 | ||
|
|
4cb544ee38 | ||
|
|
f94ce63d51 | ||
|
|
4271ff9d84 | ||
|
|
0d448c4a41 | ||
|
|
af5599e33c | ||
|
|
efdf6d917a | ||
|
|
dd71ac8d71 | ||
|
|
8bee1d4100 | ||
|
|
33521d6d00 | ||
|
|
8899734952 | ||
|
|
54df6310c5 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 | ||
|
|
5d09586853 | ||
|
|
a7cba078dd | ||
|
|
b3e9ee96fa | ||
|
|
8537a6b17e | ||
|
|
7c8d7dc5c2 | ||
|
|
8e23d663e6 | ||
|
|
8a3994bf80 | ||
|
|
8375f601ba | ||
|
|
c87c0fe662 | ||
|
|
73927b68ef | ||
|
|
cc1a62e5aa | ||
|
|
802020cb41 | ||
|
|
cdb92f7cf4 | ||
|
|
dc69bdec00 | ||
|
|
98073e9868 | ||
|
|
cf2ef48967 | ||
|
|
0692bbf7a2 | ||
|
|
52584a171f | ||
|
|
efd6b5324b | ||
|
|
2baaa4549b | ||
|
|
35310ddd52 | ||
|
|
fc9c5cb39d | ||
|
|
8f2a1e87ea | ||
|
|
50caf65f28 | ||
|
|
1b48794ca8 | ||
|
|
4aef1d814e | ||
|
|
75ddcd6158 | ||
|
|
2a4df11f5c | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 | ||
|
|
a97d3ada1c | ||
|
|
6217bb5638 | ||
|
|
2760e99e18 | ||
|
|
0544f96b79 | ||
|
|
2ebb29de65 | ||
|
|
43762d44c7 | ||
|
|
cdaf0c98be | ||
|
|
aa9a14a917 | ||
|
|
9efcc6d95c | ||
|
|
f3f5d91207 | ||
|
|
6070160959 | ||
|
|
43155d2811 | ||
|
|
d3f85678ec | ||
|
|
2a96d05b21 | ||
|
|
851e888535 | ||
|
|
90120d4dff | ||
|
|
8513471573 | ||
|
|
71e5f1774c | ||
|
|
870a443446 | ||
|
|
cefaa2a4cc | ||
|
|
ab72a2ab9d | ||
|
|
046d457d22 | ||
|
|
7fd0a30fee | ||
|
|
c2f35c8e73 | ||
|
|
573313f0b6 | ||
|
|
f7af6805fa | ||
|
|
966de3a399 | ||
|
|
8a75829f3a | ||
|
|
0f7e34b9e2 | ||
|
|
be0322b616 | ||
|
|
232a525a62 | ||
|
|
587ce65cf6 | ||
|
|
ccf6c8bfd7 | ||
|
|
c112956d2d | ||
|
|
b3970793cf | ||
|
|
727724990e | ||
|
|
530f6e4af5 | ||
|
|
2f224f5793 | ||
|
|
1b6272ce0e | ||
|
|
5259ace111 | ||
|
|
48ea5566e9 | ||
|
|
3f8b6c5bbd | ||
|
|
725b32e74f | ||
|
|
d0b71f393f | ||
|
|
8a92efdae3 | ||
|
|
019cdce2e8 | ||
|
|
b64aa54fac | ||
|
|
c0d040f9d4 | ||
|
|
32364320f8 | ||
|
|
34c71c072d | ||
|
|
6d2149c503 | ||
|
|
043b0bf69d | ||
|
|
9b07e392c6 | ||
|
|
e60fad8c73 | ||
|
|
19c1b182c3 | ||
|
|
49edea780c | ||
|
|
12ef5a1900 | ||
|
|
d21a134b2a | ||
|
|
1cd809aa41 | ||
|
|
e728449b8f | ||
|
|
d0c20b14d5 | ||
|
|
83b7ea5a59 | ||
|
|
0796a52df1 | ||
|
|
85b7ba0168 | ||
|
|
e117743d24 | ||
|
|
aec2291f04 | ||
|
|
335ae003ac | ||
|
|
71c7de9c84 | ||
|
|
1c5fec5565 | ||
|
|
99d439577d | ||
|
|
4f83086788 | ||
|
|
a13c527e39 | ||
|
|
90d9f27383 | ||
|
|
0db81c16cd | ||
|
|
e115e186b7 | ||
|
|
6546b29ef7 | ||
|
|
51255bdffa | ||
|
|
f77c4e38cb | ||
|
|
2a1a152073 | ||
|
|
7b9406a3ea | ||
|
|
c3fb949693 | ||
|
|
ed3f8dbfd6 | ||
|
|
42aa6db170 | ||
|
|
a6591d20ca | ||
|
|
c1bc2603a2 | ||
|
|
e595bbb5fb | ||
|
|
4a2cb914d7 | ||
|
|
b1c93fe178 | ||
|
|
0719458775 | ||
|
|
6a1dc895fb | ||
|
|
125c1f6f25 | ||
|
|
1ceaa7d709 | ||
|
|
dec3ee85fd | ||
|
|
d94a5176dc | ||
|
|
326783f7f1 | ||
|
|
e5a9ca8787 | ||
|
|
f2feccdbd0 | ||
|
|
246a077d64 | ||
|
|
3ba100ff25 | ||
|
|
1e3b571e72 | ||
|
|
b89e56e9c2 | ||
|
|
ed8a02e721 | ||
|
|
baa60b40d1 | ||
|
|
ef01d6997a | ||
|
|
3da5b44d7f | ||
|
|
8b4654921b | ||
|
|
cf1cbafa78 | ||
|
|
c96091744b | ||
|
|
711fb4a775 | ||
|
|
3b5a185e60 | ||
|
|
77ac013a74 | ||
|
|
b8e5728e6a | ||
|
|
d038319d8b | ||
|
|
c611d0f30f | ||
|
|
c17899662f | ||
|
|
c51d5320fa | ||
|
|
6fa9512a64 | ||
|
|
fddc61df5e | ||
|
|
53c58fa755 | ||
|
|
c69afb56e4 | ||
|
|
0fa8a9191f | ||
|
|
48dda1cb5b | ||
|
|
71ef4b7d4c | ||
|
|
ecab43e307 | ||
|
|
88ca09440d | ||
|
|
8e0ab4a28d | ||
|
|
9b8c5041dc | ||
|
|
74ffd7ec64 | ||
|
|
eb6f504789 | ||
|
|
91a026f38b | ||
|
|
595138a0a3 | ||
|
|
19df04095f | ||
|
|
8239bbb48f | ||
|
|
16ee9d0422 | ||
|
|
8a961f8ab3 | ||
|
|
558126c46e | ||
|
|
04c9684488 | ||
|
|
b744faa7e6 | ||
|
|
27b3a26e75 | ||
|
|
41d872504e | ||
|
|
963cd05273 | ||
|
|
09b6e67baf | ||
|
|
dafb2aacab | ||
|
|
a6c400cd4f | ||
|
|
c013e5ccce | ||
|
|
f25a1a3840 | ||
|
|
6497e17671 | ||
|
|
44369a8138 | ||
|
|
dfca00c21b | ||
|
|
637dab379e | ||
|
|
6fc57eb48e | ||
|
|
95a653993a | ||
|
|
af0959818d | ||
|
|
cf17c85607 | ||
|
|
a38bc0a3fc | ||
|
|
449983c937 | ||
|
|
df63526503 | ||
|
|
e92deee1e8 | ||
|
|
910927a405 | ||
|
|
0aa84e147b | ||
|
|
368474d036 | ||
|
|
a627abe794 | ||
|
|
44815ee7fd | ||
|
|
371e3de04e | ||
|
|
b81b5d0f86 | ||
|
|
ee507bfe7a | ||
|
|
30898814ae | ||
|
|
a075fd6f47 | ||
|
|
303ff6fe1d |
11
.github/workflows/build-and-publish.yml
vendored
Normal file
11
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
251
.github/workflows/build-reusable.yml
vendored
Normal file
251
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
name: Reusable Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: 'Git ref to build'
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint and Format Check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install ruff
|
||||||
|
run: |
|
||||||
|
uv tool install ruff
|
||||||
|
|
||||||
|
- name: Run ruff check
|
||||||
|
run: |
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
- name: Run ruff format check
|
||||||
|
run: |
|
||||||
|
ruff format --check .
|
||||||
|
|
||||||
|
build:
|
||||||
|
needs: lint
|
||||||
|
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.13'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||||
|
|
||||||
|
# Install Intel MKL for DiskANN
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install system dependencies (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Don't install LLVM, use system clang for better compatibility
|
||||||
|
brew install libomp boost protobuf zeromq
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||||
|
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||||
|
uv pip install --system auditwheel
|
||||||
|
else
|
||||||
|
uv pip install --system delocate
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Build packages
|
||||||
|
run: |
|
||||||
|
# Build core (platform independent)
|
||||||
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
|
cd packages/leann-core
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build HNSW backend
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
|
# Use system clang instead of homebrew LLVM for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=11.0
|
||||||
|
uv build --wheel --python python
|
||||||
|
else
|
||||||
|
uv build --wheel --python python
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build DiskANN backend
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
|
# Use system clang instead of homebrew LLVM for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
uv build --wheel --python python
|
||||||
|
else
|
||||||
|
uv build --wheel --python python
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build meta package (platform independent)
|
||||||
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
|
cd packages/leann
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Repair wheels (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: List built packages
|
||||||
|
run: |
|
||||||
|
echo "📦 Built packages:"
|
||||||
|
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||||
|
|
||||||
|
- name: Install built packages for testing
|
||||||
|
run: |
|
||||||
|
# Create a virtual environment
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Install the built wheels
|
||||||
|
# Use --find-links to let uv choose the correct wheel for the platform
|
||||||
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
|
uv pip install leann-core --find-links packages/leann-core/dist
|
||||||
|
uv pip install leann --find-links packages/leann/dist
|
||||||
|
fi
|
||||||
|
uv pip install leann-backend-hnsw --find-links packages/leann-backend-hnsw/dist
|
||||||
|
uv pip install leann-backend-diskann --find-links packages/leann-backend-diskann/dist
|
||||||
|
|
||||||
|
# Install test dependencies using extras
|
||||||
|
uv pip install -e ".[test]"
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
env:
|
||||||
|
CI: true # Mark as CI environment to skip memory-intensive tests
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
|
TOKENIZERS_PARALLELISM: false
|
||||||
|
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
||||||
|
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
||||||
|
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
||||||
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
pytest tests/
|
||||||
|
|
||||||
|
- name: Run sanity checks (optional)
|
||||||
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Run distance function tests if available
|
||||||
|
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||||
|
echo "Running distance function sanity checks..."
|
||||||
|
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
|
path: packages/*/dist/
|
||||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
name: Link Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 3 * * 1"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
link-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: lycheeverse/lychee-action@v2
|
||||||
|
with:
|
||||||
|
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
129
.github/workflows/release-manual.yml
vendored
Normal file
129
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: 'Version to release (e.g., 0.1.2)'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-version:
|
||||||
|
name: Update Version
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
outputs:
|
||||||
|
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Validate version
|
||||||
|
run: |
|
||||||
|
# Remove 'v' prefix if present for validation
|
||||||
|
VERSION_CLEAN="${{ inputs.version }}"
|
||||||
|
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
||||||
|
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||||
|
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
- name: Update versions and push
|
||||||
|
id: push
|
||||||
|
run: |
|
||||||
|
# Check current version
|
||||||
|
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||||
|
echo "Current version: $CURRENT_VERSION"
|
||||||
|
echo "Target version: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||||
|
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
else
|
||||||
|
./scripts/bump_version.sh ${{ inputs.version }}
|
||||||
|
git config user.name "GitHub Actions"
|
||||||
|
git config user.email "actions@github.com"
|
||||||
|
git add packages/*/pyproject.toml
|
||||||
|
git commit -m "chore: release v${{ inputs.version }}"
|
||||||
|
git push origin main
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
build-packages:
|
||||||
|
name: Build packages
|
||||||
|
needs: update-version
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish and Release
|
||||||
|
needs: [update-version, build-packages]
|
||||||
|
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: dist-artifacts
|
||||||
|
|
||||||
|
- name: Collect packages
|
||||||
|
run: |
|
||||||
|
mkdir -p dist
|
||||||
|
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||||
|
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||||
|
|
||||||
|
echo "📦 Packages to publish:"
|
||||||
|
ls -la dist/
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "$TWINE_PASSWORD" ]; then
|
||||||
|
echo "❌ PYPI_API_TOKEN not configured!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
pip install twine
|
||||||
|
twine upload dist/* --skip-existing --verbose
|
||||||
|
|
||||||
|
echo "✅ Published to PyPI!"
|
||||||
|
|
||||||
|
- name: Create release
|
||||||
|
run: |
|
||||||
|
# Check if tag already exists
|
||||||
|
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||||
|
else
|
||||||
|
git tag "v${{ inputs.version }}"
|
||||||
|
git push origin "v${{ inputs.version }}"
|
||||||
|
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if release already exists
|
||||||
|
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||||
|
else
|
||||||
|
gh release create "v${{ inputs.version }}" \
|
||||||
|
--title "Release v${{ inputs.version }}" \
|
||||||
|
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||||
|
--latest
|
||||||
|
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
24
.gitignore
vendored
24
.gitignore
vendored
@@ -8,11 +8,16 @@ demo/indices/
|
|||||||
*pycache*
|
*pycache*
|
||||||
outputs/
|
outputs/
|
||||||
*.pkl
|
*.pkl
|
||||||
|
*.pdf
|
||||||
|
*.idx
|
||||||
|
*.map
|
||||||
.history/
|
.history/
|
||||||
scripts/
|
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
|
*.eml
|
||||||
|
*.emlx
|
||||||
|
*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
@@ -29,6 +34,15 @@ build/
|
|||||||
nprobe_logs/
|
nprobe_logs/
|
||||||
micro/results
|
micro/results
|
||||||
micro/contriever-INT8
|
micro/contriever-INT8
|
||||||
|
data/*
|
||||||
|
!data/2501.14312v1 (1).pdf
|
||||||
|
!data/2506.08276v1.pdf
|
||||||
|
!data/PrideandPrejudice.txt
|
||||||
|
!data/huawei_pangu.md
|
||||||
|
!data/ground_truth/
|
||||||
|
!data/indices/
|
||||||
|
!data/queries/
|
||||||
|
!data/.gitattributes
|
||||||
*.qdstrm
|
*.qdstrm
|
||||||
benchmark_results/
|
benchmark_results/
|
||||||
results/
|
results/
|
||||||
@@ -41,6 +55,7 @@ embedding_comparison_results/
|
|||||||
*.ivecs
|
*.ivecs
|
||||||
*.index
|
*.index
|
||||||
*.bin
|
*.bin
|
||||||
|
*.old
|
||||||
|
|
||||||
read_graph
|
read_graph
|
||||||
analyze_diskann_graph
|
analyze_diskann_graph
|
||||||
@@ -70,3 +85,10 @@ test_indices*/
|
|||||||
test_*.py
|
test_*.py
|
||||||
!tests/**
|
!tests/**
|
||||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||||
|
|
||||||
|
*.meta.json
|
||||||
|
*.passages.json
|
||||||
|
|
||||||
|
batchtest.py
|
||||||
|
tests/__pytest_cache__/
|
||||||
|
tests/__pycache__/
|
||||||
|
|||||||
14
.gitmodules
vendored
14
.gitmodules
vendored
@@ -1,6 +1,16 @@
|
|||||||
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
|
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
|
||||||
path = packages/leann-backend-diskann/third_party/DiskANN
|
path = packages/leann-backend-diskann/third_party/DiskANN
|
||||||
url = https://github.com/yichuan520030910320/DiskANN.git
|
url = https://github.com/yichuan-w/DiskANN.git
|
||||||
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
||||||
path = packages/leann-backend-hnsw/third_party/faiss
|
path = packages/leann-backend-hnsw/third_party/faiss
|
||||||
url = https://github.com/yichuan520030910320/faiss.git
|
url = https://github.com/yichuan-w/faiss.git
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/msgpack-c
|
||||||
|
url = https://github.com/msgpack/msgpack-c.git
|
||||||
|
branch = cpp_master
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/cppzmq"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/cppzmq
|
||||||
|
url = https://github.com/zeromq/cppzmq.git
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
|||||||
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.5.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: debug-statements
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.2.1
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
- id: ruff-format
|
||||||
9
.vscode/extensions.json
vendored
9
.vscode/extensions.json
vendored
@@ -1,9 +0,0 @@
|
|||||||
{
|
|
||||||
"recommendations": [
|
|
||||||
"llvm-vs-code-extensions.vscode-clangd",
|
|
||||||
"ms-python.python",
|
|
||||||
"ms-vscode.cmake-tools",
|
|
||||||
"vadimcn.vscode-lldb",
|
|
||||||
"eamodio.gitlens",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
283
.vscode/launch.json
vendored
283
.vscode/launch.json
vendored
@@ -1,283 +0,0 @@
|
|||||||
{
|
|
||||||
// Use IntelliSense to learn about possible attributes.
|
|
||||||
// Hover to view descriptions of existing attributes.
|
|
||||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
|
||||||
"version": "0.2.0",
|
|
||||||
"configurations": [
|
|
||||||
// new emdedder
|
|
||||||
{
|
|
||||||
"name": "New Embedder",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/main.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"--search",
|
|
||||||
"--use-original",
|
|
||||||
"--domain",
|
|
||||||
"dpr",
|
|
||||||
"--nprobe",
|
|
||||||
"5000",
|
|
||||||
"--load",
|
|
||||||
"flat",
|
|
||||||
"--embedder",
|
|
||||||
"intfloat/multilingual-e5-small"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
|
|
||||||
{
|
|
||||||
"name": "main.py",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/main.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"--query",
|
|
||||||
"1000",
|
|
||||||
"--load",
|
|
||||||
"bm25"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Simple Build",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"faiss/demo/simple_build.py"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
//# Fix for Intel MKL error
|
|
||||||
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
|
|
||||||
//python faiss/demo/build_demo.py
|
|
||||||
{
|
|
||||||
"name": "Build Demo",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"faiss/demo/build_demo.py"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "DiskANN Serve",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"demo/main.py",
|
|
||||||
"--mode",
|
|
||||||
"serve",
|
|
||||||
"--engine",
|
|
||||||
"sglang",
|
|
||||||
"--load-indices",
|
|
||||||
"diskann",
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--lazy-load",
|
|
||||||
"--recompute-beighbor-embeddings",
|
|
||||||
"--port",
|
|
||||||
"8082",
|
|
||||||
"--diskann-search-memory-maximum",
|
|
||||||
"2",
|
|
||||||
"--diskann-graph",
|
|
||||||
"240",
|
|
||||||
"--search-only"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
|
|
||||||
},
|
|
||||||
"preLaunchTask": "CMake: build",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "DiskANN Serve MAC",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"demo/main.py",
|
|
||||||
"--mode",
|
|
||||||
"serve",
|
|
||||||
"--engine",
|
|
||||||
"ollama",
|
|
||||||
"--load-indices",
|
|
||||||
"diskann",
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--lazy-load",
|
|
||||||
"--recompute-beighbor-embeddings"
|
|
||||||
],
|
|
||||||
"preLaunchTask": "CMake: build",
|
|
||||||
"env": {
|
|
||||||
"KMP_DUPLICATE_LIB_OK": "TRUE",
|
|
||||||
"OMP_NUM_THREADS": "1",
|
|
||||||
"MKL_NUM_THREADS": "1",
|
|
||||||
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
|
|
||||||
"KMP_BLOCKTIME": "0"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Python Debugger: Current File with Arguments",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "ric/main_ric.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"--config-name",
|
|
||||||
"${input:configSelection}"
|
|
||||||
],
|
|
||||||
"justMyCode": false
|
|
||||||
},
|
|
||||||
//python ./demo/validate_equivalence.py sglang
|
|
||||||
{
|
|
||||||
"name": "Validate Equivalence",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/validate_equivalence.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"sglang"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
|
|
||||||
{
|
|
||||||
"name": "Retrieval Demo",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/retrieval_demo.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"--engine",
|
|
||||||
"vllm",
|
|
||||||
"--skip-embeddings",
|
|
||||||
"--domain",
|
|
||||||
"dpr",
|
|
||||||
"--load-indices",
|
|
||||||
// "flat",
|
|
||||||
"ivf_flat"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
|
|
||||||
{
|
|
||||||
"name": "Retrieval Demo DiskANN",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/retrieval_demo.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"--engine",
|
|
||||||
"sglang",
|
|
||||||
"--skip-embeddings",
|
|
||||||
"--domain",
|
|
||||||
"dpr",
|
|
||||||
"--load-indices",
|
|
||||||
"diskann",
|
|
||||||
"--hnsw-M",
|
|
||||||
"64",
|
|
||||||
"--hnsw-efConstruction",
|
|
||||||
"150",
|
|
||||||
"--hnsw-efSearch",
|
|
||||||
"128",
|
|
||||||
"--hnsw-sq-bits",
|
|
||||||
"8"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Find Probe",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "find_probe.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Python: Attach",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "attach",
|
|
||||||
"processId": "${command:pickProcess}",
|
|
||||||
"justMyCode": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Edge RAG",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"edgerag_demo.py"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
|
|
||||||
"MKL_NUM_THREADS": "1",
|
|
||||||
"OMP_NUM_THREADS": "1",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Launch Embedding Server",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/embedding_server.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--zmq-port",
|
|
||||||
"5556",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "HNSW Serve",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"demo/main.py",
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--load",
|
|
||||||
"hnsw",
|
|
||||||
"--mode",
|
|
||||||
"serve",
|
|
||||||
"--search",
|
|
||||||
"--skip-pa",
|
|
||||||
"--recompute",
|
|
||||||
"--hnsw-old"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"inputs": [
|
|
||||||
{
|
|
||||||
"id": "configSelection",
|
|
||||||
"type": "pickString",
|
|
||||||
"description": "Select a configuration",
|
|
||||||
"options": [
|
|
||||||
"example_config",
|
|
||||||
"vllm_gritlm"
|
|
||||||
],
|
|
||||||
"default": "example_config"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
43
.vscode/settings.json
vendored
43
.vscode/settings.json
vendored
@@ -1,43 +0,0 @@
|
|||||||
{
|
|
||||||
"python.analysis.extraPaths": [
|
|
||||||
"./sglang_repo/python"
|
|
||||||
],
|
|
||||||
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
|
|
||||||
"cmake.configureArgs": [
|
|
||||||
"-DPYBIND=True",
|
|
||||||
"-DUPDATE_EDITABLE_INSTALL=ON",
|
|
||||||
],
|
|
||||||
"cmake.environment": {
|
|
||||||
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
|
|
||||||
},
|
|
||||||
"cmake.buildDirectory": "${workspaceFolder}/build",
|
|
||||||
"files.associations": {
|
|
||||||
"*.tcc": "cpp",
|
|
||||||
"deque": "cpp",
|
|
||||||
"string": "cpp",
|
|
||||||
"unordered_map": "cpp",
|
|
||||||
"vector": "cpp",
|
|
||||||
"map": "cpp",
|
|
||||||
"unordered_set": "cpp",
|
|
||||||
"atomic": "cpp",
|
|
||||||
"inplace_vector": "cpp",
|
|
||||||
"*.ipp": "cpp",
|
|
||||||
"forward_list": "cpp",
|
|
||||||
"list": "cpp",
|
|
||||||
"any": "cpp",
|
|
||||||
"system_error": "cpp",
|
|
||||||
"__hash_table": "cpp",
|
|
||||||
"__split_buffer": "cpp",
|
|
||||||
"__tree": "cpp",
|
|
||||||
"ios": "cpp",
|
|
||||||
"set": "cpp",
|
|
||||||
"__string": "cpp",
|
|
||||||
"string_view": "cpp",
|
|
||||||
"ranges": "cpp",
|
|
||||||
"iosfwd": "cpp"
|
|
||||||
},
|
|
||||||
"lldb.displayFormat": "auto",
|
|
||||||
"lldb.showDisassembly": "auto",
|
|
||||||
"lldb.dereferencePointers": true,
|
|
||||||
"lldb.consoleMode": "commands",
|
|
||||||
}
|
|
||||||
16
.vscode/tasks.json
vendored
16
.vscode/tasks.json
vendored
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "2.0.0",
|
|
||||||
"tasks": [
|
|
||||||
{
|
|
||||||
"type": "cmake",
|
|
||||||
"label": "CMake: build",
|
|
||||||
"command": "build",
|
|
||||||
"targets": [
|
|
||||||
"all"
|
|
||||||
],
|
|
||||||
"group": "build",
|
|
||||||
"problemMatcher": [],
|
|
||||||
"detail": "CMake template build task"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2024 Rulin Shao
|
Copyright (c) 2025 LEANN Contributors
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
|||||||
735
README.md
735
README.md
@@ -1,171 +1,571 @@
|
|||||||
# 🚀 LEANN: A Low-Storage Vector Index
|
<p align="center">
|
||||||
|
<img src="assets/logo-text.png" alt="LEANN Logo" width="400">
|
||||||
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs Welcome">
|
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS%20%7C%20Windows-lightgrey" alt="Platform">
|
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue?style=flat-square" alt="MCP Integration">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
|
The smallest vector index in the world. RAG Everything with LEANN!
|
||||||
|
</h2>
|
||||||
|
|
||||||
|
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
|
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Why LEANN?
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>⚡ Real-time embedding computation for large-scale RAG on consumer hardware</strong>
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||||
<a href="#-quick-start">Quick Start</a> •
|
|
||||||
<a href="#-features">Features</a> •
|
|
||||||
<a href="#-benchmarks">Benchmarks</a> •
|
|
||||||
<a href="#-documentation">Documentation</a> •
|
|
||||||
<a href="#-paper">Paper</a>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🌟 What is Leann?
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
|
|
||||||
**Leann** revolutionizes Retrieval-Augmented Generation (RAG) by eliminating the storage bottleneck of traditional vector databases. Instead of pre-computing and storing billions of embeddings, Leann dynamically computes embeddings at query time using highly optimized graph-based search algorithms.
|
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||||
|
|
||||||
### 🎯 Why Leann?
|
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
|
||||||
|
|
||||||
Traditional RAG systems face a fundamental trade-off:
|
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||||
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space
|
|
||||||
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change
|
|
||||||
- **💰 Cost**: Vector databases are expensive to scale
|
|
||||||
|
|
||||||
**Leann solves this by:**
|
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||||
- ✅ **Zero embedding storage** - Only graph structure is persisted
|
|
||||||
- ✅ **Real-time computation** - Embeddings computed on-demand with ms latency
|
|
||||||
- ✅ **Memory efficient** - Runs on consumer hardware (8GB RAM)
|
|
||||||
- ✅ **Always fresh** - No stale embeddings, ever
|
|
||||||
|
|
||||||
## 🚀 Quick Start
|
## Installation
|
||||||
|
|
||||||
### Installation
|
### 📦 Prerequisites: Install uv
|
||||||
|
|
||||||
|
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/yichuan520030910320/Power-RAG.git leann
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🚀 Quick Install
|
||||||
|
|
||||||
|
Clone the repository to access all examples and try amazing applications,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
cd leann
|
cd leann
|
||||||
|
```
|
||||||
|
|
||||||
|
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install leann
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>
|
||||||
|
<strong>🔧 Build from Source (Recommended for development)</strong>
|
||||||
|
</summary>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
git submodule update --init --recursive
|
||||||
|
```
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux:**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
uv sync
|
uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
### 30-Second Example
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Our declarative API makes RAG as easy as writing a config file.
|
||||||
|
|
||||||
|
Check out [demo.ipynb](demo.ipynb) or [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from pathlib import Path
|
||||||
|
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||||
|
|
||||||
# 1. Build index (no embeddings stored!)
|
# Build an index
|
||||||
builder = LeannBuilder(backend_name="diskann")
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
builder.add_text("Python is a powerful programming language")
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
builder.add_text("Machine learning transforms industries")
|
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||||
builder.add_text("Neural networks process complex data")
|
builder.build_index(INDEX_PATH)
|
||||||
builder.build_index("knowledge.leann")
|
|
||||||
|
|
||||||
# 2. Search with real-time embeddings
|
# Search
|
||||||
searcher = LeannSearcher("knowledge.leann")
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
results = searcher.search("programming languages", top_k=2)
|
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||||
|
|
||||||
for result in results:
|
# Chat with your data
|
||||||
print(f"Score: {result['score']:.3f} - {result['text']}")
|
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||||
|
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run the Demo
|
## RAG on Everything!
|
||||||
|
|
||||||
|
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
|
### Generation Model Setup
|
||||||
|
|
||||||
|
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||||
|
|
||||||
|
Set your OpenAI API key as an environment variable:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run examples/document_search.py
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
```
|
```
|
||||||
|
|
||||||
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
|
</details>
|
||||||
|
|
||||||
This demo showcases how to build a RAG system for PDF documents using Leann.
|
<details>
|
||||||
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
|
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
|
||||||
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
|
|
||||||
|
**macOS:**
|
||||||
|
|
||||||
|
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run examples/main_cli_example.py
|
# Pull a lightweight model (recommended for consumer hardware)
|
||||||
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
## ✨ Features
|
**Linux:**
|
||||||
|
|
||||||
### 🔥 Core Features
|
```bash
|
||||||
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
|
# Install Ollama
|
||||||
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
|
curl -fsSL https://ollama.ai/install.sh | sh
|
||||||
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
|
|
||||||
- **📈 Scalable Architecture**: Handles millions of documents on consumer hardware
|
|
||||||
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
|
|
||||||
|
|
||||||
### 🛠️ Technical Highlights
|
# Start Ollama service manually
|
||||||
- **Zero-copy operations** for maximum performance
|
ollama serve &
|
||||||
- **SIMD-optimized** distance computations (AVX2/AVX512)
|
|
||||||
- **Async embedding pipeline** with batched processing
|
|
||||||
- **Memory-mapped indices** for fast startup
|
|
||||||
- **Recompute mode** for highest accuracy scenarios
|
|
||||||
|
|
||||||
### 🎨 Developer Experience
|
# Pull a lightweight model (recommended for consumer hardware)
|
||||||
- **Simple Python API** - Get started in minutes
|
ollama pull llama3.2:1b
|
||||||
- **Extensible backend system** - Easy to add new algorithms
|
|
||||||
- **Comprehensive examples** - From basic usage to production deployment
|
|
||||||
- **Rich debugging tools** - Built-in performance profiling
|
|
||||||
|
|
||||||
## 📊 Benchmarks
|
|
||||||
|
|
||||||
### Memory Usage Comparison
|
|
||||||
|
|
||||||
| System | 1M Documents | 10M Documents | 100M Documents |
|
|
||||||
|--------|-------------|---------------|----------------|
|
|
||||||
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
|
|
||||||
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
|
|
||||||
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
|
|
||||||
|
|
||||||
### Query Performance
|
|
||||||
|
|
||||||
| Backend | Index Size | Query Time | Recall@10 |
|
|
||||||
|---------|------------|------------|-----------|
|
|
||||||
| DiskANN | 1M docs | 12ms | 0.95 |
|
|
||||||
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
|
|
||||||
| HNSW | 1M docs | 8ms | 0.93 |
|
|
||||||
|
|
||||||
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
|
|
||||||
|
|
||||||
## 🏗️ Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
|
||||||
│ Query Text │───▶│ Embedding │───▶│ Graph-based │
|
|
||||||
│ │ │ Computation │ │ Search │
|
|
||||||
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
|
||||||
│ │
|
|
||||||
▼ ▼
|
|
||||||
┌──────────────┐ ┌──────────────┐
|
|
||||||
│ ZMQ Server │ │ Pruned Graph │
|
|
||||||
│ (Cached) │ │ Index │
|
|
||||||
└──────────────┘ └──────────────┘
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Key Components
|
</details>
|
||||||
|
|
||||||
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
|
### ⭐ Flexible Configuration
|
||||||
2. **📊 Graph Index**: Memory-efficient navigation structures
|
|
||||||
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
|
|
||||||
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
|
|
||||||
|
|
||||||
## 🎓 Supported Models & Backends
|
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||||
|
|
||||||
### 🤖 Embedding Models
|
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
|
||||||
- **sentence-transformers/all-mpnet-base-v2** (default)
|
|
||||||
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
|
|
||||||
- Any HuggingFace sentence-transformer model
|
|
||||||
- Custom model support via API
|
|
||||||
|
|
||||||
### 🔧 Search Backends
|
<details>
|
||||||
- **DiskANN**: Microsoft's billion-scale ANN algorithm
|
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
||||||
- **HNSW**: Hierarchical Navigable Small World graphs
|
|
||||||
- **Coming soon**: ScaNN, Faiss-IVF, NGT
|
|
||||||
|
|
||||||
### 📏 Distance Functions
|
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
|
||||||
- **L2**: Euclidean distance for precise similarity
|
|
||||||
- **Cosine**: Angular similarity for normalized vectors
|
|
||||||
- **MIPS**: Maximum Inner Product Search for recommendation systems
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Core Parameters (General preprocessing for all examples)
|
||||||
|
--index-dir DIR # Directory to store the index (default: current directory)
|
||||||
|
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
||||||
|
--max-items N # Limit data preprocessing (default: -1, process all data)
|
||||||
|
--force-rebuild # Force rebuild index even if it exists
|
||||||
|
|
||||||
|
# Embedding Parameters
|
||||||
|
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text, or mlx-community/multilingual-e5-base-mlx
|
||||||
|
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||||
|
|
||||||
|
# LLM Parameters (Text generation models)
|
||||||
|
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||||
|
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||||
|
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||||
|
|
||||||
|
# Search Parameters
|
||||||
|
--top-k N # Number of results to retrieve (default: 20)
|
||||||
|
--search-complexity N # Search complexity for graph traversal (default: 32)
|
||||||
|
|
||||||
|
# Chunking Parameters
|
||||||
|
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
||||||
|
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
||||||
|
|
||||||
|
# Index Building Parameters
|
||||||
|
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
||||||
|
--graph-degree N # Graph degree for index construction (default: 32)
|
||||||
|
--build-complexity N # Build complexity for index construction (default: 64)
|
||||||
|
--no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
|
||||||
|
--no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
|
||||||
|
|
||||||
|
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate # Don't forget to activate the virtual environment
|
||||||
|
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
```bash
|
||||||
|
--data-dir DIR # Directory containing documents to process (default: data)
|
||||||
|
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example Commands
|
||||||
|
```bash
|
||||||
|
# Process all documents with larger chunks for academic papers
|
||||||
|
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
||||||
|
|
||||||
|
# Filter only markdown and Python files with smaller chunks
|
||||||
|
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||||
|
|
||||||
|
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||||
|
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
||||||
|
```
|
||||||
|
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
```bash
|
||||||
|
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
|
||||||
|
--include-html # Include HTML content in processing (useful for newsletters)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example Commands
|
||||||
|
```bash
|
||||||
|
# Search work emails from a specific account
|
||||||
|
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
|
||||||
|
|
||||||
|
# Find all receipts and order confirmations (includes HTML)
|
||||||
|
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Example queries you can try</strong></summary>
|
||||||
|
|
||||||
|
Once the index is built, you can ask questions like:
|
||||||
|
- "Find emails from my boss about deadlines"
|
||||||
|
- "What did John say about the project timeline?"
|
||||||
|
- "Show me emails about travel expenses"
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
|
||||||
|
```
|
||||||
|
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
```bash
|
||||||
|
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example Commands
|
||||||
|
```bash
|
||||||
|
# Search academic research from your browsing history
|
||||||
|
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
|
||||||
|
|
||||||
|
# Track competitor analysis across work profile
|
||||||
|
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: How to find your Chrome profile</strong></summary>
|
||||||
|
|
||||||
|
The default Chrome profile path is configured for a typical macOS setup. If you need to find your specific Chrome profile:
|
||||||
|
|
||||||
|
1. Open Terminal
|
||||||
|
2. Run: `ls ~/Library/Application\ Support/Google/Chrome/`
|
||||||
|
3. Look for folders like "Default", "Profile 1", "Profile 2", etc.
|
||||||
|
4. Use the full path as your `--chrome-profile` argument
|
||||||
|
|
||||||
|
**Common Chrome profile locations:**
|
||||||
|
- macOS: `~/Library/Application Support/Google/Chrome/Default`
|
||||||
|
- Linux: `~/.config/google-chrome/Default`
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
|
||||||
|
|
||||||
|
Once the index is built, you can ask questions like:
|
||||||
|
|
||||||
|
- "What websites did I visit about machine learning?"
|
||||||
|
- "Find my search history about programming"
|
||||||
|
- "What YouTube videos did I watch recently?"
|
||||||
|
- "Show me websites I visited about travel planning"
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
||||||
|
```
|
||||||
|
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||||
|
|
||||||
|
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
|
||||||
|
|
||||||
|
```bash
|
||||||
|
brew install sunnyyoung/repo/wechattweak-cli
|
||||||
|
```
|
||||||
|
|
||||||
|
or install it manually (if you have issues with Homebrew):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo packages/wechat-exporter/wechattweak-cli install
|
||||||
|
```
|
||||||
|
|
||||||
|
**Troubleshooting:**
|
||||||
|
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||||
|
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||||
|
```bash
|
||||||
|
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||||
|
Failed to find or export WeChat data. Exiting.
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
```bash
|
||||||
|
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
|
||||||
|
--force-export # Force re-export even if data exists
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example Commands
|
||||||
|
```bash
|
||||||
|
# Search for travel plans discussed in group chats
|
||||||
|
python -m apps.wechat_rag --query "travel plans" --max-items 10000
|
||||||
|
|
||||||
|
# Re-export and search recent chats (useful after new messages)
|
||||||
|
python -m apps.wechat_rag --force-export --query "work schedule"
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
|
||||||
|
|
||||||
|
Once the index is built, you can ask questions like:
|
||||||
|
|
||||||
|
- "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?" (Chinese: Show me chat records about buying Magic Johnson's jersey)
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||||
|
|
||||||
|
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||||
|
|
||||||
|
**Key features:**
|
||||||
|
- 🔍 **Semantic code search** across your entire project
|
||||||
|
- 📚 **Context-aware assistance** for debugging and development
|
||||||
|
- 🚀 **Zero-config setup** with automatic language detection
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install LEANN globally for MCP integration
|
||||||
|
uv tool install leann-core
|
||||||
|
|
||||||
|
# Setup is automatic - just start using Claude Code!
|
||||||
|
```
|
||||||
|
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
**Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||||
|
|
||||||
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
|
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
If you followed the Quick Start, `leann` is already installed in your virtual environment:
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate
|
||||||
|
leann --help
|
||||||
|
```
|
||||||
|
|
||||||
|
**To make it globally available:**
|
||||||
|
```bash
|
||||||
|
# Install the LEANN CLI globally using uv tool
|
||||||
|
uv tool install leann
|
||||||
|
|
||||||
|
# Now you can use leann from anywhere without activating venv
|
||||||
|
leann --help
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Usage Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# build from a specific directory, and my_docs is the index name
|
||||||
|
leann build my-docs --docs ./your_documents
|
||||||
|
|
||||||
|
# Search your documents
|
||||||
|
leann search my-docs "machine learning concepts"
|
||||||
|
|
||||||
|
# Interactive chat with your documents
|
||||||
|
leann ask my-docs --interactive
|
||||||
|
|
||||||
|
# List all your indexes
|
||||||
|
leann list
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key CLI features:**
|
||||||
|
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||||
|
- Smart text chunking with overlap
|
||||||
|
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||||
|
- Organized index storage in `~/.leann/indexes/`
|
||||||
|
- Support for advanced search parameters
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||||
|
|
||||||
|
**Build Command:**
|
||||||
|
```bash
|
||||||
|
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||||
|
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||||
|
--graph-degree N Graph degree (default: 32)
|
||||||
|
--complexity N Build complexity (default: 64)
|
||||||
|
--force Force rebuild existing index
|
||||||
|
--compact Use compact storage (default: true)
|
||||||
|
--recompute Enable recomputation (default: true)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Search Command:**
|
||||||
|
```bash
|
||||||
|
leann search INDEX_NAME QUERY [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--top-k N Number of results (default: 5)
|
||||||
|
--complexity N Search complexity (default: 64)
|
||||||
|
--recompute-embeddings Use recomputation for highest accuracy
|
||||||
|
--pruning-strategy {global,local,proportional}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ask Command:**
|
||||||
|
```bash
|
||||||
|
leann ask INDEX_NAME [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||||
|
--model MODEL Model name (default: qwen3:8b)
|
||||||
|
--interactive Interactive chat mode
|
||||||
|
--top-k N Retrieval count (default: 20)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="assets/arch.png" alt="LEANN Architecture" width="800">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
**The magic:** Most vector DBs store every single embedding (expensive). LEANN stores a pruned graph structure (cheap) and recomputes embeddings only when needed (fast).
|
||||||
|
|
||||||
|
**Core techniques:**
|
||||||
|
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
||||||
|
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||||
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
|
**Backends:** HNSW (default) for most use cases, with optional DiskANN support for billion-scale datasets.
|
||||||
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
|
||||||
|
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
|
||||||
|
### 📊 Storage Comparison
|
||||||
|
|
||||||
|
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||||
|
|--------|-------------|------------|-------------|--------------|---------------|
|
||||||
|
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
|
||||||
|
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
|
||||||
|
| Savings| 91% | 97% | 97% | 97% | 95% |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Reproduce Our Results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install -e ".[dev]" # Install dev dependencies
|
||||||
|
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||||
|
```
|
||||||
|
|
||||||
|
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||||
## 🔬 Paper
|
## 🔬 Paper
|
||||||
|
|
||||||
If you find Leann useful, please cite:
|
If you find Leann useful, please cite:
|
||||||
@@ -174,101 +574,25 @@ If you find Leann useful, please cite:
|
|||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{wang2025leannlowstoragevectorindex,
|
@misc{wang2025leannlowstoragevectorindex,
|
||||||
title={LEANN: A Low-Storage Vector Index},
|
title={LEANN: A Low-Storage Vector Index},
|
||||||
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
|
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
|
||||||
year={2025},
|
year={2025},
|
||||||
eprint={2506.08276},
|
eprint={2506.08276},
|
||||||
archivePrefix={arXiv},
|
archivePrefix={arXiv},
|
||||||
primaryClass={cs.DB},
|
primaryClass={cs.DB},
|
||||||
url={https://arxiv.org/abs/2506.08276},
|
url={https://arxiv.org/abs/2506.08276},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🌍 Use Cases
|
## ✨ [Detailed Features →](docs/features.md)
|
||||||
|
|
||||||
### 💼 Enterprise RAG
|
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
|
||||||
```python
|
|
||||||
# Handle millions of documents with limited resources
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="diskann",
|
|
||||||
distance_metric="cosine",
|
|
||||||
graph_degree=64,
|
|
||||||
memory_budget="4GB"
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 🔬 Research & Experimentation
|
|
||||||
```python
|
|
||||||
# Quick prototyping with different algorithms
|
|
||||||
for backend in ["diskann", "hnsw"]:
|
|
||||||
searcher = LeannSearcher(index_path, backend=backend)
|
|
||||||
evaluate_recall(searcher, queries, ground_truth)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 🚀 Real-time Applications
|
## ❓ [FAQ →](docs/faq.md)
|
||||||
```python
|
|
||||||
# Sub-second response times
|
|
||||||
chat = LeannChat("knowledge.leann")
|
|
||||||
response = chat.ask("What is quantum computing?")
|
|
||||||
# Returns in <100ms with recompute mode
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🤝 Contributing
|
|
||||||
|
|
||||||
We welcome contributions! Leann is built by the community, for the community.
|
## 📈 [Roadmap →](docs/roadmap.md)
|
||||||
|
|
||||||
### Ways to Contribute
|
|
||||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
|
||||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
|
||||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
|
||||||
- 📖 **Documentation**: Help make Leann more accessible
|
|
||||||
- 🧪 **Benchmarks**: Share your performance results
|
|
||||||
|
|
||||||
### Development Setup
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/yourname/leann
|
|
||||||
cd leann
|
|
||||||
uv sync --dev
|
|
||||||
uv run pytest tests/
|
|
||||||
```
|
|
||||||
|
|
||||||
### Quick Tests
|
|
||||||
```bash
|
|
||||||
# Sanity check all distance functions
|
|
||||||
uv run python tests/sanity_checks/test_distance_functions.py
|
|
||||||
|
|
||||||
# Verify L2 implementation
|
|
||||||
uv run python tests/sanity_checks/test_l2_verification.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## 📈 Roadmap
|
|
||||||
|
|
||||||
### 🎯 Q1 2024
|
|
||||||
- [x] DiskANN backend with MIPS/L2/Cosine support
|
|
||||||
- [x] HNSW backend integration
|
|
||||||
- [x] Real-time embedding pipeline
|
|
||||||
- [x] Memory-efficient graph pruning
|
|
||||||
|
|
||||||
### 🚀 Q2 2024
|
|
||||||
- [ ] Distributed search across multiple nodes
|
|
||||||
- [ ] ScaNN backend support
|
|
||||||
- [ ] Advanced caching strategies
|
|
||||||
- [ ] Kubernetes deployment guides
|
|
||||||
|
|
||||||
### 🌟 Q3 2024
|
|
||||||
- [ ] GPU-accelerated embedding computation
|
|
||||||
- [ ] Approximate distance functions
|
|
||||||
- [ ] Integration with LangChain/LlamaIndex
|
|
||||||
- [ ] Visual similarity search
|
|
||||||
|
|
||||||
## 💬 Community
|
|
||||||
|
|
||||||
Join our growing community of researchers and engineers!
|
|
||||||
|
|
||||||
- 🐦 **Twitter**: [@LeannAI](https://twitter.com/LeannAI)
|
|
||||||
- 💬 **Discord**: [Join our server](https://discord.gg/leann)
|
|
||||||
- 📧 **Email**: leann@yourcompany.com
|
|
||||||
- 🐙 **GitHub Discussions**: [Ask questions here](https://github.com/yourname/leann/discussions)
|
|
||||||
|
|
||||||
## 📄 License
|
## 📄 License
|
||||||
|
|
||||||
@@ -276,10 +600,11 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
## 🙏 Acknowledgments
|
## 🙏 Acknowledgments
|
||||||
|
|
||||||
- **Microsoft Research** for the DiskANN algorithm
|
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||||
- **Meta AI** for FAISS and optimization insights
|
|
||||||
- **HuggingFace** for the transformer ecosystem
|
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||||
- **Our amazing contributors** who make this possible
|
|
||||||
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -289,4 +614,4 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
Made with ❤️ by the Leann team
|
Made with ❤️ by the Leann team
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
321
apps/base_rag_example.py
Normal file
321
apps/base_rag_example.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
"""
|
||||||
|
Base class for unified RAG examples interface.
|
||||||
|
Provides common parameters and functionality for all RAG examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRAGExample(ABC):
|
||||||
|
"""Base class for all RAG examples with unified interface."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
default_index_name: str,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.default_index_name = default_index_name
|
||||||
|
self.parser = self._create_parser()
|
||||||
|
|
||||||
|
def _create_parser(self) -> argparse.ArgumentParser:
|
||||||
|
"""Create argument parser with common parameters."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Core parameters (all examples share these)
|
||||||
|
core_group = parser.add_argument_group("Core Parameters")
|
||||||
|
core_group.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default=f"./{self.default_index_name}",
|
||||||
|
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Query to run (if not provided, will run in interactive mode)",
|
||||||
|
)
|
||||||
|
# Allow subclasses to override default max_items
|
||||||
|
max_items_default = getattr(self, "max_items_default", -1)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--max-items",
|
||||||
|
type=int,
|
||||||
|
default=max_items_default,
|
||||||
|
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding parameters
|
||||||
|
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||||
|
# Allow subclasses to override default embedding_model
|
||||||
|
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default=embedding_model_default,
|
||||||
|
help=f"Embedding model to use (default: {embedding_model_default})",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode (default: sentence-transformers)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM parameters
|
||||||
|
llm_group = parser.add_argument_group("LLM Parameters")
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="openai",
|
||||||
|
choices=["openai", "ollama", "hf", "simulated"],
|
||||||
|
help="LLM backend to use (default: openai)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-host",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:11434",
|
||||||
|
help="Host for Ollama API (default: http://localhost:11434)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--thinking-budget",
|
||||||
|
type=str,
|
||||||
|
choices=["low", "medium", "high"],
|
||||||
|
default=None,
|
||||||
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search parameters
|
||||||
|
search_group = parser.add_argument_group("Search Parameters")
|
||||||
|
search_group.add_argument(
|
||||||
|
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||||
|
)
|
||||||
|
search_group.add_argument(
|
||||||
|
"--search-complexity",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Search complexity for graph traversal (default: 64)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index building parameters
|
||||||
|
index_group = parser.add_argument_group("Index Building Parameters")
|
||||||
|
index_group.add_argument(
|
||||||
|
"--backend-name",
|
||||||
|
type=str,
|
||||||
|
default="hnsw",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
help="Backend to use for index (default: hnsw)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--graph-degree",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Graph degree for index construction (default: 32)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--build-complexity",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Build complexity for index construction (default: 64)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-compact",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable compact index storage",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-recompute",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable embedding recomputation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add source-specific parameters
|
||||||
|
self._add_specific_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add source-specific arguments. Override in subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load data from the source. Returns list of text chunks."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_llm_config(self, args) -> dict[str, Any]:
|
||||||
|
"""Get LLM configuration based on arguments."""
|
||||||
|
config = {"type": args.llm}
|
||||||
|
|
||||||
|
if args.llm == "openai":
|
||||||
|
config["model"] = args.llm_model or "gpt-4o"
|
||||||
|
elif args.llm == "ollama":
|
||||||
|
config["model"] = args.llm_model or "llama3.2:1b"
|
||||||
|
config["host"] = args.llm_host
|
||||||
|
elif args.llm == "hf":
|
||||||
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def build_index(self, args, texts: list[str]) -> str:
|
||||||
|
"""Build LEANN index from texts."""
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
|
||||||
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
|
print(f"Total text chunks: {len(texts)}")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend_name,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
complexity=args.build_complexity,
|
||||||
|
is_compact=not args.no_compact,
|
||||||
|
is_recompute=not args.no_recompute,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add texts in batches for better progress tracking
|
||||||
|
batch_size = 1000
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
for text in batch:
|
||||||
|
builder.add_text(text)
|
||||||
|
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||||
|
|
||||||
|
print("Building index structure...")
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"Index saved to: {index_path}")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def run_interactive_chat(self, args, index_path: str):
|
||||||
|
"""Run interactive chat with the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||||
|
print("Type 'quit' or 'exit' to stop.\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("You: ").strip()
|
||||||
|
if query.lower() in ["quit", "exit", "q"]:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare LLM kwargs with thinking budget if specified
|
||||||
|
llm_kwargs = {}
|
||||||
|
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
)
|
||||||
|
print(f"\nAssistant: {response}\n")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
async def run_single_query(self, args, index_path: str, query: str):
|
||||||
|
"""Run a single query against the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||||
|
|
||||||
|
# Prepare LLM kwargs with thinking budget if specified
|
||||||
|
llm_kwargs = {}
|
||||||
|
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
||||||
|
)
|
||||||
|
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Main entry point for the example."""
|
||||||
|
args = self.parser.parse_args()
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
index_exists = Path(args.index_dir).exists()
|
||||||
|
|
||||||
|
if not index_exists or args.force_rebuild:
|
||||||
|
# Load data and build index
|
||||||
|
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||||
|
texts = await self.load_data(args)
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
print("No data found to index!")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_path = await self.build_index(args, texts)
|
||||||
|
else:
|
||||||
|
print(f"\nUsing existing index in {args.index_dir}")
|
||||||
|
|
||||||
|
# Run query or interactive mode
|
||||||
|
if args.query:
|
||||||
|
await self.run_single_query(args, index_path, args.query)
|
||||||
|
else:
|
||||||
|
await self.run_interactive_chat(args, index_path)
|
||||||
|
|
||||||
|
|
||||||
|
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
||||||
|
"""Helper function to create text chunks from documents."""
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separator=" ",
|
||||||
|
paragraph_separator="\n\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
if nodes:
|
||||||
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
170
apps/browser_rag.py
Normal file
170
apps/browser_rag.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
Browser History RAG example using the unified interface.
|
||||||
|
Supports Chrome browser history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||||
|
|
||||||
|
from .history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Chrome browser history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Browser History",
|
||||||
|
description="Process and query Chrome browser history with LEANN",
|
||||||
|
default_index_name="google_history_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add browser-specific arguments."""
|
||||||
|
browser_group = parser.add_argument_group("Browser Parameters")
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chrome-profile",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--auto-find-profiles",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Automatically find all Chrome profiles (default: True)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_chrome_base_path(self) -> Path:
|
||||||
|
"""Get the base Chrome profile path based on OS."""
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||||
|
elif sys.platform.startswith("linux"):
|
||||||
|
return Path.home() / ".config" / "google-chrome"
|
||||||
|
elif sys.platform == "win32":
|
||||||
|
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||||
|
|
||||||
|
def _find_chrome_profiles(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Chrome profiles."""
|
||||||
|
base_path = self._get_chrome_base_path()
|
||||||
|
if not base_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
profiles = []
|
||||||
|
|
||||||
|
# Check Default profile
|
||||||
|
default_profile = base_path / "Default"
|
||||||
|
if default_profile.exists() and (default_profile / "History").exists():
|
||||||
|
profiles.append(default_profile)
|
||||||
|
|
||||||
|
# Check numbered profiles
|
||||||
|
for item in base_path.iterdir():
|
||||||
|
if item.is_dir() and item.name.startswith("Profile "):
|
||||||
|
if (item / "History").exists():
|
||||||
|
profiles.append(item)
|
||||||
|
|
||||||
|
return profiles
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load browser history and convert to text chunks."""
|
||||||
|
# Determine Chrome profiles
|
||||||
|
if args.chrome_profile and not args.auto_find_profiles:
|
||||||
|
profile_dirs = [Path(args.chrome_profile)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Chrome profiles...")
|
||||||
|
profile_dirs = self._find_chrome_profiles()
|
||||||
|
|
||||||
|
# If specific profile given, filter to just that one
|
||||||
|
if args.chrome_profile:
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||||
|
|
||||||
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found!")
|
||||||
|
print("Please specify --chrome-profile manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
# Process each profile
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
|
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per profile
|
||||||
|
max_per_profile = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_profile = remaining
|
||||||
|
|
||||||
|
# Load history
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=str(profile_dir),
|
||||||
|
max_count=max_per_profile,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} history entries from this profile")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No browser history found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for browser history RAG
|
||||||
|
print("\n🌐 Browser History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What websites did I visit about machine learning?'")
|
||||||
|
print("- 'Find my search history about programming'")
|
||||||
|
print("- 'What YouTube videos did I watch recently?'")
|
||||||
|
print("- 'Show me websites about travel planning'")
|
||||||
|
print("\nNote: Make sure Chrome is closed before running\n")
|
||||||
|
|
||||||
|
rag = BrowserRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
108
apps/document_rag.py
Normal file
108
apps/document_rag.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Document RAG example using the unified interface.
|
||||||
|
Supports PDF, TXT, MD, and other document formats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRAG(BaseRAGExample):
|
||||||
|
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Document",
|
||||||
|
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||||
|
default_index_name="test_doc_files",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add document-specific arguments."""
|
||||||
|
doc_group = parser.add_argument_group("Document Parameters")
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
type=str,
|
||||||
|
default="data",
|
||||||
|
help="Directory containing documents to index (default: data)",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--file-types",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load documents and convert to text chunks."""
|
||||||
|
print(f"Loading documents from: {args.data_dir}")
|
||||||
|
if args.file_types:
|
||||||
|
print(f"Filtering by file types: {args.file_types}")
|
||||||
|
else:
|
||||||
|
print("Processing all supported file types")
|
||||||
|
|
||||||
|
# Check if data directory exists
|
||||||
|
data_path = Path(args.data_dir)
|
||||||
|
if not data_path.exists():
|
||||||
|
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
}
|
||||||
|
if args.file_types:
|
||||||
|
reader_kwargs["required_exts"] = args.file_types
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} documents")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for document RAG
|
||||||
|
print("\n📄 Document RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What are the main techniques LEANN uses?'")
|
||||||
|
print("- 'What is the technique DLPM?'")
|
||||||
|
print("- 'Who does Elizabeth Bennet marry?'")
|
||||||
|
print(
|
||||||
|
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||||
|
)
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = DocumentRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
167
apps/email_data/LEANN_email_reader.py
Normal file
167
apps/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import email
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Recursively find all 'Messages' directories under the given root.
|
||||||
|
Returns a list of Path objects.
|
||||||
|
"""
|
||||||
|
if root is None:
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
root = os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
messages_dirs = []
|
||||||
|
for dirpath, _dirnames, _filenames in os.walk(root):
|
||||||
|
if os.path.basename(dirpath) == "Messages":
|
||||||
|
messages_dirs.append(Path(dirpath))
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, include_html: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_html: Whether to include HTML content in the email body (default: False)
|
||||||
|
"""
|
||||||
|
self.include_html = include_html
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
count = 0
|
||||||
|
total_files = 0
|
||||||
|
successful_files = 0
|
||||||
|
failed_files = 0
|
||||||
|
|
||||||
|
print(f"Starting to process directory: {input_dir}")
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
# Check if we've reached the max count (skip if max_count == -1)
|
||||||
|
if max_count > 0 and count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
total_files += 1
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get("Subject", "No Subject")
|
||||||
|
from_addr = msg.get("From", "Unknown")
|
||||||
|
to_addr = msg.get("To", "Unknown")
|
||||||
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/plain"
|
||||||
|
or part.get_content_type() == "text/html"
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/html"
|
||||||
|
and not self.include_html
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload = part.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body += payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding payload: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
payload = msg.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body = payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding single part payload: {e}")
|
||||||
|
body = ""
|
||||||
|
|
||||||
|
# Only create document if we have some content
|
||||||
|
if body.strip() or subject != "No Subject":
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[File]: {filename}
|
||||||
|
[From]: {from_addr}
|
||||||
|
[To]: {to_addr}
|
||||||
|
[Subject]: {subject}
|
||||||
|
[Date]: {date}
|
||||||
|
[EMAIL BODY Start]:
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No separate metadata - everything is in the text
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
successful_files += 1
|
||||||
|
|
||||||
|
# Print first few successful files for debugging
|
||||||
|
if successful_files <= 3:
|
||||||
|
print(
|
||||||
|
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("Processing summary:")
|
||||||
|
print(f" Total .emlx files found: {total_files}")
|
||||||
|
print(f" Successfully loaded: {successful_files}")
|
||||||
|
print(f" Failed to load: {failed_files}")
|
||||||
|
print(f" Final documents: {len(docs)}")
|
||||||
|
|
||||||
|
return docs
|
||||||
186
apps/email_data/email.py
Normal file
186
apps/email_data/email.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Contains simple parser for mbox files.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MboxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Extract messages from mailbox files.
|
||||||
|
Returns string including date, subject, sender, receiver and
|
||||||
|
content for each message.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
|
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
max_count: int = 0,
|
||||||
|
message_format: str = DEFAULT_MESSAGE_FORMAT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Init params."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup # noqa
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.max_count = max_count
|
||||||
|
self.message_format = message_format
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
file: Path,
|
||||||
|
extra_info: dict | None = None,
|
||||||
|
fs: AbstractFileSystem | None = None,
|
||||||
|
) -> list[Document]:
|
||||||
|
"""Parse file into string."""
|
||||||
|
# Import required libraries
|
||||||
|
import mailbox
|
||||||
|
from email.parser import BytesParser
|
||||||
|
from email.policy import default
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but MboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
results: list[str] = []
|
||||||
|
# Load file using mailbox
|
||||||
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
|
|
||||||
|
# Iterate through all messages
|
||||||
|
for _, _msg in enumerate(mbox):
|
||||||
|
try:
|
||||||
|
msg: mailbox.mboxMessage = _msg
|
||||||
|
# Parse multipart messages
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
ctype = part.get_content_type()
|
||||||
|
cdispo = str(part.get("Content-Disposition"))
|
||||||
|
if "attachment" in cdispo:
|
||||||
|
print(f"Attachment found: {part.get_filename()}")
|
||||||
|
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||||
|
content = part.get_payload(decode=True) # decode
|
||||||
|
break
|
||||||
|
# Get plain message payload for non-multipart messages
|
||||||
|
else:
|
||||||
|
content = msg.get_payload(decode=True)
|
||||||
|
|
||||||
|
# Parse message HTML content and remove unneeded whitespace
|
||||||
|
soup = BeautifulSoup(content)
|
||||||
|
stripped_content = " ".join(soup.get_text().split())
|
||||||
|
# Format message to include date, sender, receiver and subject
|
||||||
|
msg_string = self.message_format.format(
|
||||||
|
_date=msg["date"],
|
||||||
|
_from=msg["from"],
|
||||||
|
_to=msg["to"],
|
||||||
|
_subject=msg["subject"],
|
||||||
|
_content=stripped_content,
|
||||||
|
)
|
||||||
|
# Add message string to results
|
||||||
|
results.append(msg_string)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
|
||||||
|
|
||||||
|
# Increment counter and return if max count is met
|
||||||
|
i += 1
|
||||||
|
if self.max_count > 0 and i >= self.max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
return [Document(text=result, metadata=extra_info or {}) for result in results]
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxMboxReader(MboxReader):
|
||||||
|
"""
|
||||||
|
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||||
|
|
||||||
|
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||||
|
1. Reading .emlx files from a directory
|
||||||
|
2. Converting them to mbox format in memory
|
||||||
|
3. Using the parent MboxReader's parsing logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
directory: Path,
|
||||||
|
extra_info: dict | None = None,
|
||||||
|
fs: AbstractFileSystem | None = None,
|
||||||
|
) -> list[Document]:
|
||||||
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find all .emlx files in the directory
|
||||||
|
emlx_files = list(directory.glob("*.emlx"))
|
||||||
|
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||||
|
|
||||||
|
if not emlx_files:
|
||||||
|
logger.warning(f"No .emlx files found in {directory}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a temporary mbox file
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
||||||
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
|
# Convert .emlx files to mbox format
|
||||||
|
for emlx_file in emlx_files:
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx format: first line is length, rest is email content
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
|
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||||
|
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Close the temporary file so MboxReader can read it
|
||||||
|
temp_mbox.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the parent MboxReader's logic to parse the mbox file
|
||||||
|
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||||
|
finally:
|
||||||
|
# Clean up temporary file
|
||||||
|
try:
|
||||||
|
os.unlink(temp_mbox_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
156
apps/email_rag.py
Normal file
156
apps/email_rag.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
Email RAG example using the unified interface.
|
||||||
|
Supports Apple Mail on macOS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||||
|
|
||||||
|
from .email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Apple Mail processing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Process all emails by default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Email",
|
||||||
|
description="Process and query Apple Mail emails with LEANN",
|
||||||
|
default_index_name="mail_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add email-specific arguments."""
|
||||||
|
email_group = parser.add_argument_group("Email Parameters")
|
||||||
|
email_group.add_argument(
|
||||||
|
"--mail-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_mail_directories(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Apple Mail directories."""
|
||||||
|
mail_base = Path.home() / "Library" / "Mail"
|
||||||
|
if not mail_base.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find all Messages directories
|
||||||
|
messages_dirs = []
|
||||||
|
for item in mail_base.rglob("Messages"):
|
||||||
|
if item.is_dir():
|
||||||
|
messages_dirs.append(item)
|
||||||
|
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load emails and convert to text chunks."""
|
||||||
|
# Determine mail directories
|
||||||
|
if args.mail_path:
|
||||||
|
messages_dirs = [Path(args.mail_path)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Apple Mail directories...")
|
||||||
|
messages_dirs = self._find_mail_directories()
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Apple Mail directories found!")
|
||||||
|
print("Please specify --mail-path manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(messages_dirs)} mail directories")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = EmlxReader(include_html=args.include_html)
|
||||||
|
|
||||||
|
# Process each directory
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Count emlx files
|
||||||
|
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||||
|
print(f"Found {len(emlx_files)} email files")
|
||||||
|
|
||||||
|
# Apply max_items limit per directory
|
||||||
|
max_per_dir = -1 # Default to process all
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_dir = remaining
|
||||||
|
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||||
|
|
||||||
|
# Load emails - fix the parameter passing
|
||||||
|
documents = reader.load_data(
|
||||||
|
input_dir=str(messages_dir),
|
||||||
|
max_count=max_per_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} emails from this directory")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No emails found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
# Email reader uses chunk_overlap=25 as in original
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||||
|
print(" Windows/Linux support coming soon!\n")
|
||||||
|
|
||||||
|
# Example queries for email RAG
|
||||||
|
print("\n📧 Email RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What did my boss say about deadlines?'")
|
||||||
|
print("- 'Find emails about travel expenses'")
|
||||||
|
print("- 'Show me emails from last month about the project'")
|
||||||
|
print("- 'What food did I order from DoorDash?'")
|
||||||
|
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||||
|
|
||||||
|
rag = EmailRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
3
apps/history_data/__init__.py
Normal file
3
apps/history_data/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .history import ChromeHistoryReader
|
||||||
|
|
||||||
|
__all__ = ["ChromeHistoryReader"]
|
||||||
186
apps/history_data/history.py
Normal file
186
apps/history_data/history.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class ChromeHistoryReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||||
|
|
||||||
|
Reads Chrome history from the default Chrome profile location and creates documents
|
||||||
|
with embedded metadata similar to the email reader structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load Chrome history data from the default Chrome profile location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Not used for Chrome history (kept for compatibility)
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of history entries to read.
|
||||||
|
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
||||||
|
|
||||||
|
# Default Chrome profile path on macOS
|
||||||
|
if chrome_profile_path is None:
|
||||||
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
|
|
||||||
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
|
if not os.path.exists(history_db_path):
|
||||||
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect to the Chrome history database
|
||||||
|
print(f"Connecting to database: {history_db_path}")
|
||||||
|
conn = sqlite3.connect(history_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Query to get browsing history with metadata (removed created_time column)
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
|
url,
|
||||||
|
title,
|
||||||
|
visit_count,
|
||||||
|
typed_count,
|
||||||
|
hidden
|
||||||
|
FROM urls
|
||||||
|
ORDER BY last_visit_time DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(f"Executing query on database: {history_db_path}")
|
||||||
|
cursor.execute(query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
print(f"Query returned {len(rows)} rows")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for row in rows:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[Title]: {title}
|
||||||
|
[URL of the page]: {url}
|
||||||
|
[Last visited time]: {last_visit}
|
||||||
|
[Visit times]: {visit_count}
|
||||||
|
[Typed times]: {typed_count}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create document with embedded metadata
|
||||||
|
doc = Document(text=doc_content, metadata={"title": title[0:150]})
|
||||||
|
# if len(title) > 150:
|
||||||
|
# print(f"Title is too long: {title}")
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
print(f"Loaded {len(docs)} Chrome history documents")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading Chrome history: {e}")
|
||||||
|
# add you may need to close your browser to make the database file available
|
||||||
|
# also highlight in red
|
||||||
|
print(
|
||||||
|
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_chrome_profiles() -> list[Path]:
|
||||||
|
"""
|
||||||
|
Find all Chrome profile directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to Chrome profile directories
|
||||||
|
"""
|
||||||
|
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
||||||
|
profile_dirs = []
|
||||||
|
|
||||||
|
if not chrome_base_path.exists():
|
||||||
|
print(f"Chrome directory not found at: {chrome_base_path}")
|
||||||
|
return profile_dirs
|
||||||
|
|
||||||
|
# Find all profile directories
|
||||||
|
for profile_dir in chrome_base_path.iterdir():
|
||||||
|
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
||||||
|
history_path = profile_dir / "History"
|
||||||
|
if history_path.exists():
|
||||||
|
profile_dirs.append(profile_dir)
|
||||||
|
print(f"Found Chrome profile: {profile_dir}")
|
||||||
|
|
||||||
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
|
return profile_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def export_history_to_file(
|
||||||
|
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Export Chrome history to a text file using the same SQL query format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Path to the output file
|
||||||
|
max_count: Maximum number of entries to export
|
||||||
|
"""
|
||||||
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
|
if not os.path.exists(history_db_path):
|
||||||
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(history_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
|
url,
|
||||||
|
title,
|
||||||
|
visit_count,
|
||||||
|
typed_count,
|
||||||
|
hidden
|
||||||
|
FROM urls
|
||||||
|
ORDER BY last_visit_time DESC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute(query, (max_count,))
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
for row in rows:
|
||||||
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
|
f.write(
|
||||||
|
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting Chrome history: {e}")
|
||||||
774
apps/history_data/wechat_history.py
Normal file
774
apps/history_data/wechat_history.py
Normal file
@@ -0,0 +1,774 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class WeChatHistoryReader(BaseReader):
|
||||||
|
"""
|
||||||
|
WeChat chat history reader that extracts chat data from exported JSON files.
|
||||||
|
|
||||||
|
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
||||||
|
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
||||||
|
|
||||||
|
Also includes utilities for automatic WeChat chat history export.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
||||||
|
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
||||||
|
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
||||||
|
|
||||||
|
def check_wechat_running(self) -> bool:
|
||||||
|
"""Check if WeChat is currently running."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
|
||||||
|
return result.returncode == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def install_wechattweak(self) -> bool:
|
||||||
|
"""Install WeChatTweak CLI tool."""
|
||||||
|
try:
|
||||||
|
# Create wechat-exporter directory if it doesn't exist
|
||||||
|
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||||
|
if not wechattweak_path.exists():
|
||||||
|
print("Downloading WeChatTweak CLI...")
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"curl",
|
||||||
|
"-L",
|
||||||
|
"-o",
|
||||||
|
str(wechattweak_path),
|
||||||
|
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make executable
|
||||||
|
wechattweak_path.chmod(0o755)
|
||||||
|
|
||||||
|
# Install WeChatTweak
|
||||||
|
print("Installing WeChatTweak...")
|
||||||
|
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error installing WeChatTweak: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def restart_wechat(self):
|
||||||
|
"""Restart WeChat to apply WeChatTweak."""
|
||||||
|
try:
|
||||||
|
print("Restarting WeChat...")
|
||||||
|
subprocess.run(["pkill", "-f", "WeChat"], check=False)
|
||||||
|
time.sleep(2)
|
||||||
|
subprocess.run(["open", "-a", "WeChat"], check=True)
|
||||||
|
time.sleep(5) # Wait for WeChat to start
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error restarting WeChat: {e}")
|
||||||
|
|
||||||
|
def check_api_available(self) -> bool:
|
||||||
|
"""Check if WeChatTweak API is available."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
return result.returncode == 0 and result.stdout.strip()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _extract_readable_text(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract readable text from message content, removing XML and system messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The raw message content (can be string or dict)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned, readable text
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Handle dictionary content (like quoted messages)
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Extract text from dictionary structure
|
||||||
|
text_parts = []
|
||||||
|
if "title" in content:
|
||||||
|
text_parts.append(str(content["title"]))
|
||||||
|
if "quoted" in content:
|
||||||
|
text_parts.append(str(content["quoted"]))
|
||||||
|
if "content" in content:
|
||||||
|
text_parts.append(str(content["content"]))
|
||||||
|
if "text" in content:
|
||||||
|
text_parts.append(str(content["text"]))
|
||||||
|
|
||||||
|
if text_parts:
|
||||||
|
return " | ".join(text_parts)
|
||||||
|
else:
|
||||||
|
# If we can't extract meaningful text from dict, return empty
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Handle string content
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Remove common prefixes like "wxid_xxx:\n"
|
||||||
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
|
# If it's just XML or system message, return empty
|
||||||
|
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return clean_content.strip()
|
||||||
|
|
||||||
|
def _is_text_message(self, content: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a message contains readable text content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The message content (can be string or dict)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the message contains readable text, False otherwise
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle dictionary content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Check if dict has any readable text fields
|
||||||
|
text_fields = ["title", "quoted", "content", "text"]
|
||||||
|
for field in text_fields:
|
||||||
|
if content.get(field):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle string content
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip image messages (contain XML with img tags)
|
||||||
|
if "<img" in content and "cdnurl" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip emoji messages (contain emoji XML tags)
|
||||||
|
if "<emoji" in content and "productid" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip voice messages
|
||||||
|
if "<voice" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip video messages
|
||||||
|
if "<video" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip file messages
|
||||||
|
if "<appmsg" in content and "appid" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip system messages (like "recalled a message")
|
||||||
|
if "recalled a message" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if there's actual readable text (not just XML or system messages)
|
||||||
|
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||||
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
|
# If after cleaning we have meaningful text, consider it readable
|
||||||
|
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _concatenate_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
max_length: int = 128,
|
||||||
|
time_window_minutes: int = 30,
|
||||||
|
overlap_messages: int = 0,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Concatenate messages based on length and time rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||||
|
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||||
|
overlap_messages: Number of messages to overlap between consecutive groups
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of concatenated message groups
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
concatenated_groups = []
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
last_timestamp = None
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
# Extract message info
|
||||||
|
content = message.get("content", "")
|
||||||
|
message_text = message.get("message", "")
|
||||||
|
create_time = message.get("createTime", 0)
|
||||||
|
message.get("fromUser", "")
|
||||||
|
message.get("toUser", "")
|
||||||
|
message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
readable_text = message_text
|
||||||
|
|
||||||
|
# Skip empty messages
|
||||||
|
if not readable_text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check time window constraint (only if time_window_minutes != -1)
|
||||||
|
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||||
|
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||||
|
if time_diff_minutes > time_window_minutes:
|
||||||
|
# Time gap too large, start new group
|
||||||
|
if current_group:
|
||||||
|
concatenated_groups.append(
|
||||||
|
{
|
||||||
|
"messages": current_group,
|
||||||
|
"total_length": current_length,
|
||||||
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Keep last few messages for overlap
|
||||||
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
|
current_group = current_group[-overlap_messages:]
|
||||||
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
# Check length constraint (only if max_length != -1)
|
||||||
|
message_length = len(readable_text)
|
||||||
|
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||||
|
# Current group would exceed max length, save it and start new
|
||||||
|
concatenated_groups.append(
|
||||||
|
{
|
||||||
|
"messages": current_group,
|
||||||
|
"total_length": current_length,
|
||||||
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Keep last few messages for overlap
|
||||||
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
|
current_group = current_group[-overlap_messages:]
|
||||||
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
# Add message to current group
|
||||||
|
current_group.append(message)
|
||||||
|
current_length += message_length
|
||||||
|
last_timestamp = create_time
|
||||||
|
|
||||||
|
# Add the last group if it exists
|
||||||
|
if current_group:
|
||||||
|
concatenated_groups.append(
|
||||||
|
{
|
||||||
|
"messages": current_group,
|
||||||
|
"total_length": current_length,
|
||||||
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return concatenated_groups
|
||||||
|
|
||||||
|
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_group: Dictionary containing messages and metadata
|
||||||
|
contact_name: Name of the contact
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted concatenated content
|
||||||
|
"""
|
||||||
|
messages = message_group["messages"]
|
||||||
|
start_time = message_group["start_time"]
|
||||||
|
end_time = message_group["end_time"]
|
||||||
|
|
||||||
|
# Format timestamps
|
||||||
|
if start_time:
|
||||||
|
try:
|
||||||
|
start_timestamp = datetime.fromtimestamp(start_time)
|
||||||
|
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
start_time_str = str(start_time)
|
||||||
|
else:
|
||||||
|
start_time_str = "Unknown"
|
||||||
|
|
||||||
|
if end_time:
|
||||||
|
try:
|
||||||
|
end_timestamp = datetime.fromtimestamp(end_time)
|
||||||
|
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
end_time_str = str(end_time)
|
||||||
|
else:
|
||||||
|
end_time_str = "Unknown"
|
||||||
|
|
||||||
|
# Build concatenated message content
|
||||||
|
message_parts = []
|
||||||
|
for message in messages:
|
||||||
|
content = message.get("content", "")
|
||||||
|
message_text = message.get("message", "")
|
||||||
|
create_time = message.get("createTime", 0)
|
||||||
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
readable_text = message_text
|
||||||
|
|
||||||
|
# Format individual message
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
# change to YYYY-MM-DD HH:MM:SS
|
||||||
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||||
|
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||||
|
|
||||||
|
concatenated_text = "\n".join(message_parts)
|
||||||
|
|
||||||
|
# Create final document content
|
||||||
|
doc_content = f"""
|
||||||
|
Contact: {contact_name}
|
||||||
|
Time Range: {start_time_str} - {end_time_str}
|
||||||
|
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||||
|
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
# TODO @yichuan give better format and rich info here!
|
||||||
|
doc_content = f"""
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
return doc_content, contact_name
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load WeChat chat history data from exported JSON files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing exported WeChat JSON files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of chat entries to read.
|
||||||
|
wechat_export_dir (str): Custom path to WeChat export directory.
|
||||||
|
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
|
||||||
|
concatenate_messages (bool): Whether to concatenate messages based on length rules.
|
||||||
|
max_length (int): Maximum length for concatenated message groups (default: 1000).
|
||||||
|
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||||
|
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||||
|
include_non_text = load_kwargs.get("include_non_text", False)
|
||||||
|
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||||
|
max_length = load_kwargs.get("max_length", 1000)
|
||||||
|
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||||
|
|
||||||
|
# Default WeChat export path
|
||||||
|
if wechat_export_dir is None:
|
||||||
|
wechat_export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
|
if not os.path.exists(wechat_export_dir):
|
||||||
|
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Find all JSON files in the export directory
|
||||||
|
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
||||||
|
print(f"Found {len(json_files)} WeChat chat history files")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for json_file in json_files:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file, encoding="utf-8") as f:
|
||||||
|
chat_data = json.load(f)
|
||||||
|
|
||||||
|
# Extract contact name from filename
|
||||||
|
contact_name = json_file.stem
|
||||||
|
|
||||||
|
if concatenate_messages:
|
||||||
|
# Filter messages to only include readable text messages
|
||||||
|
readable_messages = []
|
||||||
|
for message in chat_data:
|
||||||
|
try:
|
||||||
|
content = message.get("content", "")
|
||||||
|
if not include_non_text and not self._is_text_message(content):
|
||||||
|
continue
|
||||||
|
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text and not include_non_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
readable_messages.append(message)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Concatenate messages based on rules
|
||||||
|
message_groups = self._concatenate_messages(
|
||||||
|
readable_messages,
|
||||||
|
max_length=max_length,
|
||||||
|
time_window_minutes=time_window_minutes,
|
||||||
|
overlap_messages=0, # No overlap between groups
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create documents from concatenated groups
|
||||||
|
for message_group in message_groups:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
doc_content, contact_name = self._create_concatenated_content(
|
||||||
|
message_group, contact_name
|
||||||
|
)
|
||||||
|
doc = Document(
|
||||||
|
text=doc_content,
|
||||||
|
metadata={"contact_name": contact_name},
|
||||||
|
)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Original single-message processing
|
||||||
|
for message in chat_data:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract message information
|
||||||
|
message.get("fromUser", "")
|
||||||
|
message.get("toUser", "")
|
||||||
|
content = message.get("content", "")
|
||||||
|
message_text = message.get("message", "")
|
||||||
|
create_time = message.get("createTime", 0)
|
||||||
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
|
# Handle content that might be dict or string
|
||||||
|
try:
|
||||||
|
# Check if this is a readable text message
|
||||||
|
if not include_non_text and not self._is_text_message(content):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text and not include_non_text:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
# Skip messages that cause processing errors
|
||||||
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert timestamp to readable format
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
# Create document content with metadata header and contact info
|
||||||
|
doc_content = f"""
|
||||||
|
Contact: {contact_name}
|
||||||
|
Is sent from self: {is_sent_from_self}
|
||||||
|
Time: {time_str}
|
||||||
|
Message: {readable_text if readable_text else message_text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create document with embedded metadata
|
||||||
|
doc = Document(
|
||||||
|
text=doc_content, metadata={"contact_name": contact_name}
|
||||||
|
)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} WeChat chat documents")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading WeChat history: {e}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_wechat_export_dirs() -> list[Path]:
|
||||||
|
"""
|
||||||
|
Find all WeChat export directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to WeChat export directories
|
||||||
|
"""
|
||||||
|
export_dirs = []
|
||||||
|
|
||||||
|
# Look for common export directory names
|
||||||
|
possible_dirs = [
|
||||||
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_export_direct"),
|
||||||
|
Path("./wechat_chat_history"),
|
||||||
|
Path("./chat_export"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for export_dir in possible_dirs:
|
||||||
|
if export_dir.exists() and export_dir.is_dir():
|
||||||
|
json_files = list(export_dir.glob("*.json"))
|
||||||
|
if json_files:
|
||||||
|
export_dirs.append(export_dir)
|
||||||
|
print(
|
||||||
|
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||||
|
return export_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def export_chat_to_file(
|
||||||
|
output_file: str = "wechat_chat_export.txt",
|
||||||
|
max_count: int = 1000,
|
||||||
|
export_dir: str | None = None,
|
||||||
|
include_non_text: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Export WeChat chat history to a text file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Path to the output file
|
||||||
|
max_count: Maximum number of entries to export
|
||||||
|
export_dir: Directory containing WeChat JSON files
|
||||||
|
include_non_text: Whether to include non-text messages
|
||||||
|
"""
|
||||||
|
if export_dir is None:
|
||||||
|
export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
|
if not os.path.exists(export_dir):
|
||||||
|
print(f"WeChat export directory not found at: {export_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_files = list(Path(export_dir).glob("*.json"))
|
||||||
|
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
count = 0
|
||||||
|
for json_file in json_files:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file, encoding="utf-8") as json_f:
|
||||||
|
chat_data = json.load(json_f)
|
||||||
|
|
||||||
|
contact_name = json_file.stem
|
||||||
|
f.write(f"\n=== Chat with {contact_name} ===\n")
|
||||||
|
|
||||||
|
for message in chat_data:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
from_user = message.get("fromUser", "")
|
||||||
|
content = message.get("content", "")
|
||||||
|
message_text = message.get("message", "")
|
||||||
|
create_time = message.get("createTime", 0)
|
||||||
|
|
||||||
|
# Skip non-text messages unless requested
|
||||||
|
if not include_non_text:
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
if not reader._is_text_message(content):
|
||||||
|
continue
|
||||||
|
readable_text = reader._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
continue
|
||||||
|
message_text = readable_text
|
||||||
|
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Exported {count} chat entries to {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting WeChat chat history: {e}")
|
||||||
|
|
||||||
|
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
|
||||||
|
"""
|
||||||
|
Export WeChat chat history using wechat-exporter tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Directory to save exported chat history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to export directory if successful, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_path = Path(export_dir)
|
||||||
|
export_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Exporting WeChat chat history to {export_path}...")
|
||||||
|
|
||||||
|
# Check if wechat-exporter directory exists
|
||||||
|
if not self.wechat_exporter_dir.exists():
|
||||||
|
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Install requirements if needed
|
||||||
|
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||||
|
if requirements_file.exists():
|
||||||
|
print("Installing wechat-exporter requirements...")
|
||||||
|
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
|
||||||
|
|
||||||
|
# Run the export command
|
||||||
|
print("Running wechat-exporter...")
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(self.wechat_exporter_dir / "main.py"),
|
||||||
|
"export-all",
|
||||||
|
str(export_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Export command output:")
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("Export errors:")
|
||||||
|
print(result.stderr)
|
||||||
|
|
||||||
|
# Check if export was successful
|
||||||
|
if export_path.exists() and any(export_path.glob("*.json")):
|
||||||
|
json_files = list(export_path.glob("*.json"))
|
||||||
|
print(
|
||||||
|
f"Successfully exported {len(json_files)} chat history files to {export_path}"
|
||||||
|
)
|
||||||
|
return export_path
|
||||||
|
else:
|
||||||
|
print("Export completed but no JSON files found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Export command failed: {e}")
|
||||||
|
print(f"Command output: {e.stdout}")
|
||||||
|
print(f"Command errors: {e.stderr}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Export failed: {e}")
|
||||||
|
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
|
||||||
|
"""
|
||||||
|
Find existing WeChat exports or create new ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Directory to save exported chat history if needed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to WeChat export directories
|
||||||
|
"""
|
||||||
|
export_dirs = []
|
||||||
|
|
||||||
|
# Look for existing exports in common locations
|
||||||
|
possible_export_dirs = [
|
||||||
|
Path("./wechat_database_export"),
|
||||||
|
Path("./wechat_export_test"),
|
||||||
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_export_direct"),
|
||||||
|
Path("./wechat_chat_history"),
|
||||||
|
Path("./chat_export"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for export_dir_path in possible_export_dirs:
|
||||||
|
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
||||||
|
export_dirs.append(export_dir_path)
|
||||||
|
print(f"Found existing export: {export_dir_path}")
|
||||||
|
|
||||||
|
# If no existing exports, try to export automatically
|
||||||
|
if not export_dirs:
|
||||||
|
print("No existing WeChat exports found. Starting direct export...")
|
||||||
|
|
||||||
|
# Try to export using wechat-exporter
|
||||||
|
exported_path = self.export_wechat_chat_history(export_dir)
|
||||||
|
if exported_path:
|
||||||
|
export_dirs = [exported_path]
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
|
||||||
|
)
|
||||||
|
|
||||||
|
return export_dirs
|
||||||
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
WeChat History RAG example using the unified interface.
|
||||||
|
Supports WeChat chat history export and search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
|
||||||
|
from .history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class WeChatRAG(BaseRAGExample):
|
||||||
|
"""RAG example for WeChat chat history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Match original default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="WeChat History",
|
||||||
|
description="Process and query WeChat chat history with LEANN",
|
||||||
|
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add WeChat-specific arguments."""
|
||||||
|
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--export-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_export",
|
||||||
|
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||||
|
"""Export WeChat data using wechattweak-cli."""
|
||||||
|
print("Exporting WeChat data...")
|
||||||
|
|
||||||
|
# Check if WeChat is running
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
print("WeChat is not running. Please start WeChat first.")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
pass # pgrep might not be available on all systems
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Run export command
|
||||||
|
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Running: {' '.join(cmd)}")
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print("WeChat data exported successfully!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"Export failed: {result.stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("\nError: wechattweak-cli not found!")
|
||||||
|
print("Please install it first:")
|
||||||
|
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Export error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load WeChat history and convert to text chunks."""
|
||||||
|
# Initialize WeChat reader with export capabilities
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
# Find existing exports or create new ones using the centralized method
|
||||||
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
if not export_dirs:
|
||||||
|
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||||
|
# Try to find any existing exports in common locations
|
||||||
|
export_dirs = reader.find_wechat_export_dirs()
|
||||||
|
if not export_dirs:
|
||||||
|
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Load documents from all found export directories
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, export_dir in enumerate(export_dirs):
|
||||||
|
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per export
|
||||||
|
max_per_export = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_export = remaining
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=str(export_dir),
|
||||||
|
max_count=max_per_export,
|
||||||
|
concatenate_messages=True, # Enable message concatenation for better context
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {export_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks with contact information
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
text_splitter = SentenceSplitter(
|
||||||
|
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
# Add contact information to each chunk
|
||||||
|
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||||
|
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||||
|
print(" You can still query existing exports on other platforms\n")
|
||||||
|
|
||||||
|
# Example queries for WeChat RAG
|
||||||
|
print("\n💬 WeChat History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'Show me conversations about travel plans'")
|
||||||
|
print("- 'Find group chats about weekend activities'")
|
||||||
|
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||||
|
print("- 'What did we discuss about the project last month?'")
|
||||||
|
print("\nNote: WeChat must be running for export to work\n")
|
||||||
|
|
||||||
|
rag = WeChatRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
BIN
assets/arch.png
Normal file
BIN
assets/arch.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/effects.png
Normal file
BIN
assets/effects.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 339 KiB |
BIN
assets/logo-text.png
Normal file
BIN
assets/logo-text.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 818 KiB |
BIN
assets/logo.png
Normal file
BIN
assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 276 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
@@ -7,7 +7,7 @@ This directory contains comprehensive sanity checks for the Leann system, ensuri
|
|||||||
### `test_distance_functions.py`
|
### `test_distance_functions.py`
|
||||||
Tests all supported distance functions across DiskANN backend:
|
Tests all supported distance functions across DiskANN backend:
|
||||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||||
- ✅ **L2** (Euclidean Distance)
|
- ✅ **L2** (Euclidean Distance)
|
||||||
- ✅ **Cosine** (Cosine Similarity)
|
- ✅ **Cosine** (Cosine Similarity)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -27,7 +27,7 @@ uv run python tests/sanity_checks/test_l2_verification.py
|
|||||||
### `test_sanity_check.py`
|
### `test_sanity_check.py`
|
||||||
Comprehensive end-to-end verification including:
|
Comprehensive end-to-end verification including:
|
||||||
- Distance function testing
|
- Distance function testing
|
||||||
- Embedding model compatibility
|
- Embedding model compatibility
|
||||||
- Search result correctness validation
|
- Search result correctness validation
|
||||||
- Backend integration testing
|
- Backend integration testing
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ When all tests pass, you should see:
|
|||||||
```
|
```
|
||||||
📊 测试结果总结:
|
📊 测试结果总结:
|
||||||
mips : ✅ 通过
|
mips : ✅ 通过
|
||||||
l2 : ✅ 通过
|
l2 : ✅ 通过
|
||||||
cosine : ✅ 通过
|
cosine : ✅ 通过
|
||||||
|
|
||||||
🎉 测试完成!
|
🎉 测试完成!
|
||||||
@@ -98,7 +98,7 @@ pkill -f "embedding_server"
|
|||||||
|
|
||||||
### Typical Timing (3 documents, consumer hardware):
|
### Typical Timing (3 documents, consumer hardware):
|
||||||
- **Index Building**: 2-5 seconds per distance function
|
- **Index Building**: 2-5 seconds per distance function
|
||||||
- **Search Query**: 50-200ms
|
- **Search Query**: 50-200ms
|
||||||
- **Recompute Mode**: 5-15 seconds (higher accuracy)
|
- **Recompute Mode**: 5-15 seconds (higher accuracy)
|
||||||
|
|
||||||
### Memory Usage:
|
### Memory Usage:
|
||||||
@@ -117,4 +117,4 @@ These tests are designed to be run in automated environments:
|
|||||||
uv run python tests/sanity_checks/test_l2_verification.py
|
uv run python tests/sanity_checks/test_l2_verification.py
|
||||||
```
|
```
|
||||||
|
|
||||||
The tests are deterministic and should produce consistent results across different platforms.
|
The tests are deterministic and should produce consistent results across different platforms.
|
||||||
141
benchmarks/benchmark_embeddings.py
Normal file
141
benchmarks/benchmark_embeddings.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mlx_lm import load
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||||
|
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||||
|
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||||
|
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||||
|
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||||
|
|
||||||
|
# --- Generate Dummy Data ---
|
||||||
|
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||||
|
|
||||||
|
# --- Benchmark Functions ---b
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_torch(model, sentences):
|
||||||
|
start_time = time.time()
|
||||||
|
model.encode(sentences, convert_to_numpy=True)
|
||||||
|
end_time = time.time()
|
||||||
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_mlx(model, tokenizer, sentences):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Tokenize sentences using MLX tokenizer
|
||||||
|
tokens = []
|
||||||
|
for sentence in sentences:
|
||||||
|
token_ids = tokenizer.encode(sentence)
|
||||||
|
tokens.append(token_ids)
|
||||||
|
|
||||||
|
# Pad sequences to the same length
|
||||||
|
max_len = max(len(t) for t in tokens)
|
||||||
|
input_ids = []
|
||||||
|
attention_mask = []
|
||||||
|
|
||||||
|
for token_seq in tokens:
|
||||||
|
# Pad sequence
|
||||||
|
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
|
||||||
|
input_ids.append(padded)
|
||||||
|
# Create attention mask (1 for real tokens, 0 for padding)
|
||||||
|
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
|
||||||
|
attention_mask.append(mask)
|
||||||
|
|
||||||
|
# Convert to MLX arrays
|
||||||
|
input_ids = mx.array(input_ids)
|
||||||
|
attention_mask = mx.array(attention_mask)
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
embeddings = model(input_ids)
|
||||||
|
|
||||||
|
# Mean pooling
|
||||||
|
mask = mx.expand_dims(attention_mask, -1)
|
||||||
|
sum_embeddings = (embeddings * mask).sum(axis=1)
|
||||||
|
sum_mask = mask.sum(axis=1)
|
||||||
|
_ = sum_embeddings / sum_mask
|
||||||
|
|
||||||
|
mx.eval() # Ensure computation is finished
|
||||||
|
end_time = time.time()
|
||||||
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Execution ---
|
||||||
|
def main():
|
||||||
|
print("--- Initializing Models ---")
|
||||||
|
# Load PyTorch model
|
||||||
|
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
|
||||||
|
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
|
||||||
|
print(f"PyTorch model loaded on: {device}")
|
||||||
|
|
||||||
|
# Load MLX model
|
||||||
|
print(f"Loading MLX model: {MODEL_NAME_MLX}")
|
||||||
|
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
|
||||||
|
print("MLX model loaded.")
|
||||||
|
|
||||||
|
# --- Warm-up ---
|
||||||
|
print("\n--- Performing Warm-up Runs ---")
|
||||||
|
for _ in range(WARMUP_RUNS):
|
||||||
|
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
|
||||||
|
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
|
||||||
|
print("Warm-up complete.")
|
||||||
|
|
||||||
|
# --- Benchmarking ---
|
||||||
|
print("\n--- Starting Benchmark ---")
|
||||||
|
results_torch = []
|
||||||
|
results_mlx = []
|
||||||
|
|
||||||
|
for batch_size in BATCH_SIZES:
|
||||||
|
print(f"Benchmarking batch size: {batch_size}")
|
||||||
|
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
||||||
|
|
||||||
|
# Benchmark PyTorch
|
||||||
|
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
||||||
|
results_torch.append(np.mean(torch_times))
|
||||||
|
|
||||||
|
# Benchmark MLX
|
||||||
|
mlx_times = [
|
||||||
|
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
||||||
|
]
|
||||||
|
results_mlx.append(np.mean(mlx_times))
|
||||||
|
|
||||||
|
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||||
|
print(f"Batch Sizes: {BATCH_SIZES}")
|
||||||
|
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
|
||||||
|
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
|
||||||
|
|
||||||
|
# --- Plotting ---
|
||||||
|
print("\n--- Generating Plot ---")
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(
|
||||||
|
BATCH_SIZES,
|
||||||
|
results_torch,
|
||||||
|
marker="o",
|
||||||
|
linestyle="-",
|
||||||
|
label=f"PyTorch ({device})",
|
||||||
|
)
|
||||||
|
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||||
|
|
||||||
|
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||||
|
plt.xlabel("Batch Size")
|
||||||
|
plt.ylabel("Average Time per Batch (ms)")
|
||||||
|
plt.xticks(BATCH_SIZES)
|
||||||
|
plt.grid(True)
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
# Save the plot
|
||||||
|
output_filename = "embedding_benchmark.png"
|
||||||
|
plt.savefig(output_filename)
|
||||||
|
print(f"Plot saved to {output_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
326
benchmarks/compare_faiss_vs_leann.py
Normal file
326
benchmarks/compare_faiss_vs_leann.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
"""Get current memory usage in MB"""
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
|
def print_memory_stats(stage: str, start_mem: float):
|
||||||
|
"""Print memory statistics"""
|
||||||
|
current_mem = get_memory_usage()
|
||||||
|
diff = current_mem - start_mem
|
||||||
|
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTracker:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.start_mem = get_memory_usage()
|
||||||
|
self.stages = []
|
||||||
|
|
||||||
|
def checkpoint(self, stage: str):
|
||||||
|
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
|
||||||
|
self.stages.append((stage, current_mem))
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
print(f"\n=== {self.name} Memory Summary ===")
|
||||||
|
for stage, mem in self.stages:
|
||||||
|
print(f"{stage}: {mem:.1f} MB")
|
||||||
|
peak_mem = max(mem for _, mem in self.stages)
|
||||||
|
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||||
|
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
|
||||||
|
return peak_mem
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_hnsw():
|
||||||
|
"""Test Faiss HNSW Vector Store in subprocess"""
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("TESTING FAISS HNSW VECTOR STORE")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[sys.executable, "benchmarks/faiss_only.py"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("Stderr:", result.stderr)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return {
|
||||||
|
"peak_memory": float("inf"),
|
||||||
|
"error": f"Process failed with code {result.returncode}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse peak memory from output
|
||||||
|
lines = result.stdout.split("\n")
|
||||||
|
peak_memory = 0.0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if "Peak Memory:" in line:
|
||||||
|
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||||
|
|
||||||
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"peak_memory": float("inf"),
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_leann_hnsw():
|
||||||
|
"""Test LEANN HNSW Search Memory (load existing index)"""
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("TESTING LEANN HNSW SEARCH MEMORY")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
tracker = MemoryTracker("LEANN HNSW Search")
|
||||||
|
|
||||||
|
# Import and setup
|
||||||
|
tracker.checkpoint("Initial")
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
tracker.checkpoint("After imports")
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
# Load and parse documents
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
"data",
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data()
|
||||||
|
|
||||||
|
tracker.checkpoint("After document loading")
|
||||||
|
|
||||||
|
# Parse into chunks
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
print(f"Total number of chunks: {len(all_texts)}")
|
||||||
|
|
||||||
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
|
# Build LEANN index
|
||||||
|
INDEX_DIR = Path("./test_leann_comparison")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
|
||||||
|
|
||||||
|
# Check if index already exists
|
||||||
|
if os.path.exists(INDEX_PATH + ".meta.json"):
|
||||||
|
print("Loading existing LEANN HNSW index...")
|
||||||
|
tracker.checkpoint("After loading existing index")
|
||||||
|
else:
|
||||||
|
print("Building new LEANN HNSW index...")
|
||||||
|
# Clean up previous index
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if INDEX_DIR.exists():
|
||||||
|
shutil.rmtree(INDEX_DIR)
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.checkpoint("After builder setup")
|
||||||
|
|
||||||
|
print("Building LEANN HNSW index...")
|
||||||
|
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
|
# Find existing LEANN index
|
||||||
|
index_paths = [
|
||||||
|
"./test_leann_comparison/comparison.leann",
|
||||||
|
]
|
||||||
|
index_path = None
|
||||||
|
for path in index_paths:
|
||||||
|
if os.path.exists(path + ".meta.json"):
|
||||||
|
index_path = path
|
||||||
|
break
|
||||||
|
|
||||||
|
if not index_path:
|
||||||
|
print("❌ LEANN index not found. Please build it first")
|
||||||
|
return {"peak_memory": float("inf"), "error": "Index not found"}
|
||||||
|
|
||||||
|
# Measure runtime memory overhead
|
||||||
|
print("\nMeasuring runtime memory overhead...")
|
||||||
|
runtime_start_mem = get_memory_usage()
|
||||||
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
|
# Load searcher
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
tracker.checkpoint("After searcher loading")
|
||||||
|
|
||||||
|
print("Running search queries...")
|
||||||
|
queries = [
|
||||||
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
|
"What is LEANN and how does it work?",
|
||||||
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
start_time = time.time()
|
||||||
|
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
|
||||||
|
_ = searcher.search(query, top_k=20, ef=120)
|
||||||
|
query_time = time.time() - start_time
|
||||||
|
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||||
|
tracker.checkpoint(f"After query {i + 1}")
|
||||||
|
|
||||||
|
runtime_end_mem = get_memory_usage()
|
||||||
|
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||||
|
|
||||||
|
peak_memory = tracker.summary()
|
||||||
|
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||||
|
|
||||||
|
# Get storage size before cleanup
|
||||||
|
storage_size = 0
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
if INDEX_DIR.exists():
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
|
||||||
|
for filename in filenames:
|
||||||
|
# Only count actual index files, skip text data and backups
|
||||||
|
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
|
||||||
|
continue
|
||||||
|
# Count .index, .idx, .map files (actual index structures)
|
||||||
|
if filename.endswith((".index", ".idx", ".map")):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
del searcher
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"peak_memory": peak_memory,
|
||||||
|
"storage_size": storage_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run comparison tests"""
|
||||||
|
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Test Faiss HNSW
|
||||||
|
faiss_results = test_faiss_hnsw()
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Test LEANN HNSW
|
||||||
|
leann_results = test_leann_hnsw()
|
||||||
|
|
||||||
|
# Final comparison
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("STORAGE + SEARCH MEMORY COMPARISON")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Get storage sizes
|
||||||
|
faiss_storage_size = 0
|
||||||
|
leann_storage_size = leann_results.get("storage_size", 0)
|
||||||
|
|
||||||
|
# Get Faiss storage size using Python
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, _, filenames in os.walk("./storage_faiss"):
|
||||||
|
for filename in filenames:
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
print("Faiss HNSW:")
|
||||||
|
if "error" in faiss_results:
|
||||||
|
print(f" ❌ Failed: {faiss_results['error']}")
|
||||||
|
else:
|
||||||
|
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {faiss_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
print("\nLEANN HNSW:")
|
||||||
|
if "error" in leann_results:
|
||||||
|
print(f" ❌ Failed: {leann_results['error']}")
|
||||||
|
else:
|
||||||
|
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
# Calculate improvements only if both tests succeeded
|
||||||
|
if "error" not in faiss_results and "error" not in leann_results:
|
||||||
|
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
|
||||||
|
|
||||||
|
print("\nLEANN vs Faiss Performance:")
|
||||||
|
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||||
|
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||||
|
|
||||||
|
# Storage comparison
|
||||||
|
if leann_storage_size > faiss_storage_size:
|
||||||
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
|
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||||
|
elif faiss_storage_size > leann_storage_size:
|
||||||
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
|
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||||
|
else:
|
||||||
|
print(" Storage Size: similar")
|
||||||
|
else:
|
||||||
|
if "error" not in leann_results:
|
||||||
|
print("\n✅ LEANN HNSW completed successfully!")
|
||||||
|
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
if "error" not in faiss_results:
|
||||||
|
print("\n✅ Faiss HNSW completed successfully!")
|
||||||
|
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||||
|
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
82
benchmarks/data/.gitattributes
vendored
Normal file
82
benchmarks/data/.gitattributes
vendored
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - uncompressed
|
||||||
|
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - compressed
|
||||||
|
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - uncompressed
|
||||||
|
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - compressed
|
||||||
|
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Video files - compressed
|
||||||
|
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
151
benchmarks/faiss_only.py
Normal file
151
benchmarks/faiss_only.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Test only Faiss HNSW"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTracker:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.start_mem = get_memory_usage()
|
||||||
|
self.stages = []
|
||||||
|
|
||||||
|
def checkpoint(self, stage: str):
|
||||||
|
current_mem = get_memory_usage()
|
||||||
|
diff = current_mem - self.start_mem
|
||||||
|
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||||
|
self.stages.append((stage, current_mem))
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
peak_mem = max(mem for _, mem in self.stages)
|
||||||
|
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||||
|
return peak_mem
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
except ImportError:
|
||||||
|
print("Faiss is not installed.")
|
||||||
|
print(
|
||||||
|
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
from llama_index.core import (
|
||||||
|
Settings,
|
||||||
|
SimpleDirectoryReader,
|
||||||
|
StorageContext,
|
||||||
|
VectorStoreIndex,
|
||||||
|
)
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
|
tracker = MemoryTracker("Faiss HNSW")
|
||||||
|
tracker.checkpoint("Initial")
|
||||||
|
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||||
|
Settings.embed_model = embed_model
|
||||||
|
tracker.checkpoint("After embedding model setup")
|
||||||
|
|
||||||
|
d = 768
|
||||||
|
faiss_index = faiss.IndexHNSWFlat(d, 32)
|
||||||
|
faiss_index.hnsw.efConstruction = 64
|
||||||
|
tracker.checkpoint("After Faiss index creation")
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
"data",
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data()
|
||||||
|
tracker.checkpoint("After document loading")
|
||||||
|
|
||||||
|
# Parse into chunks using the same splitter as LEANN
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.checkpoint("After text splitter setup")
|
||||||
|
|
||||||
|
# Check if index already exists and try to load it
|
||||||
|
index_loaded = False
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
print("Loading existing Faiss HNSW index...")
|
||||||
|
try:
|
||||||
|
# Use the correct Faiss loading pattern from the example
|
||||||
|
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
|
||||||
|
storage_context = StorageContext.from_defaults(
|
||||||
|
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||||
|
)
|
||||||
|
from llama_index.core import load_index_from_storage
|
||||||
|
|
||||||
|
index = load_index_from_storage(storage_context=storage_context)
|
||||||
|
print("Index loaded from ./storage_faiss")
|
||||||
|
tracker.checkpoint("After loading existing index")
|
||||||
|
index_loaded = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load existing index: {e}")
|
||||||
|
print("Cleaning up corrupted index and building new one...")
|
||||||
|
# Clean up corrupted index
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
shutil.rmtree("./storage_faiss")
|
||||||
|
|
||||||
|
if not index_loaded:
|
||||||
|
print("Building new Faiss HNSW index...")
|
||||||
|
|
||||||
|
# Use the correct Faiss building pattern from the example
|
||||||
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents, storage_context=storage_context, transformations=[node_parser]
|
||||||
|
)
|
||||||
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
|
# Save index to disk using the correct pattern
|
||||||
|
index.storage_context.persist(persist_dir="./storage_faiss")
|
||||||
|
tracker.checkpoint("After index saving")
|
||||||
|
|
||||||
|
# Measure runtime memory overhead
|
||||||
|
print("\nMeasuring runtime memory overhead...")
|
||||||
|
runtime_start_mem = get_memory_usage()
|
||||||
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||||
|
queries = [
|
||||||
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
|
"What is LEANN and how does it work?",
|
||||||
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
start_time = time.time()
|
||||||
|
_ = query_engine.query(query)
|
||||||
|
query_time = time.time() - start_time
|
||||||
|
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||||
|
tracker.checkpoint(f"After query {i + 1}")
|
||||||
|
|
||||||
|
runtime_end_mem = get_memory_usage()
|
||||||
|
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||||
|
|
||||||
|
peak_memory = tracker.summary()
|
||||||
|
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||||
|
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
561
research/micro/embedd_micro.py → benchmarks/micro_tpt.py
Executable file → Normal file
561
research/micro/embedd_micro.py → benchmarks/micro_tpt.py
Executable file → Normal file
@@ -2,21 +2,20 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchao import quantize_
|
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from contextlib import contextmanager
|
from transformers import AutoModel, BitsAndBytesConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str
|
model_path: str
|
||||||
batch_sizes: List[int]
|
batch_sizes: list[int]
|
||||||
seq_length: int
|
seq_length: int
|
||||||
num_runs: int
|
num_runs: int
|
||||||
use_fp16: bool = True
|
use_fp16: bool = True
|
||||||
@@ -27,173 +26,223 @@ class BenchmarkConfig:
|
|||||||
use_linear8bitlt: bool = False
|
use_linear8bitlt: bool = False
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphContainer:
|
class GraphContainer:
|
||||||
"""Container for managing CUDA graphs for different batch sizes."""
|
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, seq_length: int):
|
def __init__(self, model: nn.Module, seq_length: int):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
self.graphs: dict[int, GraphWrapper] = {}
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
def get_or_create(self, batch_size: int) -> "GraphWrapper":
|
||||||
if batch_size not in self.graphs:
|
if batch_size not in self.graphs:
|
||||||
self.graphs[batch_size] = CUDAGraphWrapper(
|
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
|
||||||
self.model, batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[batch_size]
|
return self.graphs[batch_size]
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphWrapper:
|
class GraphWrapper:
|
||||||
"""Wrapper for CUDA graph capture and replay."""
|
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.device = self._get_device()
|
||||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||||
|
|
||||||
# Warm up
|
# Warm up
|
||||||
self._warmup()
|
self._warmup()
|
||||||
|
|
||||||
# Capture graph
|
# Only use CUDA graphs on NVIDIA GPUs
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
|
||||||
with torch.cuda.graph(self.graph):
|
# Capture graph
|
||||||
self.static_output = self.model(
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
input_ids=self.static_input,
|
with torch.cuda.graph(self.graph):
|
||||||
attention_mask=self.static_attention_mask
|
self.static_output = self.model(
|
||||||
)
|
input_ids=self.static_input,
|
||||||
|
attention_mask=self.static_attention_mask,
|
||||||
|
)
|
||||||
|
self.use_cuda_graph = True
|
||||||
|
else:
|
||||||
|
# For MPS or CPU, just store the model
|
||||||
|
self.use_cuda_graph = False
|
||||||
|
self.static_output = None
|
||||||
|
|
||||||
|
def _get_device(self) -> str:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
else:
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000, (batch_size, seq_length),
|
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
def _warmup(self, num_warmup: int = 3):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(num_warmup):
|
for _ in range(num_warmup):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=self.static_input,
|
input_ids=self.static_input,
|
||||||
attention_mask=self.static_attention_mask
|
attention_mask=self.static_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
self.static_input.copy_(input_ids)
|
if self.use_cuda_graph:
|
||||||
self.static_attention_mask.copy_(attention_mask)
|
self.static_input.copy_(input_ids)
|
||||||
self.graph.replay()
|
self.static_attention_mask.copy_(attention_mask)
|
||||||
return self.static_output
|
self.graph.replay()
|
||||||
|
return self.static_output
|
||||||
|
else:
|
||||||
|
# For MPS/CPU, just run normally
|
||||||
|
return self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
class ModelOptimizer:
|
class ModelOptimizer:
|
||||||
"""Applies various optimizations to the model."""
|
"""Applies various optimizations to the model."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||||
print("\nApplying model optimizations:")
|
print("\nApplying model optimizations:")
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError("Cannot optimize None model")
|
raise ValueError("Cannot optimize None model")
|
||||||
|
|
||||||
# Move to GPU
|
# Move to GPU
|
||||||
model = model.cuda()
|
if torch.cuda.is_available():
|
||||||
print("- Model moved to GPU")
|
model = model.cuda()
|
||||||
|
device = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
model = model.to("mps")
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
model = model.cpu()
|
||||||
|
device = "cpu"
|
||||||
|
print(f"- Model moved to {device}")
|
||||||
|
|
||||||
# FP16
|
# FP16
|
||||||
if config.use_fp16 and not config.use_int4:
|
if config.use_fp16 and not config.use_int4:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
# use torch compile
|
# use torch compile
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
print("- Using FP16 precision")
|
print("- Using FP16 precision")
|
||||||
|
|
||||||
# Check if using SDPA
|
# Check if using SDPA (only on CUDA)
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
|
|
||||||
# Flash Attention
|
# Flash Attention (only on CUDA)
|
||||||
if config.use_flash_attention:
|
if config.use_flash_attention and torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attention import FlashAttention
|
from flash_attn.flash_attention import FlashAttention # noqa: F401
|
||||||
|
|
||||||
print("- Flash Attention 2 available")
|
print("- Flash Attention 2 available")
|
||||||
if hasattr(model.config, "attention_mode"):
|
if hasattr(model.config, "attention_mode"):
|
||||||
model.config.attention_mode = "flash_attention_2"
|
model.config.attention_mode = "flash_attention_2"
|
||||||
print(" - Enabled Flash Attention 2 mode")
|
print(" - Enabled Flash Attention 2 mode")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("- Flash Attention not available")
|
print("- Flash Attention not available")
|
||||||
|
|
||||||
# Memory efficient attention
|
# Memory efficient attention (only on CUDA)
|
||||||
try:
|
if torch.cuda.is_available():
|
||||||
from xformers.ops import memory_efficient_attention
|
try:
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
print("- Enabled xformers memory efficient attention")
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
else:
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Model doesn't support xformers")
|
print("- Enabled xformers memory efficient attention")
|
||||||
except (ImportError, AttributeError):
|
else:
|
||||||
print("- Xformers not available")
|
print("- Model doesn't support xformers")
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
print("- Xformers not available")
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
print("- Model set to eval mode")
|
print("- Model set to eval mode")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
"""Handles accurate GPU timing using CUDA events."""
|
"""Handles accurate GPU timing using GPU events or CPU timing."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
if torch.cuda.is_available():
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
self.use_gpu_timing = True
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
# MPS doesn't have events, use CPU timing
|
||||||
|
self.use_gpu_timing = False
|
||||||
|
else:
|
||||||
|
# CPU timing
|
||||||
|
self.use_gpu_timing = False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def timing(self):
|
def timing(self):
|
||||||
self.start_event.record()
|
if self.use_gpu_timing:
|
||||||
yield
|
self.start_event.record()
|
||||||
self.end_event.record()
|
yield
|
||||||
self.end_event.synchronize()
|
self.end_event.record()
|
||||||
|
self.end_event.synchronize()
|
||||||
|
else:
|
||||||
|
# Use CPU timing for MPS/CPU
|
||||||
|
start_time = time.time()
|
||||||
|
yield
|
||||||
|
self.cpu_elapsed = time.time() - start_time
|
||||||
|
|
||||||
def elapsed_time(self) -> float:
|
def elapsed_time(self) -> float:
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
if self.use_gpu_timing:
|
||||||
|
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||||
|
else:
|
||||||
|
return self.cpu_elapsed
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
"""Main benchmark runner."""
|
"""Main benchmark runner."""
|
||||||
|
|
||||||
def __init__(self, config: BenchmarkConfig):
|
def __init__(self, config: BenchmarkConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
try:
|
try:
|
||||||
self.model = self._load_model()
|
self.model = self._load_model()
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise ValueError("Model initialization failed - model is None")
|
raise ValueError("Model initialization failed - model is None")
|
||||||
|
|
||||||
self.cuda_graphs = (
|
# Only use CUDA graphs on NVIDIA GPUs
|
||||||
CUDAGraphContainer(self.model, config.seq_length)
|
if config.use_cuda_graphs and torch.cuda.is_available():
|
||||||
if config.use_cuda_graphs
|
self.graphs = GraphContainer(self.model, config.seq_length)
|
||||||
else None
|
else:
|
||||||
)
|
self.graphs = None
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
print(f"ERROR in benchmark initialization: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
def _load_model(self) -> nn.Module:
|
||||||
print(f"Loading model from {self.config.model_path}...")
|
print(f"Loading model from {self.config.model_path}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Int4 quantization using HuggingFace integration
|
# Int4 quantization using HuggingFace integration
|
||||||
if self.config.use_int4:
|
if self.config.use_int4:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||||
|
|
||||||
# 检查是否使用自定义的8bit量化
|
# Check if using custom 8bit quantization
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||||
|
|
||||||
# 加载原始模型(不使用量化配置)
|
# Load original model (without quantization config)
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# set default to half
|
# set default to half
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||||
@@ -201,276 +250,281 @@ class Benchmark:
|
|||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义替换函数
|
# Define replacement function
|
||||||
def replace_linear_with_linear8bitlt(model):
|
def replace_linear_with_linear8bitlt(model):
|
||||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
|
||||||
for name, module in list(model.named_children()):
|
for name, module in list(model.named_children()):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
# 获取原始线性层的参数
|
# Get original linear layer parameters
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
out_features = module.out_features
|
out_features = module.out_features
|
||||||
bias = module.bias is not None
|
bias = module.bias is not None
|
||||||
|
|
||||||
# 创建8bit线性层
|
# Create 8bit linear layer
|
||||||
# print size
|
# print size
|
||||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||||
new_module = bnb.nn.Linear8bitLt(
|
new_module = bnb.nn.Linear8bitLt(
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
has_fp16_weights=False
|
has_fp16_weights=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 复制权重和偏置
|
# Copy weights and bias
|
||||||
new_module.weight.data = module.weight.data
|
new_module.weight.data = module.weight.data
|
||||||
if bias:
|
if bias:
|
||||||
new_module.bias.data = module.bias.data
|
new_module.bias.data = module.bias.data
|
||||||
|
|
||||||
# 替换模块
|
# Replace module
|
||||||
setattr(model, name, new_module)
|
setattr(model, name, new_module)
|
||||||
else:
|
else:
|
||||||
# 递归处理子模块
|
# Process child modules recursively
|
||||||
replace_linear_with_linear8bitlt(module)
|
replace_linear_with_linear8bitlt(module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# 替换所有线性层
|
# Replace all linear layers
|
||||||
model = replace_linear_with_linear8bitlt(model)
|
model = replace_linear_with_linear8bitlt(model)
|
||||||
# add torch compile
|
# add torch compile
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
# 将模型移到GPU(量化发生在这里)
|
# Move model to GPU (quantization happens here)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print("- All linear layers replaced with Linear8bitLt")
|
print("- All linear layers replaced with Linear8bitLt")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 使用原来的Int4量化方法
|
# Use original Int4 quantization method
|
||||||
print("- Using bitsandbytes for Int4 quantization")
|
print("- Using bitsandbytes for Int4 quantization")
|
||||||
|
|
||||||
# Create quantization config
|
# Create quantization config
|
||||||
|
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||||
quantization_config = BitsAndBytesConfig(
|
quantization_config = BitsAndBytesConfig(
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_compute_dtype=compute_dtype,
|
bnb_4bit_compute_dtype=compute_dtype,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4"
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("- Quantization config:", quantization_config)
|
print("- Quantization config:", quantization_config)
|
||||||
|
|
||||||
# Load model directly with quantization config
|
# Load model directly with quantization config
|
||||||
model = AutoModel.from_pretrained(
|
model = AutoModel.from_pretrained(
|
||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
device_map="auto" # Let HF decide on device mapping
|
device_map="auto", # Let HF decide on device mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if model loaded successfully
|
# Check if model loaded successfully
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError("Model loading returned None")
|
raise ValueError("Model loading returned None")
|
||||||
|
|
||||||
print(f"- Model type: {type(model)}")
|
print(f"- Model type: {type(model)}")
|
||||||
|
|
||||||
# Apply optimizations directly here
|
# Apply optimizations directly here
|
||||||
print("\nApplying model optimizations:")
|
print("\nApplying model optimizations:")
|
||||||
|
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||||
else:
|
else:
|
||||||
# Skip moving to GPU since device_map="auto" already did that
|
# Skip moving to GPU since device_map="auto" already did that
|
||||||
print("- Model already on GPU due to device_map='auto'")
|
print("- Model already on GPU due to device_map='auto'")
|
||||||
|
|
||||||
# Skip FP16 conversion since we specified compute_dtype
|
# Skip FP16 conversion since we specified compute_dtype
|
||||||
print(f"- Using {compute_dtype} for compute dtype")
|
print(f"- Using {compute_dtype} for compute dtype")
|
||||||
|
|
||||||
# Check CUDA and SDPA
|
# Check CUDA and SDPA
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
|
|
||||||
# Try xformers if available
|
# Try xformers if available (only on CUDA)
|
||||||
try:
|
if torch.cuda.is_available():
|
||||||
from xformers.ops import memory_efficient_attention
|
try:
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
model.enable_xformers_memory_efficient_attention()
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Enabled xformers memory efficient attention")
|
print("- Enabled xformers memory efficient attention")
|
||||||
else:
|
else:
|
||||||
print("- Model doesn't support xformers")
|
print("- Model doesn't support xformers")
|
||||||
except (ImportError, AttributeError):
|
except (ImportError, AttributeError):
|
||||||
print("- Xformers not available")
|
print("- Xformers not available")
|
||||||
|
|
||||||
# Set to eval mode
|
# Set to eval mode
|
||||||
model.eval()
|
model.eval()
|
||||||
print("- Model set to eval mode")
|
print("- Model set to eval mode")
|
||||||
# Int8 quantization using HuggingFace integration
|
# Int8 quantization using HuggingFace integration
|
||||||
# Int8 quantization using TorchAO
|
|
||||||
elif self.config.use_int8:
|
elif self.config.use_int8:
|
||||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
print("- Using INT8 quantization")
|
||||||
|
# For now, just use standard loading with INT8 config
|
||||||
# Import the quantize_ function and the quantization config
|
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
quantization_config = BitsAndBytesConfig(
|
||||||
print("- Successfully imported TorchAO")
|
load_in_8bit=True,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
# Load model normally first
|
llm_int8_has_fp16_weight=False,
|
||||||
# set default to half
|
)
|
||||||
import torch
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
|
||||||
model = AutoModel.from_pretrained(
|
model = AutoModel.from_pretrained(
|
||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
device_map="auto"
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=compute_dtype,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("- Model loaded in full precision")
|
if model is None:
|
||||||
|
raise ValueError("Model loading returned None")
|
||||||
|
|
||||||
print(f"- Model type: {type(model)}")
|
print(f"- Model type: {type(model)}")
|
||||||
|
|
||||||
# Apply quantization - call the function to get the config, then apply it
|
|
||||||
# quantize_(model, int8_dynamic_activation_int8_weight())
|
|
||||||
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
|
|
||||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
|
||||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
|
||||||
print("- Model successfully quantized with int8 weights and int8 activations")
|
|
||||||
# add torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
# For older PyTorch versions that have issues with tensor subclasses
|
|
||||||
from torchao.utils import unwrap_tensor_subclass
|
|
||||||
import torch
|
|
||||||
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
|
|
||||||
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
|
|
||||||
unwrap_tensor_subclass(model)
|
|
||||||
|
|
||||||
# Apply optimizations
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Set to eval mode
|
|
||||||
model.eval()
|
model.eval()
|
||||||
print("- Model set to eval mode")
|
print("- Model set to eval mode")
|
||||||
|
|
||||||
# For better performance with int8 dynamic quantization
|
|
||||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
|
||||||
print("- Enabled fusion of int matmul with mul operations")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Standard loading for FP16/FP32
|
# Standard loading for FP16/FP32
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
model = AutoModel.from_pretrained(self.config.model_path)
|
||||||
print("- Model loaded in standard precision")
|
print("- Model loaded in standard precision")
|
||||||
print(f"- Model type: {type(model)}")
|
print(f"- Model type: {type(model)}")
|
||||||
|
|
||||||
# Apply standard optimizations
|
# Apply standard optimizations
|
||||||
# set default to half
|
# set default to half
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model = ModelOptimizer.optimize(model, self.config)
|
model = ModelOptimizer.optimize(model, self.config)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
# add torch compile
|
# add torch compile
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
# Final check to ensure model is not None
|
# Final check to ensure model is not None
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError("Model is None after optimization")
|
raise ValueError("Model is None after optimization")
|
||||||
|
|
||||||
print(f"- Final model type: {type(model)}")
|
print(f"- Final model type: {type(model)}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR loading model: {str(e)}")
|
print(f"ERROR loading model: {e!s}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
device = (
|
||||||
0, 1000,
|
"cuda"
|
||||||
(batch_size, self.config.seq_length),
|
if torch.cuda.is_available()
|
||||||
device="cuda",
|
else "mps"
|
||||||
dtype=torch.long
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
)
|
)
|
||||||
|
return torch.randint(
|
||||||
|
0,
|
||||||
|
1000,
|
||||||
|
(batch_size, self.config.seq_length),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
|
||||||
def _run_inference(
|
def _run_inference(
|
||||||
self,
|
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
|
||||||
input_ids: torch.Tensor,
|
) -> tuple[float, torch.Tensor]:
|
||||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
with torch.no_grad(), self.timer.timing():
|
||||||
if cuda_graph_wrapper is not None:
|
if graph_wrapper is not None:
|
||||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
output = graph_wrapper(input_ids, attention_mask)
|
||||||
else:
|
else:
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
return self.timer.elapsed_time(), output
|
return self.timer.elapsed_time(), output
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
# Reset peak memory stats
|
# Reset peak memory stats
|
||||||
torch.cuda.reset_peak_memory_stats()
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
# MPS doesn't have reset_peak_memory_stats, skip it
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("- No GPU memory stats available")
|
||||||
|
|
||||||
for batch_size in self.config.batch_sizes:
|
for batch_size in self.config.batch_sizes:
|
||||||
print(f"\nTesting batch size: {batch_size}")
|
print(f"\nTesting batch size: {batch_size}")
|
||||||
times = []
|
times = []
|
||||||
|
|
||||||
# Get or create CUDA graph for this batch size
|
# Get or create graph for this batch size
|
||||||
cuda_graph_wrapper = (
|
graph_wrapper = (
|
||||||
self.cuda_graphs.get_or_create(batch_size)
|
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
|
||||||
if self.cuda_graphs is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
# Pre-allocate input tensor
|
||||||
input_ids = self._create_random_batch(batch_size)
|
input_ids = self._create_random_batch(batch_size)
|
||||||
print(f"Input shape: {input_ids.shape}")
|
print(f"Input shape: {input_ids.shape}")
|
||||||
|
|
||||||
# Run benchmark
|
# Run benchmark
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||||
try:
|
try:
|
||||||
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
|
elapsed_time, output = self._run_inference(input_ids, graph_wrapper)
|
||||||
if i == 0: # Only print on first run
|
if i == 0: # Only print on first run
|
||||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
print(f"Output shape: {output.last_hidden_state.shape}")
|
||||||
times.append(elapsed_time)
|
times.append(elapsed_time)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during inference: {e}")
|
print(f"Error during inference: {e}")
|
||||||
break
|
break
|
||||||
|
|
||||||
if not times:
|
if not times:
|
||||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
print(f"No successful runs for batch size {batch_size}, skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Calculate statistics
|
# Calculate statistics
|
||||||
avg_time = np.mean(times)
|
avg_time = np.mean(times)
|
||||||
std_time = np.std(times)
|
std_time = np.std(times)
|
||||||
throughput = batch_size / avg_time
|
throughput = batch_size / avg_time
|
||||||
|
|
||||||
results[batch_size] = {
|
results[batch_size] = {
|
||||||
"avg_time": avg_time,
|
"avg_time": avg_time,
|
||||||
"std_time": std_time,
|
"std_time": std_time,
|
||||||
"throughput": throughput,
|
"throughput": throughput,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
# Log memory usage
|
# Log memory usage
|
||||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
if torch.cuda.is_available():
|
||||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
# MPS doesn't have max_memory_allocated, use 0
|
||||||
|
peak_memory_gb = 0.0
|
||||||
|
else:
|
||||||
|
peak_memory_gb = 0.0
|
||||||
|
print("- No GPU memory usage available")
|
||||||
|
|
||||||
|
if peak_memory_gb > 0:
|
||||||
|
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
||||||
|
else:
|
||||||
|
print("\n- GPU memory usage not available")
|
||||||
|
|
||||||
# Add memory info to results
|
# Add memory info to results
|
||||||
for batch_size in results:
|
for batch_size in results:
|
||||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@@ -485,7 +539,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_sizes",
|
"--batch_sizes",
|
||||||
type=str,
|
type=str,
|
||||||
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
|
default="1,2,4,8,16,32",
|
||||||
help="Comma-separated list of batch sizes",
|
help="Comma-separated list of batch sizes",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -518,26 +572,26 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_cuda_graphs",
|
"--use_cuda_graphs",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable CUDA Graphs optimization",
|
help="Enable CUDA Graphs optimization (only on NVIDIA GPUs)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_flash_attention",
|
"--use_flash_attention",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable Flash Attention 2 if available",
|
help="Enable Flash Attention 2 if available (only on NVIDIA GPUs)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_linear8bitlt",
|
"--use_linear8bitlt",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable Linear8bitLt quantization for all linear layers",
|
help="Enable Linear8bitLt quantization for all linear layers",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Print arguments for debugging
|
# Print arguments for debugging
|
||||||
print("\nCommand line arguments:")
|
print("\nCommand line arguments:")
|
||||||
for arg, value in vars(args).items():
|
for arg, value in vars(args).items():
|
||||||
print(f"- {arg}: {value}")
|
print(f"- {arg}: {value}")
|
||||||
|
|
||||||
config = BenchmarkConfig(
|
config = BenchmarkConfig(
|
||||||
model_path=args.model_path,
|
model_path=args.model_path,
|
||||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||||
@@ -550,45 +604,56 @@ def main():
|
|||||||
use_flash_attention=args.use_flash_attention,
|
use_flash_attention=args.use_flash_attention,
|
||||||
use_linear8bitlt=args.use_linear8bitlt,
|
use_linear8bitlt=args.use_linear8bitlt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print configuration for debugging
|
# Print configuration for debugging
|
||||||
print("\nBenchmark configuration:")
|
print("\nBenchmark configuration:")
|
||||||
for field, value in vars(config).items():
|
for field, value in vars(config).items():
|
||||||
print(f"- {field}: {value}")
|
print(f"- {field}: {value}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
benchmark = Benchmark(config)
|
benchmark = Benchmark(config)
|
||||||
results = benchmark.run()
|
results = benchmark.run()
|
||||||
|
|
||||||
# Save results to file
|
# Save results to file
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Create results directory if it doesn't exist
|
# Create results directory if it doesn't exist
|
||||||
os.makedirs("results", exist_ok=True)
|
os.makedirs("results", exist_ok=True)
|
||||||
|
|
||||||
# Generate filename based on configuration
|
# Generate filename based on configuration
|
||||||
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
|
precision_type = (
|
||||||
|
"int4"
|
||||||
|
if config.use_int4
|
||||||
|
else "int8"
|
||||||
|
if config.use_int8
|
||||||
|
else "fp16"
|
||||||
|
if config.use_fp16
|
||||||
|
else "fp32"
|
||||||
|
)
|
||||||
model_name = os.path.basename(config.model_path)
|
model_name = os.path.basename(config.model_path)
|
||||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||||
|
|
||||||
# Save results
|
# Save results
|
||||||
with open(output_file, "w") as f:
|
with open(output_file, "w") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
"config": {
|
||||||
"results": {str(k): v for k, v in results.items()}
|
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
|
||||||
},
|
},
|
||||||
f,
|
"results": {str(k): v for k, v in results.items()},
|
||||||
indent=2
|
},
|
||||||
|
f,
|
||||||
|
indent=2,
|
||||||
)
|
)
|
||||||
print(f"Results saved to {output_file}")
|
print(f"Results saved to {output_file}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Benchmark failed: {e}")
|
print(f"Benchmark failed: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
359
benchmarks/run_evaluation.py
Normal file
359
benchmarks/run_evaluation.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
This script runs a recall evaluation on a given LEANN index.
|
||||||
|
It correctly compares results by fetching the text content for both the new search
|
||||||
|
results and the golden standard results, making the comparison robust to ID changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
|
if not data_root.exists():
|
||||||
|
print(f"Data directory '{data_root}' not found.")
|
||||||
|
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
if download_embeddings:
|
||||||
|
# Download everything including embeddings (large files)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
print("Data download complete (including embeddings)!")
|
||||||
|
else:
|
||||||
|
# Download only specific folders, excluding embeddings
|
||||||
|
allow_patterns = [
|
||||||
|
"ground_truth/**",
|
||||||
|
"indices/**",
|
||||||
|
"queries/**",
|
||||||
|
"*.md",
|
||||||
|
"*.txt",
|
||||||
|
]
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
print("Data download complete (excluding embeddings)!")
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
|
)
|
||||||
|
print("uv pip install -e '.[dev]'")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred during data download: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||||
|
"""Download embeddings files specifically."""
|
||||||
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
# Check if specific dataset embeddings exist
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
print(f"Embeddings for {dataset_type} already exist")
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
print("Downloading embeddings from HuggingFace Hub...")
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# Download only embeddings folder
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=["embeddings/**/*.pkl"],
|
||||||
|
)
|
||||||
|
print("Embeddings download complete!")
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
return str(embeddings_dir)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading embeddings: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helper Function to get Golden Passages ---
|
||||||
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||||
|
"""
|
||||||
|
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||||
|
passage manager.
|
||||||
|
"""
|
||||||
|
golden_texts = set()
|
||||||
|
for gid in golden_ids:
|
||||||
|
try:
|
||||||
|
# PassageManager uses string IDs
|
||||||
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
|
golden_texts.add(passage_data["text"])
|
||||||
|
except KeyError:
|
||||||
|
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||||
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(file_path: Path) -> list[str]:
|
||||||
|
queries = []
|
||||||
|
with open(file_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["query"])
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||||
|
"""
|
||||||
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
||||||
|
output_path: Path where to save the index
|
||||||
|
backend: Backend to use ("hnsw" or "diskann")
|
||||||
|
"""
|
||||||
|
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
||||||
|
|
||||||
|
# Create builder with appropriate parameters
|
||||||
|
if backend == "hnsw":
|
||||||
|
builder_kwargs = {
|
||||||
|
"M": 32, # Graph degree
|
||||||
|
"efConstruction": 256, # Construction complexity
|
||||||
|
"is_compact": True, # Use compact storage
|
||||||
|
"is_recompute": True, # Enable pruning for better recall
|
||||||
|
}
|
||||||
|
elif backend == "diskann":
|
||||||
|
builder_kwargs = {
|
||||||
|
"complexity": 64,
|
||||||
|
"graph_degree": 32,
|
||||||
|
"search_memory_maximum": 8.0, # GB
|
||||||
|
"build_memory_maximum": 16.0, # GB
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
builder_kwargs = {}
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
||||||
|
dimensions=768, # Will be auto-detected from embeddings
|
||||||
|
**builder_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build index from precomputed embeddings
|
||||||
|
builder.build_index_from_embeddings(output_path, embeddings_file)
|
||||||
|
print(f"Index saved to: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||||
|
parser.add_argument(
|
||||||
|
"index_path",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="Path to the LEANN index to evaluate or build (optional).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
choices=["evaluate", "build"],
|
||||||
|
default="evaluate",
|
||||||
|
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embeddings-file",
|
||||||
|
type=str,
|
||||||
|
help="Path to embeddings pickle file (optional for build mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
default="hnsw",
|
||||||
|
help="Backend to use for building index (default: hnsw)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
|
)
|
||||||
|
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# --- Path Configuration ---
|
||||||
|
# Assumes a project structure where the script is in 'benchmarks/'
|
||||||
|
# and evaluation data is in 'benchmarks/data/'.
|
||||||
|
script_dir = Path(__file__).resolve().parent
|
||||||
|
data_root = script_dir / "data"
|
||||||
|
|
||||||
|
# Download data based on mode
|
||||||
|
if args.mode == "build":
|
||||||
|
# For building mode, we need embeddings
|
||||||
|
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||||
|
|
||||||
|
# Auto-detect dataset type and download embeddings
|
||||||
|
if args.embeddings_file:
|
||||||
|
embeddings_file = args.embeddings_file
|
||||||
|
# Try to detect dataset type from embeddings file path
|
||||||
|
if "rpj_wiki" in str(embeddings_file):
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in str(embeddings_file):
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default
|
||||||
|
else:
|
||||||
|
# Auto-detect from index path if provided, otherwise default to DPR
|
||||||
|
if args.index_path:
|
||||||
|
index_path_str = str(args.index_path)
|
||||||
|
if "rpj_wiki" in index_path_str:
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in index_path_str:
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
|
||||||
|
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
||||||
|
|
||||||
|
# Auto-generate index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
indices_dir = data_root / "indices" / dataset_type
|
||||||
|
indices_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
||||||
|
print(f"Auto-generated index path: {args.index_path}")
|
||||||
|
|
||||||
|
print(f"Building index from embeddings: {embeddings_file}")
|
||||||
|
built_index_path = build_index_from_embeddings(
|
||||||
|
embeddings_file, args.index_path, args.backend
|
||||||
|
)
|
||||||
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
|
# Ask if user wants to run evaluation
|
||||||
|
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
|
if eval_response != "y":
|
||||||
|
print("Index building complete. Exiting.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# For evaluation mode, don't need embeddings
|
||||||
|
download_data_if_needed(data_root, download_embeddings=False)
|
||||||
|
|
||||||
|
# Auto-detect index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
# Default to using downloaded indices
|
||||||
|
indices_dir = data_root / "indices"
|
||||||
|
|
||||||
|
# Try common datasets in order of preference
|
||||||
|
for dataset in ["dpr", "rpj_wiki"]:
|
||||||
|
dataset_dir = indices_dir / dataset
|
||||||
|
if dataset_dir.exists():
|
||||||
|
# Look for index files
|
||||||
|
index_files = list(dataset_dir.glob("*.index")) + list(
|
||||||
|
dataset_dir.glob("*_disk.index")
|
||||||
|
)
|
||||||
|
if index_files:
|
||||||
|
args.index_path = str(
|
||||||
|
index_files[0].with_suffix("")
|
||||||
|
) # Remove .index extension
|
||||||
|
print(f"Using index: {args.index_path}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not args.index_path:
|
||||||
|
print("No indices found. The data download should have included pre-built indices.")
|
||||||
|
print(
|
||||||
|
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Detect dataset type from index path to select the correct ground truth
|
||||||
|
index_path_str = str(args.index_path)
|
||||||
|
if "rpj_wiki" in index_path_str:
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in index_path_str:
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
# Fallback: try to infer from the index directory name
|
||||||
|
dataset_type = Path(args.index_path).name
|
||||||
|
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||||
|
|
||||||
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
|
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
|
|
||||||
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
|
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(args.index_path)
|
||||||
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
|
with open(golden_results_file) as f:
|
||||||
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
|
queries = queries[:num_eval_queries]
|
||||||
|
|
||||||
|
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
||||||
|
recall_scores = []
|
||||||
|
search_times = []
|
||||||
|
|
||||||
|
for i in range(num_eval_queries):
|
||||||
|
start_time = time.time()
|
||||||
|
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
||||||
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
|
# Get golden texts directly from the searcher's passage manager
|
||||||
|
golden_ids = golden_results_data["indices"][i][: args.top_k]
|
||||||
|
golden_texts = get_golden_texts(searcher, golden_ids)
|
||||||
|
|
||||||
|
overlap = len(new_texts & golden_texts)
|
||||||
|
recall = overlap / len(golden_texts) if golden_texts else 0
|
||||||
|
recall_scores.append(recall)
|
||||||
|
|
||||||
|
print("\n--- EVALUATION RESULTS ---")
|
||||||
|
print(f"Query: {queries[i]}")
|
||||||
|
print(f"New Results: {new_texts}")
|
||||||
|
print(f"Golden Results: {golden_texts}")
|
||||||
|
print(f"Overlap: {overlap}")
|
||||||
|
print(f"Recall: {recall}")
|
||||||
|
print(f"Search Time: {search_times[-1]:.4f}s")
|
||||||
|
print("--------------------------------")
|
||||||
|
|
||||||
|
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||||
|
avg_time = np.mean(search_times) if search_times else 0
|
||||||
|
|
||||||
|
print("\n🎉 --- Evaluation Complete ---")
|
||||||
|
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||||
|
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
311
benchmarks/simple_mac_tpt_test.py
Normal file
311
benchmarks/simple_mac_tpt_test.py
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
# Add MLX imports
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.utils import load
|
||||||
|
|
||||||
|
MLX_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||||
|
MLX_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkConfig:
|
||||||
|
model_path: str = "facebook/contriever"
|
||||||
|
batch_sizes: list[int] = None
|
||||||
|
seq_length: int = 256
|
||||||
|
num_runs: int = 5
|
||||||
|
use_fp16: bool = True
|
||||||
|
use_int4: bool = False
|
||||||
|
use_int8: bool = False
|
||||||
|
use_cuda_graphs: bool = False
|
||||||
|
use_flash_attention: bool = False
|
||||||
|
use_linear8bitlt: bool = False
|
||||||
|
use_mlx: bool = False # New flag for MLX testing
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.batch_sizes is None:
|
||||||
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||||
|
|
||||||
|
|
||||||
|
class MLXBenchmark:
|
||||||
|
"""MLX-specific benchmark for embedding models"""
|
||||||
|
|
||||||
|
def __init__(self, config: BenchmarkConfig):
|
||||||
|
self.config = config
|
||||||
|
self.model, self.tokenizer = self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""Load MLX model and tokenizer following the API pattern"""
|
||||||
|
print(f"Loading MLX model from {self.config.model_path}...")
|
||||||
|
try:
|
||||||
|
model, tokenizer = load(self.config.model_path)
|
||||||
|
print("MLX model loaded successfully")
|
||||||
|
return model, tokenizer
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading MLX model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_random_batch(self, batch_size: int):
|
||||||
|
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||||
|
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
||||||
|
|
||||||
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
|
"""Run MLX inference with same input as PyTorch"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
# Convert PyTorch tensor to MLX array
|
||||||
|
input_ids_mlx = mx.array(input_ids.numpy())
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
embeddings = self.model(input_ids_mlx)
|
||||||
|
|
||||||
|
# Mean pooling (following the API pattern)
|
||||||
|
pooled = embeddings.mean(axis=1)
|
||||||
|
|
||||||
|
# Convert to numpy (following the API pattern)
|
||||||
|
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
|
||||||
|
|
||||||
|
# Force computation
|
||||||
|
_ = pooled_numpy.shape
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"MLX inference error: {e}")
|
||||||
|
return float("inf")
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
return end_time - start_time
|
||||||
|
|
||||||
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
|
"""Run the MLX benchmark across all batch sizes"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
print(f"Starting MLX benchmark with model: {self.config.model_path}")
|
||||||
|
print(f"Testing batch sizes: {self.config.batch_sizes}")
|
||||||
|
|
||||||
|
for batch_size in self.config.batch_sizes:
|
||||||
|
print(f"\n=== Testing MLX batch size: {batch_size} ===")
|
||||||
|
times = []
|
||||||
|
|
||||||
|
# Create input batch (same as PyTorch)
|
||||||
|
input_ids = self._create_random_batch(batch_size)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
print("Warming up...")
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
self._run_inference(input_ids[:2]) # Warm up with smaller batch
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warmup error: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||||
|
try:
|
||||||
|
elapsed_time = self._run_inference(input_ids)
|
||||||
|
if elapsed_time != float("inf"):
|
||||||
|
times.append(elapsed_time)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during MLX inference: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not times:
|
||||||
|
print(f"Skipping batch size {batch_size} due to errors")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
avg_time = np.mean(times)
|
||||||
|
std_time = np.std(times)
|
||||||
|
throughput = batch_size / avg_time
|
||||||
|
|
||||||
|
results[batch_size] = {
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"std_time": std_time,
|
||||||
|
"throughput": throughput,
|
||||||
|
"min_time": np.min(times),
|
||||||
|
"max_time": np.max(times),
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"MLX Results for batch size {batch_size}:")
|
||||||
|
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||||
|
print(f" Min Time: {np.min(times):.4f}s")
|
||||||
|
print(f" Max Time: {np.max(times):.4f}s")
|
||||||
|
print(f" Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class Benchmark:
|
||||||
|
def __init__(self, config: BenchmarkConfig):
|
||||||
|
self.config = config
|
||||||
|
self.device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
self.model = self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self) -> nn.Module:
|
||||||
|
print(f"Loading model from {self.config.model_path}...")
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(self.config.model_path)
|
||||||
|
if self.config.use_fp16:
|
||||||
|
model = model.half()
|
||||||
|
model = torch.compile(model)
|
||||||
|
model = model.to(self.device)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
|
return torch.randint(
|
||||||
|
0,
|
||||||
|
1000,
|
||||||
|
(batch_size, self.config.seq_length),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
return end_time - start_time
|
||||||
|
|
||||||
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
for batch_size in self.config.batch_sizes:
|
||||||
|
print(f"\nTesting batch size: {batch_size}")
|
||||||
|
times = []
|
||||||
|
|
||||||
|
input_ids = self._create_random_batch(batch_size)
|
||||||
|
|
||||||
|
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||||
|
try:
|
||||||
|
elapsed_time = self._run_inference(input_ids)
|
||||||
|
times.append(elapsed_time)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during inference: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not times:
|
||||||
|
continue
|
||||||
|
|
||||||
|
avg_time = np.mean(times)
|
||||||
|
std_time = np.std(times)
|
||||||
|
throughput = batch_size / avg_time
|
||||||
|
|
||||||
|
results[batch_size] = {
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"std_time": std_time,
|
||||||
|
"throughput": throughput,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||||
|
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
|
else:
|
||||||
|
peak_memory_gb = 0.0
|
||||||
|
|
||||||
|
for batch_size in results:
|
||||||
|
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark():
|
||||||
|
"""Main function to run the benchmark with optimized parameters."""
|
||||||
|
config = BenchmarkConfig()
|
||||||
|
|
||||||
|
try:
|
||||||
|
benchmark = Benchmark(config)
|
||||||
|
results = benchmark.run()
|
||||||
|
|
||||||
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
|
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"max_throughput": max_throughput,
|
||||||
|
"avg_throughput": avg_throughput,
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Benchmark failed: {e}")
|
||||||
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def run_mlx_benchmark():
|
||||||
|
"""Run MLX-specific benchmark"""
|
||||||
|
if not MLX_AVAILABLE:
|
||||||
|
print("MLX not available, skipping MLX benchmark")
|
||||||
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": "MLX not available",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
benchmark = MLXBenchmark(config)
|
||||||
|
results = benchmark.run()
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": "No valid results",
|
||||||
|
}
|
||||||
|
|
||||||
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
|
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"max_throughput": max_throughput,
|
||||||
|
"avg_throughput": avg_throughput,
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"MLX benchmark failed: {e}")
|
||||||
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=== PyTorch Benchmark ===")
|
||||||
|
pytorch_result = run_benchmark()
|
||||||
|
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
|
||||||
|
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
|
||||||
|
|
||||||
|
print("\n=== MLX Benchmark ===")
|
||||||
|
mlx_result = run_mlx_benchmark()
|
||||||
|
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
|
||||||
|
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
||||||
|
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
||||||
|
print("\n=== Comparison ===")
|
||||||
|
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||||
BIN
data/2501.14312v1 (1).pdf
Normal file
BIN
data/2501.14312v1 (1).pdf
Normal file
Binary file not shown.
14905
data/PrideandPrejudice.txt
Normal file
14905
data/PrideandPrejudice.txt
Normal file
File diff suppressed because it is too large
Load Diff
322
demo.ipynb
322
demo.ipynb
@@ -1,226 +1,116 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "markdown",
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO: LeannBuilder initialized with 'diskann' backend.\n",
|
|
||||||
"INFO: Computing embeddings for 6 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 77.61it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO: Building DiskANN index for 6 vectors with metric Metric.INNER_PRODUCT...\n",
|
|
||||||
"Using Inner Product search, so need to pre-process base data into temp file. Please ensure there is additional (n*(d+1)*4) bytes for storing pre-processed base vectors, apart from the interim indices created by DiskANN and the final index.\n",
|
|
||||||
"Pre-processing base file by adding extra coordinate\n",
|
|
||||||
"✅ DiskANN index built successfully at 'knowledge'\n",
|
|
||||||
"Writing bin: knowledge_disk.index_max_base_norm.bin\n",
|
|
||||||
"bin: #pts = 1, #dims = 1, size = 12B\n",
|
|
||||||
"Finished writing bin.\n",
|
|
||||||
"Time for preprocessing data for inner product: 0.000165 seconds\n",
|
|
||||||
"Reading max_norm_of_base from knowledge_disk.index_max_base_norm.bin\n",
|
|
||||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
|
||||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
|
||||||
"Metadata: #pts = 1, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"max_norm_of_base: 1\n",
|
|
||||||
"! Using prepped_base file at knowledge_prepped_base.bin\n",
|
|
||||||
"Starting index build: R=32 L=64 Query RAM budget: 4.02653e+09 Indexing ram budget: 8 T: 8\n",
|
|
||||||
"getting bin metadata\n",
|
|
||||||
"Time for getting bin metadata: 0.000008 seconds\n",
|
|
||||||
"Compressing 769-dimensional data into 512 bytes per vector.\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Training data with 6 samples loaded.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 256, #dims = 769...\n",
|
|
||||||
"done.\n",
|
|
||||||
"PQ pivot file exists. Not generating again\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 4, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 256, #dims = 769...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 769, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 513, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Loaded PQ pivot information\n",
|
|
||||||
"Processing points [0, 6)...done.\n",
|
|
||||||
"Time for generating quantized data: 0.023918 seconds\n",
|
|
||||||
"Full index fits in RAM budget, should consume at most 2.03973e-05GiBs, so building in one shot\n",
|
|
||||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
|
||||||
"Passed, empty search_params while creating index config\n",
|
|
||||||
"Using only first 6 from file.. \n",
|
|
||||||
"Starting index build with 6 points... \n",
|
|
||||||
"0% of index build completed.Starting final cleanup..done. Link time: 9e-05s\n",
|
|
||||||
"Index built with degree: max:5 avg:5 min:5 count(deg<2):0\n",
|
|
||||||
"Not saving tags as they are not enabled.\n",
|
|
||||||
"Time taken for save: 0.000178s.\n",
|
|
||||||
"Time for building merged vamana index: 0.000579 seconds\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Vamana index file size=168\n",
|
|
||||||
"Opened: knowledge_disk.index, cache_size: 67108864\n",
|
|
||||||
"medoid: 0B\n",
|
|
||||||
"max_node_len: 3100B\n",
|
|
||||||
"nnodes_per_sector: 1B\n",
|
|
||||||
"# sectors: 6\n",
|
|
||||||
"Sector #0written\n",
|
|
||||||
"Finished writing 28672B\n",
|
|
||||||
"Writing bin: knowledge_disk.index\n",
|
|
||||||
"bin: #pts = 9, #dims = 1, size = 80B\n",
|
|
||||||
"Finished writing bin.\n",
|
|
||||||
"Output disk index file written to knowledge_disk.index\n",
|
|
||||||
"Finished writing 28672B\n",
|
|
||||||
"Time for generating disk layout: 0.043488 seconds\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Loading base knowledge_prepped_base.bin. #points: 6. #dim: 769.\n",
|
|
||||||
"Wrote 1 points to sample file: knowledge_sample_data.bin\n",
|
|
||||||
"Indexing time: 0.0684344\n",
|
|
||||||
"INFO: Leann metadata saved to knowledge.leann.meta.json\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n",
|
|
||||||
"Opened file : knowledge_disk.index\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Since data is floating point, we assume that it has been appropriately pre-processed (normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we shall invoke an l2 distance function.\n",
|
|
||||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
|
||||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
|
||||||
"Before index load\n",
|
|
||||||
"✅ DiskANN index loaded successfully.\n",
|
|
||||||
"INFO: LeannSearcher initialized with 'diskann' backend using index 'knowledge.leann'.\n",
|
|
||||||
"Reading bin file knowledge_pq_compressed.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_compressed.bin... \n",
|
|
||||||
"Metadata: #pts = 6, #dims = 512...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 4, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Offsets: 4096 791560 794644 796704\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 256, #dims = 769...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 769, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 513, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Loaded PQ Pivots: #ctrs: 256, #dims: 769, #chunks: 512\n",
|
|
||||||
"Loaded PQ centroids and in-memory compressed vectors. #points: 6 #dim: 769 #aligned_dim: 776 #chunks: 512\n",
|
|
||||||
"Loading index metadata from knowledge_disk.index\n",
|
|
||||||
"Disk-Index File Meta-data: # nodes per sector: 1, max node len (bytes): 3100, max node degree: 5\n",
|
|
||||||
"Disk-Index Meta: nodes per sector: 1, max node len: 3100, max node degree: 5\n",
|
|
||||||
"Setting up thread-specific contexts for nthreads: 8\n",
|
|
||||||
"allocating ctx: 0x78348f4de000 to thread-id:132170359560000\n",
|
|
||||||
"allocating ctx: 0x78348f4cd000 to thread-id:132158431693760\n",
|
|
||||||
"allocating ctx: 0x78348f4bc000 to thread-id:132158442179392\n",
|
|
||||||
"allocating ctx: 0x78348f4ab000 to thread-id:132158421208128\n",
|
|
||||||
"allocating ctx: 0x78348f49a000 to thread-id:132158452665024\n",
|
|
||||||
"allocating ctx: 0x78348f489000 to thread-id:132158389751232\n",
|
|
||||||
"allocating ctx: 0x78348f478000 to thread-id:132158410722496\n",
|
|
||||||
"allocating ctx: 0x78348f467000 to thread-id:132158400236864\n",
|
|
||||||
"Loading centroid data from medoids vector data of 1 medoid(s)\n",
|
|
||||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
|
||||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
|
||||||
"Metadata: #pts = 1, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Setting re-scaling factor of base vectors to 1\n",
|
|
||||||
"load_from_separate_paths done.\n",
|
|
||||||
"Reading (with alignment) bin file knowledge_sample_data.bin ...Metadata: #pts = 1, #dims = 769, aligned_dim = 776... allocating aligned memory of 3104 bytes... done. Copying data to mem_aligned buffer... done.\n",
|
|
||||||
"reserve ratio: 1\n",
|
|
||||||
"Graph traversal completed, hops: 3\n",
|
|
||||||
"Loading the cache list into memory....done.\n",
|
|
||||||
"After index load\n",
|
|
||||||
"Clearing scratch\n",
|
|
||||||
"INFO: Computing embeddings for 1 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 92.66it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Score: -0.481 - C++ is a powerful programming language\n",
|
|
||||||
"Score: -1.049 - Java is a powerful programming language\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"reserve ratio: 1\n",
|
|
||||||
"Graph traversal completed, hops: 3\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannBuilder, LeannSearcher\n",
|
"# Quick Start \n",
|
||||||
"import leann_backend_diskann\n",
|
"\n",
|
||||||
"# 1. Build index (no embeddings stored!)\n",
|
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||||
"builder = LeannBuilder(backend_name=\"diskann\")\n",
|
"\n",
|
||||||
"builder.add_text(\"Python is a powerful programming language\")\n",
|
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
||||||
"builder.add_text(\"Machine learning transforms industries\") \n",
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# install this if you are using colab\n",
|
||||||
|
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
||||||
|
"! uv pip install leann --no-deps\n",
|
||||||
|
"# For Colab environment, we need to set some environment variables\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"INDEX_DIR = Path(\"./\").resolve()\n",
|
||||||
|
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Build the index"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannBuilder\n",
|
||||||
|
"\n",
|
||||||
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
|
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||||
|
"builder.add_text(\n",
|
||||||
|
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
||||||
|
")\n",
|
||||||
|
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
"builder.add_text(\"Java is a powerful programming language\")\n",
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||||
"builder.add_text(\"C++ is a powerful programming language\")\n",
|
"builder.build_index(INDEX_PATH)"
|
||||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
]
|
||||||
"builder.build_index(\"knowledge.leann\")\n",
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Search with real-time embeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannSearcher\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 2. Search with real-time embeddings\n",
|
"searcher = LeannSearcher(INDEX_PATH)\n",
|
||||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||||
"results = searcher.search(\"C++ programming languages\", top_k=2)\n",
|
"results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chat with LEANN using retrieved results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannChat\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for result in results:\n",
|
"llm_config = {\n",
|
||||||
" print(f\"Score: {result['score']:.3f} - {result['text']}\")"
|
" \"type\": \"hf\",\n",
|
||||||
|
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
||||||
|
"response = chat.ask(\n",
|
||||||
|
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||||
|
" top_k=2,\n",
|
||||||
|
" llm_kwargs={\"max_tokens\": 128},\n",
|
||||||
|
")\n",
|
||||||
|
"response"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -240,7 +130,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.11"
|
"version": "3.11.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
220
docs/CONTRIBUTING.md
Normal file
220
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
# 🤝 Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
|
## Ways to Contribute
|
||||||
|
|
||||||
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
|
- 📖 **Documentation**: Help make Leann more accessible
|
||||||
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
|
|
||||||
|
## 🚀 Development Setup
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
1. **Install uv** (fast Python package installer):
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Clone the repository**:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||||
|
cd LEANN-RAG
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Install system dependencies**:
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ubuntu/Debian:**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||||
|
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Build from source**:
|
||||||
|
```bash
|
||||||
|
# macOS
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
|
|
||||||
|
# Ubuntu/Debian
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔨 Pre-commit Hooks
|
||||||
|
|
||||||
|
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||||
|
|
||||||
|
### Setup Pre-commit
|
||||||
|
|
||||||
|
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||||
|
```bash
|
||||||
|
uv pip install pre-commit
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install the git hooks**:
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Run pre-commit manually** (optional):
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pre-commit Checks
|
||||||
|
|
||||||
|
Our pre-commit configuration includes:
|
||||||
|
- **Trailing whitespace removal**
|
||||||
|
- **End-of-file fixing**
|
||||||
|
- **YAML validation**
|
||||||
|
- **Large file prevention**
|
||||||
|
- **Merge conflict detection**
|
||||||
|
- **Debug statement detection**
|
||||||
|
- **Code formatting with ruff**
|
||||||
|
- **Code linting with ruff**
|
||||||
|
|
||||||
|
## 🧪 Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
uv run pytest
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
uv run pytest test/test_filename.py
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
uv run pytest --cov=leann
|
||||||
|
```
|
||||||
|
|
||||||
|
### Writing Tests
|
||||||
|
|
||||||
|
- Place tests in the `test/` directory
|
||||||
|
- Follow the naming convention `test_*.py`
|
||||||
|
- Use descriptive test names that explain what's being tested
|
||||||
|
- Include both positive and negative test cases
|
||||||
|
|
||||||
|
## 📝 Code Style
|
||||||
|
|
||||||
|
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||||
|
|
||||||
|
### Format Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Format all files
|
||||||
|
ruff format
|
||||||
|
|
||||||
|
# Check formatting without changing files
|
||||||
|
ruff format --check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lint Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run linter with auto-fix
|
||||||
|
ruff check --fix
|
||||||
|
|
||||||
|
# Just check without fixing
|
||||||
|
ruff check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Style Guidelines
|
||||||
|
|
||||||
|
- Follow PEP 8 conventions
|
||||||
|
- Use descriptive variable names
|
||||||
|
- Add type hints where appropriate
|
||||||
|
- Write docstrings for all public functions and classes
|
||||||
|
- Keep functions focused and single-purpose
|
||||||
|
|
||||||
|
## 🚦 CI/CD
|
||||||
|
|
||||||
|
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||||
|
|
||||||
|
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||||
|
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||||
|
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||||
|
4. **Wheel building**: Ensures packages can be built and distributed
|
||||||
|
|
||||||
|
### CI Commands
|
||||||
|
|
||||||
|
The CI uses the same commands as pre-commit to ensure consistency:
|
||||||
|
```bash
|
||||||
|
# Linting
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
# Format checking
|
||||||
|
ruff format --check .
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure your code passes these checks locally before pushing!
|
||||||
|
|
||||||
|
## 🔄 Pull Request Process
|
||||||
|
|
||||||
|
1. **Fork the repository** and create your branch from `main`:
|
||||||
|
```bash
|
||||||
|
git checkout -b feature/your-feature-name
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Make your changes**:
|
||||||
|
- Write clean, documented code
|
||||||
|
- Add tests for new functionality
|
||||||
|
- Update documentation as needed
|
||||||
|
|
||||||
|
3. **Run pre-commit checks**:
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Test your changes**:
|
||||||
|
```bash
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Commit with descriptive messages**:
|
||||||
|
```bash
|
||||||
|
git commit -m "feat: add new search algorithm"
|
||||||
|
```
|
||||||
|
|
||||||
|
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||||
|
- `feat:` for new features
|
||||||
|
- `fix:` for bug fixes
|
||||||
|
- `docs:` for documentation changes
|
||||||
|
- `test:` for test additions/changes
|
||||||
|
- `refactor:` for code refactoring
|
||||||
|
- `perf:` for performance improvements
|
||||||
|
|
||||||
|
6. **Push and create a pull request**:
|
||||||
|
- Provide a clear description of your changes
|
||||||
|
- Reference any related issues
|
||||||
|
- Include examples or screenshots if applicable
|
||||||
|
|
||||||
|
## 📚 Documentation
|
||||||
|
|
||||||
|
When adding new features or making significant changes:
|
||||||
|
|
||||||
|
1. Update relevant documentation in `/docs`
|
||||||
|
2. Add docstrings to new functions/classes
|
||||||
|
3. Update README.md if needed
|
||||||
|
4. Include usage examples
|
||||||
|
|
||||||
|
## 🤔 Getting Help
|
||||||
|
|
||||||
|
- **Discord**: Join our community for discussions
|
||||||
|
- **Issues**: Check existing issues or create a new one
|
||||||
|
- **Discussions**: For general questions and ideas
|
||||||
|
|
||||||
|
## 📄 License
|
||||||
|
|
||||||
|
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||||
22
docs/RELEASE.md
Normal file
22
docs/RELEASE.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Release Guide
|
||||||
|
|
||||||
|
## Setup (One-time)
|
||||||
|
|
||||||
|
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||||
|
1. Get token: https://pypi.org/manage/account/token/
|
||||||
|
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||||
|
|
||||||
|
## Release (One-click)
|
||||||
|
|
||||||
|
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||||
|
2. Click "Run workflow"
|
||||||
|
3. Enter version: `0.1.2`
|
||||||
|
4. Click green "Run workflow" button
|
||||||
|
|
||||||
|
That's it! The workflow will automatically:
|
||||||
|
- ✅ Update version in all packages
|
||||||
|
- ✅ Build all packages
|
||||||
|
- ✅ Publish to PyPI
|
||||||
|
- ✅ Create GitHub tag and release
|
||||||
|
|
||||||
|
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||||
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# Thinking Budget Feature Implementation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
||||||
|
|
||||||
|
## Feature Description
|
||||||
|
|
||||||
|
The thinking budget feature provides three levels of computational effort for reasoning models:
|
||||||
|
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
||||||
|
- **`medium`**: Balanced speed and reasoning depth
|
||||||
|
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### 1. Command Line Interface
|
||||||
|
|
||||||
|
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LEANN CLI
|
||||||
|
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||||
|
|
||||||
|
# RAG Examples
|
||||||
|
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||||
|
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. LLM Backend Support
|
||||||
|
|
||||||
|
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
# Handle thinking budget for reasoning models
|
||||||
|
options = kwargs.copy()
|
||||||
|
thinking_budget = kwargs.get("thinking_budget")
|
||||||
|
if thinking_budget:
|
||||||
|
options.pop("thinking_budget", None)
|
||||||
|
if thinking_budget in ["low", "medium", "high"]:
|
||||||
|
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
||||||
|
```
|
||||||
|
|
||||||
|
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
||||||
|
|
||||||
|
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
# Handle thinking budget for reasoning models
|
||||||
|
thinking_budget = kwargs.get("thinking_budget")
|
||||||
|
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
||||||
|
# Check if this is an o-series model
|
||||||
|
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
||||||
|
if any(model in self.model for model in o_series_models):
|
||||||
|
params["reasoning_effort"] = thinking_budget
|
||||||
|
```
|
||||||
|
|
||||||
|
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
||||||
|
|
||||||
|
### 3. Parameter Propagation
|
||||||
|
|
||||||
|
The thinking budget parameter is properly propagated through the LEANN architecture:
|
||||||
|
|
||||||
|
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
||||||
|
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
||||||
|
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
||||||
|
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
||||||
|
|
||||||
|
## Files Modified
|
||||||
|
|
||||||
|
### Core Implementation
|
||||||
|
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
||||||
|
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
||||||
|
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- `README.md`: Added thinking budget parameter to usage examples
|
||||||
|
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
```bash
|
||||||
|
# High reasoning effort for complex questions
|
||||||
|
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||||
|
|
||||||
|
# Medium reasoning for balanced performance
|
||||||
|
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
||||||
|
|
||||||
|
# Low reasoning for fast responses
|
||||||
|
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
||||||
|
```
|
||||||
|
|
||||||
|
### RAG Examples
|
||||||
|
```bash
|
||||||
|
# Email RAG with high reasoning
|
||||||
|
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||||
|
|
||||||
|
# Document RAG with medium reasoning
|
||||||
|
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
### Ollama Models
|
||||||
|
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
||||||
|
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
||||||
|
|
||||||
|
### OpenAI Models
|
||||||
|
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
||||||
|
- **GPT-OSS models**: Models that support reasoning capabilities
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
The implementation includes comprehensive testing:
|
||||||
|
- Parameter handling verification
|
||||||
|
- Backend-specific API format validation
|
||||||
|
- CLI argument parsing tests
|
||||||
|
- Integration with existing LEANN architecture
|
||||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Comparison between Sentence Transformers and OpenAI embeddings
|
||||||
|
|
||||||
|
This example shows how different embedding models handle complex queries
|
||||||
|
and demonstrates the differences between local and API-based embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
|
# OpenAI API key should be set as environment variable
|
||||||
|
# export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||||
|
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||||
|
|
||||||
|
# Two queries with same intent but different wording
|
||||||
|
query1 = "Tell me my browser history about some conference i often visit"
|
||||||
|
query2 = "browser history about conference I often visit"
|
||||||
|
|
||||||
|
texts = [query1, query2, conference_text, browser_text]
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b):
|
||||||
|
return np.dot(a, b) # Already normalized
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_embeddings(embeddings, model_name):
|
||||||
|
print(f"\n=== {model_name} Results ===")
|
||||||
|
|
||||||
|
# Results for Query 1
|
||||||
|
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||||
|
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||||
|
|
||||||
|
print(f"Query 1: '{query1}'")
|
||||||
|
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Results for Query 2
|
||||||
|
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||||
|
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||||
|
|
||||||
|
print(f"\nQuery 2: '{query2}'")
|
||||||
|
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Show the impact
|
||||||
|
print("\n=== Impact Analysis ===")
|
||||||
|
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||||
|
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||||
|
|
||||||
|
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||||
|
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||||
|
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||||
|
print("✅ STABLE: Conference remains winner in both queries")
|
||||||
|
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||||
|
print("✅ STABLE: Browser remains winner in both queries")
|
||||||
|
else:
|
||||||
|
print("🔄 MIXED: Results vary between queries")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query1_conf": sim1_conf,
|
||||||
|
"query1_browser": sim1_browser,
|
||||||
|
"query2_conf": sim2_conf,
|
||||||
|
"query2_browser": sim2_browser,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Test Sentence Transformers
|
||||||
|
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||||
|
try:
|
||||||
|
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||||
|
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Sentence Transformers failed: {e}")
|
||||||
|
st_results = None
|
||||||
|
|
||||||
|
# Test OpenAI
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing OpenAI (text-embedding-3-small)...")
|
||||||
|
try:
|
||||||
|
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||||
|
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ OpenAI failed: {e}")
|
||||||
|
openai_results = None
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
if st_results and openai_results:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("=== COMPARISON SUMMARY ===")
|
||||||
279
docs/configuration-guide.md
Normal file
279
docs/configuration-guide.md
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
# LEANN Configuration Guide
|
||||||
|
|
||||||
|
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
||||||
|
|
||||||
|
## Getting Started: Simple is Better
|
||||||
|
|
||||||
|
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
||||||
|
|
||||||
|
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
||||||
|
```bash
|
||||||
|
python -m apps.document_rag --query "What techniques does LEANN use?"
|
||||||
|
```
|
||||||
|
|
||||||
|
**For other data sources**: Limit the dataset size for quick testing
|
||||||
|
```bash
|
||||||
|
# WeChat: Test with recent messages only
|
||||||
|
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
||||||
|
|
||||||
|
# Browser history: Last few days
|
||||||
|
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
||||||
|
|
||||||
|
# Email: Recent inbox
|
||||||
|
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
||||||
|
```
|
||||||
|
|
||||||
|
Once validated, scale up gradually:
|
||||||
|
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
||||||
|
- This helps identify issues early before committing to long processing times
|
||||||
|
|
||||||
|
## Embedding Model Selection: Understanding the Trade-offs
|
||||||
|
|
||||||
|
Based on our experience developing LEANN, embedding models fall into three categories:
|
||||||
|
|
||||||
|
### Small Models (< 100M parameters)
|
||||||
|
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
||||||
|
- **Pros**: Lightweight, fast for both indexing and inference
|
||||||
|
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
||||||
|
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
||||||
|
|
||||||
|
### Medium Models (100M-500M parameters)
|
||||||
|
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
||||||
|
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
||||||
|
- **Cons**: Requires more compute than small models
|
||||||
|
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
||||||
|
|
||||||
|
### Large Models (500M+ parameters)
|
||||||
|
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
||||||
|
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
||||||
|
- **Cons**: Slower inference, longer index build times
|
||||||
|
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
||||||
|
|
||||||
|
### Quick Start: Cloud and Local Embedding Options
|
||||||
|
|
||||||
|
**OpenAI Embeddings (Fastest Setup)**
|
||||||
|
For immediate testing without local model downloads:
|
||||||
|
```bash
|
||||||
|
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
||||||
|
--embedding-mode openai --embedding-model text-embedding-3-small
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ollama Embeddings (Privacy-Focused)**
|
||||||
|
For local embeddings with complete privacy:
|
||||||
|
```bash
|
||||||
|
# First, pull an embedding model
|
||||||
|
ollama pull nomic-embed-text
|
||||||
|
|
||||||
|
# Use Ollama embeddings
|
||||||
|
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
||||||
|
|
||||||
|
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
||||||
|
- **Pros**: No local compute needed, consistently fast, high quality
|
||||||
|
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
|
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
||||||
|
|
||||||
|
**Local Embeddings**
|
||||||
|
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
||||||
|
- **Cons**: Slower than cloud APIs, requires local compute resources
|
||||||
|
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
|
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
||||||
|
- Full recomputation required
|
||||||
|
- High memory usage during build phase
|
||||||
|
- Excellent recall (95%+)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Optimal for most use cases
|
||||||
|
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
||||||
|
```
|
||||||
|
|
||||||
|
### DiskANN
|
||||||
|
**Best for**: Large datasets (> 10M vectors, 10GB+ index size) - **⚠️ Beta version, still in active development**
|
||||||
|
- Uses Product Quantization (PQ) for coarse filtering during graph traversal
|
||||||
|
- Novel approach: stores only PQ codes, performs rerank with exact computation in final step
|
||||||
|
- Implements a corner case of double-queue: prunes all neighbors and recomputes at the end
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# For billion-scale deployments
|
||||||
|
--backend-name diskann --graph-degree 64 --build-complexity 128
|
||||||
|
```
|
||||||
|
|
||||||
|
## LLM Selection: Engine and Model Comparison
|
||||||
|
|
||||||
|
### LLM Engines
|
||||||
|
|
||||||
|
**OpenAI** (`--llm openai`)
|
||||||
|
- **Pros**: Best quality, consistent performance, no local resources needed
|
||||||
|
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
||||||
|
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
||||||
|
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
||||||
|
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
||||||
|
|
||||||
|
**Ollama** (`--llm ollama`)
|
||||||
|
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
||||||
|
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
||||||
|
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
||||||
|
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
||||||
|
|
||||||
|
**HuggingFace** (`--llm hf`)
|
||||||
|
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
||||||
|
- **Cons**: More complex initial setup
|
||||||
|
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
||||||
|
|
||||||
|
## Parameter Tuning Guide
|
||||||
|
|
||||||
|
### Search Complexity Parameters
|
||||||
|
|
||||||
|
**`--build-complexity`** (index building)
|
||||||
|
- Controls thoroughness during index construction
|
||||||
|
- Higher = better recall but slower build
|
||||||
|
- Recommendations:
|
||||||
|
- 32: Quick prototyping
|
||||||
|
- 64: Balanced (default)
|
||||||
|
- 128: Production systems
|
||||||
|
- 256: Maximum quality
|
||||||
|
|
||||||
|
**`--search-complexity`** (query time)
|
||||||
|
- Controls search thoroughness
|
||||||
|
- Higher = better results but slower
|
||||||
|
- Recommendations:
|
||||||
|
- 16: Fast/Interactive search
|
||||||
|
- 32: High quality with diversity
|
||||||
|
- 64+: Maximum accuracy
|
||||||
|
|
||||||
|
### Top-K Selection
|
||||||
|
|
||||||
|
**`--top-k`** (number of retrieved chunks)
|
||||||
|
- More chunks = better context but slower LLM processing
|
||||||
|
- Should be always smaller than `--search-complexity`
|
||||||
|
- Guidelines:
|
||||||
|
- 10-20: General questions (default: 20)
|
||||||
|
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
||||||
|
|
||||||
|
**Trade-off formula**:
|
||||||
|
- Retrieval time ∝ log(n) × search_complexity
|
||||||
|
- LLM processing time ∝ top_k × chunk_size
|
||||||
|
- Total context = top_k × chunk_size tokens
|
||||||
|
|
||||||
|
### Thinking Budget for Reasoning Models
|
||||||
|
|
||||||
|
**`--thinking-budget`** (reasoning effort level)
|
||||||
|
- Controls the computational effort for reasoning models
|
||||||
|
- Options: `low`, `medium`, `high`
|
||||||
|
- Guidelines:
|
||||||
|
- `low`: Fast responses, basic reasoning (default for simple queries)
|
||||||
|
- `medium`: Balanced speed and reasoning depth
|
||||||
|
- `high`: Maximum reasoning effort, best for complex analytical questions
|
||||||
|
- **Supported Models**:
|
||||||
|
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
||||||
|
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
||||||
|
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
||||||
|
- **Example**: `--thinking-budget high` for complex analytical questions
|
||||||
|
|
||||||
|
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
||||||
|
|
||||||
|
**💡 Quick Examples:**
|
||||||
|
```bash
|
||||||
|
# OpenAI o-series reasoning model
|
||||||
|
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||||
|
--index-dir hnswbuild --backend hnsw \
|
||||||
|
--llm openai --llm-model o3 --thinking-budget medium
|
||||||
|
|
||||||
|
# Ollama reasoning model
|
||||||
|
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||||
|
--index-dir hnswbuild --backend hnsw \
|
||||||
|
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||||
|
```
|
||||||
|
|
||||||
|
### Graph Degree (HNSW/DiskANN)
|
||||||
|
|
||||||
|
**`--graph-degree`**
|
||||||
|
- Number of connections per node in the graph
|
||||||
|
- Higher = better recall but more memory
|
||||||
|
- HNSW: 16-32 (default: 32)
|
||||||
|
- DiskANN: 32-128 (default: 64)
|
||||||
|
|
||||||
|
|
||||||
|
## Performance Optimization Checklist
|
||||||
|
|
||||||
|
### If Embedding is Too Slow
|
||||||
|
|
||||||
|
1. **Switch to smaller model**:
|
||||||
|
```bash
|
||||||
|
# From large model
|
||||||
|
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
||||||
|
# To small model
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Limit dataset size for testing**:
|
||||||
|
```bash
|
||||||
|
--max-items 1000 # Process first 1k items only
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Use MLX on Apple Silicon** (optional optimization):
|
||||||
|
```bash
|
||||||
|
--embedding-mode mlx --embedding-model mlx-community/multilingual-e5-base-mlx
|
||||||
|
```
|
||||||
|
|
||||||
|
### If Search Quality is Poor
|
||||||
|
|
||||||
|
1. **Increase retrieval count**:
|
||||||
|
```bash
|
||||||
|
--top-k 30 # Retrieve more candidates
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Upgrade embedding model**:
|
||||||
|
```bash
|
||||||
|
# For English
|
||||||
|
--embedding-model BAAI/bge-base-en-v1.5
|
||||||
|
# For multilingual
|
||||||
|
--embedding-model intfloat/multilingual-e5-large
|
||||||
|
```
|
||||||
|
|
||||||
|
## Understanding the Trade-offs
|
||||||
|
|
||||||
|
Every configuration choice involves trade-offs:
|
||||||
|
|
||||||
|
| Factor | Small/Fast | Large/Quality |
|
||||||
|
|--------|------------|---------------|
|
||||||
|
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
||||||
|
| Chunk Size | 512 tokens | 128 tokens |
|
||||||
|
| Index Type | HNSW | DiskANN |
|
||||||
|
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
||||||
|
|
||||||
|
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
||||||
|
|
||||||
|
## Deep Dive: Critical Configuration Decisions
|
||||||
|
|
||||||
|
### When to Disable Recomputation
|
||||||
|
|
||||||
|
LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--no-recompute # Disable selective recomputation
|
||||||
|
```
|
||||||
|
|
||||||
|
**Trade-offs**:
|
||||||
|
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
|
||||||
|
- **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
|
||||||
|
|
||||||
|
**Disable when**:
|
||||||
|
- You have abundant storage and memory
|
||||||
|
- Need extremely low latency (< 100ms)
|
||||||
|
- Running a read-heavy workload where storage cost is acceptable
|
||||||
|
|
||||||
|
## Further Reading
|
||||||
|
|
||||||
|
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
|
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||||
|
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||||
10
docs/faq.md
Normal file
10
docs/faq.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# FAQ
|
||||||
|
|
||||||
|
## 1. My building time seems long
|
||||||
|
|
||||||
|
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
22
docs/features.md
Normal file
22
docs/features.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# ✨ Detailed Features
|
||||||
|
|
||||||
|
## 🔥 Core Features
|
||||||
|
|
||||||
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
|
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||||
|
|
||||||
|
## 🛠️ Technical Highlights
|
||||||
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
|
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||||
|
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||||
|
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||||
|
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||||
|
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
||||||
|
|
||||||
|
## 🎨 Developer Experience
|
||||||
|
|
||||||
|
- **Simple Python API** - Get started in minutes
|
||||||
|
- **Extensible backend system** - Easy to add new algorithms
|
||||||
|
- **Comprehensive examples** - From basic usage to production deployment
|
||||||
75
docs/normalized_embeddings.md
Normal file
75
docs/normalized_embeddings.md
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Normalized Embeddings Support in LEANN
|
||||||
|
|
||||||
|
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
||||||
|
|
||||||
|
## What are Normalized Embeddings?
|
||||||
|
|
||||||
|
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
||||||
|
|
||||||
|
## Automatic Detection
|
||||||
|
|
||||||
|
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
||||||
|
|
||||||
|
1. **Automatically set `distance_metric="cosine"`** if not specified
|
||||||
|
2. **Show a warning** if you manually specify a different distance metric
|
||||||
|
3. **Provide optimal search performance** with the correct metric
|
||||||
|
|
||||||
|
## Supported Normalized Embedding Models
|
||||||
|
|
||||||
|
### OpenAI
|
||||||
|
All OpenAI text embedding models are normalized:
|
||||||
|
- `text-embedding-ada-002`
|
||||||
|
- `text-embedding-3-small`
|
||||||
|
- `text-embedding-3-large`
|
||||||
|
|
||||||
|
### Voyage AI
|
||||||
|
All Voyage AI embedding models are normalized:
|
||||||
|
- `voyage-2`
|
||||||
|
- `voyage-3`
|
||||||
|
- `voyage-large-2`
|
||||||
|
- `voyage-multilingual-2`
|
||||||
|
- `voyage-code-2`
|
||||||
|
|
||||||
|
### Cohere
|
||||||
|
All Cohere embedding models are normalized:
|
||||||
|
- `embed-english-v3.0`
|
||||||
|
- `embed-multilingual-v3.0`
|
||||||
|
- `embed-english-light-v3.0`
|
||||||
|
- `embed-multilingual-light-v3.0`
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
# Automatic detection - will use cosine distance
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai"
|
||||||
|
)
|
||||||
|
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
||||||
|
# Automatically setting distance_metric='cosine'
|
||||||
|
|
||||||
|
# Manual override (not recommended)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
distance_metric="mips" # Will show warning
|
||||||
|
)
|
||||||
|
# Warning: Using 'mips' distance metric with normalized embeddings...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Non-Normalized Embeddings
|
||||||
|
|
||||||
|
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
||||||
|
|
||||||
|
## Why This Matters
|
||||||
|
|
||||||
|
Using the wrong distance metric with normalized embeddings can lead to:
|
||||||
|
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
||||||
|
- **Incorrect ranking** of search results
|
||||||
|
- **Suboptimal performance** compared to using the correct metric
|
||||||
|
|
||||||
|
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
||||||
21
docs/roadmap.md
Normal file
21
docs/roadmap.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# 📈 Roadmap
|
||||||
|
|
||||||
|
## 🎯 Q2 2025
|
||||||
|
|
||||||
|
- [X] HNSW backend integration
|
||||||
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
|
- [X] Real-time embedding pipeline
|
||||||
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
## 🚀 Q3 2025
|
||||||
|
|
||||||
|
- [ ] Advanced caching strategies
|
||||||
|
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||||
|
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||||
|
- [ ] Add OpenAI recompute API
|
||||||
|
|
||||||
|
## 🌟 Q4 2025
|
||||||
|
|
||||||
|
- [ ] Integration with LangChain/LlamaIndex
|
||||||
|
- [ ] Visual similarity search
|
||||||
|
- [ ] Query rewrtiting, rerank and expansion
|
||||||
@@ -1,15 +1,28 @@
|
|||||||
"""
|
"""
|
||||||
Simple demo showing basic leann usage
|
Simple demo showing basic leann usage
|
||||||
Run: uv run python examples/simple_demo.py
|
Run: uv run python examples/basic_demo.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
import argparse
|
||||||
|
|
||||||
|
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("=== Leann Simple Demo ===")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Simple demo of Leann with selectable embedding models."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding_model",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Sample knowledge base
|
# Sample knowledge base
|
||||||
chunks = [
|
chunks = [
|
||||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
||||||
@@ -21,61 +34,55 @@ def main():
|
|||||||
"Big data refers to extremely large datasets that require special tools and techniques to process.",
|
"Big data refers to extremely large datasets that require special tools and techniques to process.",
|
||||||
"Cloud computing provides on-demand access to computing resources over the internet.",
|
"Cloud computing provides on-demand access to computing resources over the internet.",
|
||||||
]
|
]
|
||||||
|
|
||||||
print("1. Building index (no embeddings stored)...")
|
print("1. Building index (no embeddings stored)...")
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
embedding_model="sentence-transformers/all-mpnet-base-v2",
|
embedding_model=args.embedding_model,
|
||||||
prune_ratio=0.7, # Keep 30% of connections
|
backend_name="hnsw",
|
||||||
)
|
)
|
||||||
builder.add_chunks(chunks)
|
for chunk in chunks:
|
||||||
|
builder.add_text(chunk)
|
||||||
builder.build_index("demo_knowledge.leann")
|
builder.build_index("demo_knowledge.leann")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("2. Searching with real-time embeddings...")
|
print("2. Searching with real-time embeddings...")
|
||||||
searcher = LeannSearcher("demo_knowledge.leann")
|
searcher = LeannSearcher("demo_knowledge.leann")
|
||||||
|
|
||||||
queries = [
|
queries = [
|
||||||
"What is machine learning?",
|
"What is machine learning?",
|
||||||
"How does neural network work?",
|
"How does neural network work?",
|
||||||
"Tell me about data processing",
|
"Tell me about data processing",
|
||||||
]
|
]
|
||||||
|
|
||||||
for query in queries:
|
for query in queries:
|
||||||
print(f"Query: {query}")
|
print(f"Query: {query}")
|
||||||
results = searcher.search(query, top_k=2)
|
results = searcher.search(query, top_k=2)
|
||||||
|
|
||||||
for i, result in enumerate(results, 1):
|
for i, result in enumerate(results, 1):
|
||||||
print(f" {i}. Score: {result.score:.3f}")
|
print(f" {i}. Score: {result.score:.3f}")
|
||||||
print(f" Text: {result.text[:100]}...")
|
print(f" Text: {result.text[:100]}...")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("3. Memory stats:")
|
print("3. Interactive chat demo:")
|
||||||
stats = searcher.get_memory_stats()
|
|
||||||
print(f" Cache size: {stats.embedding_cache_size}")
|
|
||||||
print(f" Cache memory: {stats.embedding_cache_memory_mb:.1f} MB")
|
|
||||||
print(f" Total chunks: {stats.total_chunks}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("4. Interactive chat demo:")
|
|
||||||
print(" (Note: Requires OpenAI API key for real responses)")
|
print(" (Note: Requires OpenAI API key for real responses)")
|
||||||
|
|
||||||
chat = LeannChat("demo_knowledge.leann")
|
chat = LeannChat("demo_knowledge.leann")
|
||||||
|
|
||||||
# Demo questions
|
# Demo questions
|
||||||
demo_questions: list[str] = [
|
demo_questions: list[str] = [
|
||||||
"What is the difference between machine learning and deep learning?",
|
"What is the difference between machine learning and deep learning?",
|
||||||
"How is data science related to big data?",
|
"How is data science related to big data?",
|
||||||
]
|
]
|
||||||
|
|
||||||
for question in demo_questions:
|
for question in demo_questions:
|
||||||
print(f" Q: {question}")
|
print(f" Q: {question}")
|
||||||
response = chat.ask(question)
|
response = chat.ask(question)
|
||||||
print(f" A: {response}")
|
print(f" A: {response}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("Demo completed! Try running:")
|
print("Demo completed! Try running:")
|
||||||
print(" uv run python examples/document_search.py")
|
print(" uv run python apps/document_rag.py")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Document search demo with recompute mode
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
|
|
||||||
# Import backend packages to trigger plugin registration
|
|
||||||
try:
|
|
||||||
import leann_backend_diskann
|
|
||||||
import leann_backend_hnsw
|
|
||||||
print("INFO: Backend packages imported successfully.")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
|
||||||
|
|
||||||
# Import upper-level API from leann-core
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
|
|
||||||
|
|
||||||
def load_sample_documents():
|
|
||||||
"""Create sample documents for demonstration"""
|
|
||||||
docs = [
|
|
||||||
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
|
|
||||||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
|
||||||
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
|
|
||||||
]
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def main():
|
|
||||||
print("==========================================================")
|
|
||||||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
|
||||||
print("==========================================================")
|
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_indices")
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
|
||||||
BACKEND_TO_TEST = "diskann"
|
|
||||||
|
|
||||||
if INDEX_DIR.exists():
|
|
||||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
|
||||||
shutil.rmtree(INDEX_DIR)
|
|
||||||
|
|
||||||
# --- 1. Build index ---
|
|
||||||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=BACKEND_TO_TEST,
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = load_sample_documents()
|
|
||||||
print(f"Loaded {len(documents)} sample documents.")
|
|
||||||
for doc in documents:
|
|
||||||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
print(f"\nIndex built!")
|
|
||||||
|
|
||||||
# --- 2. Basic search demo ---
|
|
||||||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
|
||||||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
|
||||||
|
|
||||||
query = "What is machine learning?"
|
|
||||||
print(f"\nQuery: '{query}'")
|
|
||||||
|
|
||||||
print("\n--- Basic search mode (PQ computation) ---")
|
|
||||||
start_time = time.time()
|
|
||||||
results = searcher.search(query, top_k=2)
|
|
||||||
basic_time = time.time() - start_time
|
|
||||||
|
|
||||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
|
||||||
print(">>> Basic search results <<<")
|
|
||||||
for i, res in enumerate(results, 1):
|
|
||||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
|
||||||
|
|
||||||
# --- 3. Recompute search demo ---
|
|
||||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
|
||||||
|
|
||||||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
|
||||||
|
|
||||||
# Configure recompute parameters
|
|
||||||
recompute_params = {
|
|
||||||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
|
||||||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
|
||||||
"skip_search_reorder": True, # Skip search reordering
|
|
||||||
"dedup_node_dis": True, # Enable node distance deduplication
|
|
||||||
"prune_ratio": 0.1, # Pruning ratio 10%
|
|
||||||
"batch_recompute": False, # Don't use batch recomputation
|
|
||||||
"global_pruning": False, # Don't use global pruning
|
|
||||||
"zmq_port": 5555, # ZMQ port
|
|
||||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
|
|
||||||
}
|
|
||||||
|
|
||||||
print("Recompute parameter configuration:")
|
|
||||||
for key, value in recompute_params.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
print(f"\n🔄 Executing Recompute search...")
|
|
||||||
try:
|
|
||||||
start_time = time.time()
|
|
||||||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
|
||||||
recompute_time = time.time() - start_time
|
|
||||||
|
|
||||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
|
||||||
print(">>> Recompute search results <<<")
|
|
||||||
for i, res in enumerate(recompute_results, 1):
|
|
||||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
print(f"\n--- Result comparison ---")
|
|
||||||
print(f"Basic search time: {basic_time:.3f} seconds")
|
|
||||||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
|
||||||
|
|
||||||
print("\nBasic search vs Recompute results:")
|
|
||||||
for i in range(min(len(results), len(recompute_results))):
|
|
||||||
basic_score = results[i]['score']
|
|
||||||
recompute_score = recompute_results[i]['score']
|
|
||||||
score_diff = abs(basic_score - recompute_score)
|
|
||||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
|
||||||
|
|
||||||
if recompute_time > basic_time:
|
|
||||||
print(f"✅ Recompute mode working correctly (more accurate but slower)")
|
|
||||||
else:
|
|
||||||
print(f"ℹ️ Recompute time is unusually fast, network recomputation may not be enabled")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Recompute search failed: {e}")
|
|
||||||
print("This usually indicates an embedding server connection issue")
|
|
||||||
|
|
||||||
# --- 4. Chat demo ---
|
|
||||||
print(f"\n[PHASE 4] Starting chat session...")
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH)
|
|
||||||
chat_response = chat.ask(query)
|
|
||||||
print(f"You: {query}")
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
print("\n==========================================================")
|
|
||||||
print("✅ Demo finished successfully!")
|
|
||||||
print("==========================================================")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
from llama_index.core import SimpleDirectoryReader, Settings
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
from llama_index.node_parser.docling import DoclingNodeParser
|
|
||||||
from llama_index.readers.docling import DoclingReader
|
|
||||||
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import dotenv
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
import leann_backend_diskann # Import to ensure backend registration
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
reader = DoclingReader(export_type=DoclingReader.ExportType.JSON)
|
|
||||||
file_extractor: dict[str, BaseReader] = {
|
|
||||||
".docx": reader,
|
|
||||||
".pptx": reader,
|
|
||||||
".pdf": reader,
|
|
||||||
".xlsx": reader,
|
|
||||||
}
|
|
||||||
node_parser = DoclingNodeParser(
|
|
||||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=10240)
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(
|
|
||||||
"examples/data",
|
|
||||||
recursive=True,
|
|
||||||
file_extractor=file_extractor,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
|
|
||||||
).load_data(show_progress=True)
|
|
||||||
|
|
||||||
# Extract text from documents and prepare for Leann
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
# DoclingNodeParser returns Node objects, which have a text attribute
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.text)
|
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_pdf_index")
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
|
||||||
|
|
||||||
if INDEX_DIR.exists():
|
|
||||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
|
||||||
shutil.rmtree(INDEX_DIR)
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="diskann",
|
|
||||||
embedding_model="sentence-transformers/all-mpnet-base-v2", # Using a common sentence transformer model
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH)
|
|
||||||
|
|
||||||
query = "Based on the paper, what are the two main techniques LEANN uses to achieve low storage overhead and high retrieval accuracy?"
|
|
||||||
print(f"You: {query}")
|
|
||||||
chat_response = chat.ask(query, recompute_beighbor_embeddings=True)
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
43
examples/mlx_demo.py
Normal file
43
examples/mlx_demo.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
|
||||||
|
# Define the path for our new MLX-based index
|
||||||
|
INDEX_PATH = "./mlx_diskann_index/leann"
|
||||||
|
|
||||||
|
if os.path.exists(INDEX_PATH + ".meta.json"):
|
||||||
|
print(f"Index already exists at {INDEX_PATH}. Skipping build.")
|
||||||
|
else:
|
||||||
|
print("Initializing LeannBuilder with MLX support...")
|
||||||
|
# 1. Configure LeannBuilder to use MLX
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
|
||||||
|
embedding_mode="mlx",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Add documents
|
||||||
|
print("Adding documents...")
|
||||||
|
docs = [
|
||||||
|
"MLX is an array framework for machine learning on Apple silicon.",
|
||||||
|
"It was designed by Apple's machine learning research team.",
|
||||||
|
"The mlx-community organization provides pre-trained models in MLX format.",
|
||||||
|
"It supports operations on multi-dimensional arrays.",
|
||||||
|
"Leann can now use MLX for its embedding models.",
|
||||||
|
]
|
||||||
|
for doc in docs:
|
||||||
|
builder.add_text(doc)
|
||||||
|
|
||||||
|
# 3. Build the index
|
||||||
|
print(f"Building the MLX-based index at: {INDEX_PATH}")
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print("\nSuccessfully built the index with MLX embeddings!")
|
||||||
|
print(f"Check the metadata file: {INDEX_PATH}.meta.json")
|
||||||
|
|
||||||
|
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH)
|
||||||
|
# add query
|
||||||
|
query = "MLX is an array framework for machine learning on Apple silicon."
|
||||||
|
print(f"Query: {query}")
|
||||||
|
response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1)
|
||||||
|
print(f"Response: {response}")
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "0.1.0",
|
|
||||||
"backend_name": "diskann",
|
|
||||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
|
||||||
"num_chunks": 6,
|
|
||||||
"chunks": [
|
|
||||||
{
|
|
||||||
"text": "Python is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Machine learning transforms industries",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Neural networks process complex data",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Java is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "C++ is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "C# is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||||
|
|
||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
project(leann_backend_diskann_wrapper)
|
project(leann_backend_diskann_wrapper)
|
||||||
|
|
||||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
# DiskANN will handle everything itself, including compiling Python bindings
|
||||||
add_subdirectory(src/third_party/DiskANN)
|
add_subdirectory(src/third_party/DiskANN)
|
||||||
|
|||||||
1
packages/leann-backend-diskann/__init__.py
Normal file
1
packages/leann-backend-diskann/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# This file makes the directory a Python package
|
||||||
@@ -1,7 +1 @@
|
|||||||
print("Initializing leann-backend-diskann...")
|
from . import diskann_backend as diskann_backend
|
||||||
|
|
||||||
try:
|
|
||||||
from .diskann_backend import DiskannBackend
|
|
||||||
print("INFO: DiskANN backend loaded successfully")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"WARNING: Could not import DiskANN backend: {e}")
|
|
||||||
|
|||||||
@@ -1,30 +1,71 @@
|
|||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import struct
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import threading
|
import logging
|
||||||
import time
|
import os
|
||||||
import atexit
|
import struct
|
||||||
import socket
|
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from leann.registry import register_backend
|
import numpy as np
|
||||||
|
import psutil
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
LeannBackendFactoryInterface,
|
|
||||||
LeannBackendBuilderInterface,
|
LeannBackendBuilderInterface,
|
||||||
LeannBackendSearcherInterface
|
LeannBackendFactoryInterface,
|
||||||
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
from . import _diskannpy as diskannpy
|
from leann.registry import register_backend
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def suppress_cpp_output_if_needed():
|
||||||
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||||
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
|
||||||
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||||
|
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
if not should_suppress:
|
||||||
|
# Don't suppress, just yield
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save original file descriptors
|
||||||
|
stdout_fd = sys.stdout.fileno()
|
||||||
|
stderr_fd = sys.stderr.fileno()
|
||||||
|
|
||||||
|
# Save original stdout/stderr
|
||||||
|
stdout_dup = os.dup(stdout_fd)
|
||||||
|
stderr_dup = os.dup(stderr_fd)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Redirect to /dev/null
|
||||||
|
devnull = os.open(os.devnull, os.O_WRONLY)
|
||||||
|
os.dup2(devnull, stdout_fd)
|
||||||
|
os.dup2(devnull, stderr_fd)
|
||||||
|
os.close(devnull)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original file descriptors
|
||||||
|
os.dup2(stdout_dup, stdout_fd)
|
||||||
|
os.dup2(stderr_dup, stderr_fd)
|
||||||
|
os.close(stdout_dup)
|
||||||
|
os.close(stderr_dup)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_diskann_metrics():
|
||||||
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mips": diskannpy.Metric.INNER_PRODUCT,
|
||||||
|
"l2": diskannpy.Metric.L2,
|
||||||
|
"cosine": diskannpy.Metric.COSINE,
|
||||||
|
}
|
||||||
|
|
||||||
METRIC_MAP = {
|
|
||||||
"mips": diskannpy.Metric.INNER_PRODUCT,
|
|
||||||
"l2": diskannpy.Metric.L2,
|
|
||||||
"cosine": diskannpy.Metric.COSINE,
|
|
||||||
}
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def chdir(path):
|
def chdir(path):
|
||||||
@@ -35,102 +76,51 @@ def chdir(path):
|
|||||||
finally:
|
finally:
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
|
|
||||||
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
|
|
||||||
|
def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||||
num_vectors, dim = data.shape
|
num_vectors, dim = data.shape
|
||||||
with open(file_path, 'wb') as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(struct.pack('I', num_vectors))
|
f.write(struct.pack("I", num_vectors))
|
||||||
f.write(struct.pack('I', dim))
|
f.write(struct.pack("I", dim))
|
||||||
f.write(data.tobytes())
|
f.write(data.tobytes())
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
||||||
def __init__(self):
|
"""
|
||||||
self.server_process = None
|
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
||||||
self.server_port = None
|
|
||||||
atexit.register(self.stop_server)
|
|
||||||
|
|
||||||
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"):
|
Args:
|
||||||
if self.server_process and self.server_process.poll() is None:
|
data: The embedding data array
|
||||||
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查端口是否已被其他无关进程占用
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
print(f"INFO: Starting session-level embedding server as a background process...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
command = [
|
|
||||||
sys.executable,
|
|
||||||
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
|
||||||
"--zmq-port", str(port),
|
|
||||||
"--model-name", model_name
|
|
||||||
]
|
|
||||||
project_root = Path(__file__).parent.parent.parent.parent
|
|
||||||
print(f"INFO: Running command from project root: {project_root}")
|
|
||||||
self.server_process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
cwd=project_root,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
text=True,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
self.server_port = port
|
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
|
||||||
|
|
||||||
max_wait, wait_interval = 30, 0.5
|
Returns:
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
tuple: (search_memory_maximum, build_memory_maximum) in GB
|
||||||
if _check_port(port):
|
"""
|
||||||
print(f"✅ Embedding server is up and ready for this session.")
|
num_vectors, dim = data.shape
|
||||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
|
||||||
log_thread.start()
|
|
||||||
return True
|
|
||||||
if self.server_process.poll() is not None:
|
|
||||||
print("❌ ERROR: Server process terminated unexpectedly during startup.")
|
|
||||||
self._log_monitor()
|
|
||||||
return False
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
|
|
||||||
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
|
|
||||||
self.stop_server()
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _log_monitor(self):
|
# Calculate embedding storage size
|
||||||
if not self.server_process:
|
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
|
||||||
return
|
embedding_size_gb = embedding_size_bytes / (1024**3)
|
||||||
try:
|
|
||||||
if self.server_process.stdout:
|
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
|
||||||
for line in iter(self.server_process.stdout.readline, ''):
|
# This controls Product Quantization size - smaller means more compression
|
||||||
print(f"[EmbeddingServer LOG]: {line.strip()}")
|
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
|
||||||
self.server_process.stdout.close()
|
|
||||||
if self.server_process.stderr:
|
# build_memory_maximum: Based on available system RAM for sharding control
|
||||||
for line in iter(self.server_process.stderr.readline, ''):
|
# This controls how much memory DiskANN uses during index construction
|
||||||
print(f"[EmbeddingServer ERROR]: {line.strip()}")
|
available_memory_gb = psutil.virtual_memory().available / (1024**3)
|
||||||
self.server_process.stderr.close()
|
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
||||||
except Exception as e:
|
|
||||||
print(f"Log monitor error: {e}")
|
# Use 50% of available memory, but at least 2GB and at most 75% of total
|
||||||
|
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
|
||||||
|
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
|
||||||
|
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return search_memory_gb, build_memory_gb
|
||||||
|
|
||||||
def stop_server(self):
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
|
|
||||||
self.server_process.terminate()
|
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=5)
|
|
||||||
print("INFO: Server process terminated.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
print("WARNING: Server process did not terminate gracefully, killing it.")
|
|
||||||
self.server_process.kill()
|
|
||||||
self.server_process = None
|
|
||||||
|
|
||||||
@register_backend("diskann")
|
@register_backend("diskann")
|
||||||
class DiskannBackend(LeannBackendFactoryInterface):
|
class DiskannBackend(LeannBackendFactoryInterface):
|
||||||
@@ -140,138 +130,191 @@ class DiskannBackend(LeannBackendFactoryInterface):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
||||||
path = Path(index_path)
|
|
||||||
meta_path = path.parent / f"{path.name}.meta.json"
|
|
||||||
if not meta_path.exists():
|
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
|
||||||
with open(meta_path, 'r') as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
model = SentenceTransformer(meta.get("embedding_model"))
|
|
||||||
dimensions = model.get_sentence_embedding_dimension()
|
|
||||||
kwargs['dimensions'] = dimensions
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
|
|
||||||
|
|
||||||
return DiskannSearcher(index_path, **kwargs)
|
return DiskannSearcher(index_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class DiskannBuilder(LeannBackendBuilderInterface):
|
class DiskannBuilder(LeannBackendBuilderInterface):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
index_prefix = path.stem
|
index_prefix = path.stem
|
||||||
|
|
||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
|
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
if not data.flags['C_CONTIGUOUS']:
|
|
||||||
data = np.ascontiguousarray(data)
|
|
||||||
|
|
||||||
data_filename = f"{index_prefix}_data.bin"
|
data_filename = f"{index_prefix}_data.bin"
|
||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
metric_str = build_kwargs.get("distance_metric", "mips").lower()
|
metric_enum = _get_diskann_metrics().get(
|
||||||
metric_enum = METRIC_MAP.get(metric_str)
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
|
)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
raise ValueError(
|
||||||
|
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||||
|
)
|
||||||
|
|
||||||
complexity = build_kwargs.get("complexity", 64)
|
# Calculate smart memory configuration if not explicitly provided
|
||||||
graph_degree = build_kwargs.get("graph_degree", 32)
|
if (
|
||||||
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
|
"search_memory_maximum" not in build_kwargs
|
||||||
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
|
or "build_memory_maximum" not in build_kwargs
|
||||||
num_threads = build_kwargs.get("num_threads", 8)
|
):
|
||||||
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
|
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
|
||||||
codebook_prefix = ""
|
else:
|
||||||
|
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
|
||||||
|
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
|
||||||
|
|
||||||
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
|
|
||||||
with chdir(index_dir):
|
with chdir(index_dir):
|
||||||
diskannpy.build_disk_float_index(
|
diskannpy.build_disk_float_index(
|
||||||
metric_enum,
|
metric_enum,
|
||||||
data_filename,
|
data_filename,
|
||||||
index_prefix,
|
index_prefix,
|
||||||
complexity,
|
build_kwargs.get("complexity", 64),
|
||||||
graph_degree,
|
build_kwargs.get("graph_degree", 32),
|
||||||
final_index_ram_limit,
|
build_kwargs.get("search_memory_maximum", smart_search_mem),
|
||||||
indexing_ram_budget,
|
build_kwargs.get("build_memory_maximum", smart_build_mem),
|
||||||
num_threads,
|
build_kwargs.get("num_threads", 8),
|
||||||
pq_disk_bytes,
|
build_kwargs.get("pq_disk_bytes", 0),
|
||||||
codebook_prefix
|
"",
|
||||||
)
|
)
|
||||||
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
finally:
|
||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
os.remove(temp_data_file)
|
os.remove(temp_data_file)
|
||||||
|
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
|
||||||
|
|
||||||
class DiskannSearcher(LeannBackendSearcherInterface):
|
|
||||||
|
class DiskannSearcher(BaseSearcher):
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
super().__init__(
|
||||||
index_dir = path.parent
|
index_path,
|
||||||
index_prefix = path.stem
|
backend_module_name="leann_backend_diskann.diskann_embedding_server",
|
||||||
metric_str = kwargs.get("distance_metric", "mips").lower()
|
**kwargs,
|
||||||
metric_enum = METRIC_MAP.get(metric_str)
|
)
|
||||||
if metric_enum is None:
|
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
# Initialize DiskANN index with suppressed C++ output based on log level
|
||||||
|
with suppress_cpp_output_if_needed():
|
||||||
num_threads = kwargs.get("num_threads", 8)
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
|
||||||
dimensions = kwargs.get("dimensions")
|
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||||
if not dimensions:
|
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||||
raise ValueError("Vector dimension not provided to DiskannSearcher.")
|
if metric_enum is None:
|
||||||
|
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||||
try:
|
|
||||||
full_index_prefix = str(index_dir / index_prefix)
|
self.num_threads = kwargs.get("num_threads", 8)
|
||||||
self._index = diskannpy.StaticDiskFloatIndex(
|
|
||||||
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
|
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||||
)
|
# Store the initialization parameters for later use
|
||||||
self.num_threads = num_threads
|
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||||
self.embedding_server_manager = EmbeddingServerManager()
|
self._init_params = {
|
||||||
print("✅ DiskANN index loaded successfully.")
|
"metric_enum": metric_enum,
|
||||||
except Exception as e:
|
"full_index_prefix": full_index_prefix,
|
||||||
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
"num_threads": self.num_threads,
|
||||||
raise
|
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||||
|
"cache_mechanism": 1,
|
||||||
|
"pq_prefix": "",
|
||||||
|
"partition_prefix": "",
|
||||||
|
}
|
||||||
|
self._diskannpy = diskannpy
|
||||||
|
self._current_zmq_port = None
|
||||||
|
self._index = None
|
||||||
|
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
||||||
|
|
||||||
|
def _ensure_index_loaded(self, zmq_port: int):
|
||||||
|
"""Ensure the index is loaded with the correct zmq_port."""
|
||||||
|
if self._index is None or self._current_zmq_port != zmq_port:
|
||||||
|
# Need to (re)load the index with the correct zmq_port
|
||||||
|
with suppress_cpp_output_if_needed():
|
||||||
|
if self._index is not None:
|
||||||
|
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
||||||
|
|
||||||
|
self._index = self._diskannpy.StaticDiskFloatIndex(
|
||||||
|
self._init_params["metric_enum"],
|
||||||
|
self._init_params["full_index_prefix"],
|
||||||
|
self._init_params["num_threads"],
|
||||||
|
self._init_params["num_nodes_to_cache"],
|
||||||
|
self._init_params["cache_mechanism"],
|
||||||
|
zmq_port,
|
||||||
|
self._init_params["pq_prefix"],
|
||||||
|
self._init_params["partition_prefix"],
|
||||||
|
)
|
||||||
|
self._current_zmq_port = zmq_port
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: np.ndarray,
|
||||||
|
top_k: int,
|
||||||
|
complexity: int = 64,
|
||||||
|
beam_width: int = 1,
|
||||||
|
prune_ratio: float = 0.0,
|
||||||
|
recompute_embeddings: bool = False,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
zmq_port: int | None = None,
|
||||||
|
batch_recompute: bool = False,
|
||||||
|
dedup_node_dis: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Search for nearest neighbors using DiskANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query vectors (B, D) where B is batch size, D is dimension
|
||||||
|
top_k: Number of nearest neighbors to return
|
||||||
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel IO requests per iteration
|
||||||
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server
|
||||||
|
pruning_strategy: PQ candidate selection strategy:
|
||||||
|
- "global": Use global pruning strategy (default)
|
||||||
|
- "local": Use local pruning strategy
|
||||||
|
- "proportional": Not supported in DiskANN, falls back to global
|
||||||
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
|
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||||
|
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||||
|
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||||
|
"""
|
||||||
|
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
||||||
|
if recompute_embeddings:
|
||||||
|
if zmq_port is None:
|
||||||
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
|
self._ensure_index_loaded(zmq_port)
|
||||||
|
else:
|
||||||
|
# If not recomputing, we still need an index, use a default port
|
||||||
|
if self._index is None:
|
||||||
|
self._ensure_index_loaded(6666) # Default port when not recomputing
|
||||||
|
|
||||||
|
# DiskANN doesn't support "proportional" strategy
|
||||||
|
if pruning_strategy == "proportional":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
||||||
|
)
|
||||||
|
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
|
||||||
complexity = kwargs.get("complexity", 100)
|
|
||||||
beam_width = kwargs.get("beam_width", 4)
|
|
||||||
|
|
||||||
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
|
|
||||||
skip_search_reorder = kwargs.get("skip_search_reorder", False)
|
|
||||||
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False)
|
|
||||||
dedup_node_dis = kwargs.get("dedup_node_dis", False)
|
|
||||||
prune_ratio = kwargs.get("prune_ratio", 0.0)
|
|
||||||
batch_recompute = kwargs.get("batch_recompute", False)
|
|
||||||
global_pruning = kwargs.get("global_pruning", False)
|
|
||||||
|
|
||||||
if recompute_beighbor_embeddings:
|
|
||||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
|
||||||
zmq_port = kwargs.get("zmq_port", 5555)
|
|
||||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
|
||||||
|
|
||||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
|
|
||||||
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
|
|
||||||
kwargs['recompute_beighbor_embeddings'] = False
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
if query.ndim == 1:
|
|
||||||
query = np.expand_dims(query, axis=0)
|
# Map pruning_strategy to DiskANN's global_pruning parameter
|
||||||
|
if pruning_strategy == "local":
|
||||||
try:
|
use_global_pruning = False
|
||||||
|
else: # "global"
|
||||||
|
use_global_pruning = True
|
||||||
|
|
||||||
|
# Perform search with suppressed C++ output based on log level
|
||||||
|
use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
|
||||||
|
recompute_neighors = False
|
||||||
|
with suppress_cpp_output_if_needed():
|
||||||
labels, distances = self._index.batch_search(
|
labels, distances = self._index.batch_search(
|
||||||
query,
|
query,
|
||||||
query.shape[0],
|
query.shape[0],
|
||||||
@@ -279,21 +322,15 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
complexity,
|
complexity,
|
||||||
beam_width,
|
beam_width,
|
||||||
self.num_threads,
|
self.num_threads,
|
||||||
USE_DEFERRED_FETCH,
|
use_deferred_fetch,
|
||||||
skip_search_reorder,
|
kwargs.get("skip_search_reorder", False),
|
||||||
recompute_beighbor_embeddings,
|
recompute_neighors,
|
||||||
dedup_node_dis,
|
dedup_node_dis,
|
||||||
prune_ratio,
|
prune_ratio,
|
||||||
batch_recompute,
|
batch_recompute,
|
||||||
global_pruning
|
use_global_pruning,
|
||||||
)
|
)
|
||||||
return {"labels": labels, "distances": distances}
|
|
||||||
except Exception as e:
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
|
|
||||||
batch_size = query.shape[0]
|
return {"labels": string_labels, "distances": distances}
|
||||||
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
|
|
||||||
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if hasattr(self, 'embedding_server_manager'):
|
|
||||||
self.embedding_server_manager.stop_server()
|
|
||||||
|
|||||||
@@ -0,0 +1,284 @@
|
|||||||
|
"""
|
||||||
|
DiskANN-specific embedding server
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
# Ensure we have a handler if none exists
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def create_diskann_embedding_server(
|
||||||
|
passages_file: str | None = None,
|
||||||
|
zmq_port: int = 5555,
|
||||||
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
distance_metric: str = "l2",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||||
|
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
||||||
|
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||||
|
|
||||||
|
# Add leann-core to path for unified embedding computation
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||||
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.api import PassageManager
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import embedding computation module: {e}")
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
sys.path.pop(0)
|
||||||
|
|
||||||
|
# Check port availability
|
||||||
|
import socket
|
||||||
|
|
||||||
|
def check_port(port):
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
if check_port(zmq_port):
|
||||||
|
logger.error(f"Port {zmq_port} is already in use")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only support metadata file, fail fast for everything else
|
||||||
|
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||||
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||||
|
|
||||||
|
# Load metadata to get passage sources
|
||||||
|
with open(passages_file) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passages = PassageManager(meta["passage_sources"])
|
||||||
|
logger.info(
|
||||||
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import protobuf after ensuring the path is correct
|
||||||
|
try:
|
||||||
|
from . import embedding_pb2
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import protobuf module: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def zmq_server_thread():
|
||||||
|
"""ZMQ server thread using REP socket for universal compatibility"""
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(
|
||||||
|
zmq.REP
|
||||||
|
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||||
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# REP socket receives single-part messages
|
||||||
|
message = socket.recv()
|
||||||
|
|
||||||
|
# Check for empty messages - REP socket requires response to every request
|
||||||
|
if len(message) == 0:
|
||||||
|
logger.debug("Received empty message, sending empty response")
|
||||||
|
socket.send(b"") # REP socket must respond to every request
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
||||||
|
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
||||||
|
|
||||||
|
e2e_start = time.time()
|
||||||
|
|
||||||
|
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
||||||
|
texts = []
|
||||||
|
node_ids = []
|
||||||
|
is_text_request = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
req_proto.ParseFromString(message)
|
||||||
|
node_ids = list(req_proto.node_ids)
|
||||||
|
|
||||||
|
if not node_ids:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
||||||
|
)
|
||||||
|
except Exception as protobuf_error:
|
||||||
|
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
||||||
|
# Fallback to msgpack (for BaseSearcher direct text requests)
|
||||||
|
try:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
request = msgpack.unpackb(message)
|
||||||
|
# For BaseSearcher compatibility, request is a list of texts directly
|
||||||
|
if isinstance(request, list) and all(
|
||||||
|
isinstance(item, str) for item in request
|
||||||
|
):
|
||||||
|
texts = request
|
||||||
|
is_text_request = True
|
||||||
|
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
|
||||||
|
else:
|
||||||
|
raise ValueError("Not a valid msgpack text request")
|
||||||
|
except Exception as msgpack_error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look up texts by node IDs (only if not direct text request)
|
||||||
|
if not is_text_request:
|
||||||
|
for nid in node_ids:
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"Passage ID {nid} not found: {e}")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Debug logging
|
||||||
|
logger.debug(f"Processing {len(texts)} texts")
|
||||||
|
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||||
|
|
||||||
|
# Process embeddings using unified computation
|
||||||
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare response based on request type
|
||||||
|
if is_text_request:
|
||||||
|
# For BaseSearcher compatibility: return msgpack format
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb(embeddings.tolist())
|
||||||
|
else:
|
||||||
|
# For DiskANN C++ compatibility: return protobuf format
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Serialize embeddings data
|
||||||
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||||
|
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
|
||||||
|
# Send response back to the client
|
||||||
|
socket.send(response_data)
|
||||||
|
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
|
zmq_thread.start()
|
||||||
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
|
# Keep the main thread alive
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("DiskANN Server shutting down...")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||||
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
|
parser.add_argument(
|
||||||
|
"--passages-file",
|
||||||
|
type=str,
|
||||||
|
help="Metadata JSON file containing passage sources",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
type=str,
|
||||||
|
default="l2",
|
||||||
|
choices=["l2", "mips", "cosine"],
|
||||||
|
help="Distance metric for similarity computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create and start the DiskANN embedding server
|
||||||
|
create_diskann_embedding_server(
|
||||||
|
passages_file=args.passages_file,
|
||||||
|
zmq_port=args.zmq_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
@@ -1,27 +1,28 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# source: embedding.proto
|
# source: embedding.proto
|
||||||
|
# ruff: noqa
|
||||||
"""Generated protocol buffer code."""
|
"""Generated protocol buffer code."""
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
from google.protobuf import descriptor as _descriptor
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
|
||||||
# @@protoc_insertion_point(imports)
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||||
|
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
|
)
|
||||||
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
|
DESCRIPTOR._options = None
|
||||||
DESCRIPTOR._options = None
|
_NODEEMBEDDINGREQUEST._serialized_start = 35
|
||||||
_NODEEMBEDDINGREQUEST._serialized_start=35
|
_NODEEMBEDDINGREQUEST._serialized_end = 75
|
||||||
_NODEEMBEDDINGREQUEST._serialized_end=75
|
_NODEEMBEDDINGRESPONSE._serialized_start = 77
|
||||||
_NODEEMBEDDINGRESPONSE._serialized_start=77
|
_NODEEMBEDDINGRESPONSE._serialized_end = 166
|
||||||
_NODEEMBEDDINGRESPONSE._serialized_end=166
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|||||||
@@ -1,397 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
import argparse
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import zmq
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
RED = "\033[91m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
# 简化的文档存储 - 替代 LazyPassages
|
|
||||||
class SimpleDocumentStore:
|
|
||||||
"""简化的文档存储,支持任意ID"""
|
|
||||||
def __init__(self, documents: dict = None):
|
|
||||||
self.documents = documents or {}
|
|
||||||
# 默认演示文档
|
|
||||||
self.default_docs = {
|
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
|
||||||
1: "Machine learning builds systems that learn from data.",
|
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __getitem__(self, doc_id):
|
|
||||||
doc_id = int(doc_id)
|
|
||||||
|
|
||||||
# 优先使用指定的文档
|
|
||||||
if doc_id in self.documents:
|
|
||||||
return {"text": self.documents[doc_id]}
|
|
||||||
|
|
||||||
# 其次使用默认演示文档
|
|
||||||
if doc_id in self.default_docs:
|
|
||||||
return {"text": self.default_docs[doc_id]}
|
|
||||||
|
|
||||||
# 对于任意其他ID,返回通用文档
|
|
||||||
fallback_docs = [
|
|
||||||
"This is a general document about technology and programming concepts.",
|
|
||||||
"This document discusses machine learning and artificial intelligence topics.",
|
|
||||||
"This content covers data structures, algorithms, and computer science fundamentals.",
|
|
||||||
"This is a document about software engineering and development practices.",
|
|
||||||
"This content focuses on databases, data management, and information systems."
|
|
||||||
]
|
|
||||||
|
|
||||||
# 根据ID选择一个fallback文档
|
|
||||||
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
|
|
||||||
return {"text": f"[ID:{doc_id}] {fallback_text}"}
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.documents) + len(self.default_docs)
|
|
||||||
|
|
||||||
def create_embedding_server_thread(
|
|
||||||
zmq_port=5555,
|
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
max_batch_size=128,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
在当前线程中创建并运行 embedding server
|
|
||||||
这个函数设计为在单独的线程中调用
|
|
||||||
"""
|
|
||||||
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 检查端口是否已被占用
|
|
||||||
import socket
|
|
||||||
def check_port(port):
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
if check_port(zmq_port):
|
|
||||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# 选择设备
|
|
||||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
|
||||||
cuda_available = torch.cuda.is_available()
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
device = torch.device("cuda")
|
|
||||||
print("INFO: Using CUDA device")
|
|
||||||
elif mps_available:
|
|
||||||
device = torch.device("mps")
|
|
||||||
print("INFO: Using MPS device (Apple Silicon)")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
print("INFO: Using CPU device")
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
print(f"INFO: Loading model {model_name}")
|
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
|
||||||
|
|
||||||
# 优化模型
|
|
||||||
if cuda_available or mps_available:
|
|
||||||
try:
|
|
||||||
model = model.half()
|
|
||||||
model = torch.compile(model)
|
|
||||||
print(f"INFO: Using FP16 precision with model: {model_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
|
||||||
|
|
||||||
# 默认演示文档
|
|
||||||
demo_documents = {
|
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
|
||||||
1: "Machine learning builds systems that learn from data.",
|
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
|
||||||
}
|
|
||||||
|
|
||||||
passages = SimpleDocumentStore(demo_documents)
|
|
||||||
print(f"INFO: Loaded {len(passages)} demo documents")
|
|
||||||
|
|
||||||
class DeviceTimer:
|
|
||||||
"""设备计时器"""
|
|
||||||
def __init__(self, name="", device=device):
|
|
||||||
self.name = name
|
|
||||||
self.device = device
|
|
||||||
self.start_time = 0
|
|
||||||
self.end_time = 0
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
else:
|
|
||||||
self.start_event = None
|
|
||||||
self.end_event = None
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start()
|
|
||||||
yield
|
|
||||||
self.end()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
if cuda_available:
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.start_event.record()
|
|
||||||
else:
|
|
||||||
if self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
if cuda_available:
|
|
||||||
self.end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
else:
|
|
||||||
if self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.end_time = time.time()
|
|
||||||
|
|
||||||
def elapsed_time(self):
|
|
||||||
if cuda_available:
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
|
||||||
else:
|
|
||||||
return self.end_time - self.start_time
|
|
||||||
|
|
||||||
def print_elapsed(self):
|
|
||||||
print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds")
|
|
||||||
|
|
||||||
def process_batch(texts_batch, ids_batch, missing_ids):
|
|
||||||
"""处理文本批次"""
|
|
||||||
batch_size = len(texts_batch)
|
|
||||||
print(f"INFO: Processing batch of size {batch_size}")
|
|
||||||
|
|
||||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
|
||||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
|
||||||
embed_timer = DeviceTimer("embedding (batch)", device)
|
|
||||||
pool_timer = DeviceTimer("mean pooling (batch)", device)
|
|
||||||
|
|
||||||
with tokenize_timer.timing():
|
|
||||||
encoded_batch = tokenizer.batch_encode_plus(
|
|
||||||
texts_batch,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
max_length=256,
|
|
||||||
return_tensors="pt",
|
|
||||||
return_token_type_ids=False,
|
|
||||||
)
|
|
||||||
tokenize_timer.print_elapsed()
|
|
||||||
|
|
||||||
seq_length = encoded_batch["input_ids"].size(1)
|
|
||||||
print(f"Batch size: {batch_size}, Sequence length: {seq_length}")
|
|
||||||
|
|
||||||
with to_device_timer.timing():
|
|
||||||
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
|
||||||
to_device_timer.print_elapsed()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
with embed_timer.timing():
|
|
||||||
out = model(enc["input_ids"], enc["attention_mask"])
|
|
||||||
embed_timer.print_elapsed()
|
|
||||||
|
|
||||||
with pool_timer.timing():
|
|
||||||
hidden_states = out.last_hidden_state if hasattr(out, "last_hidden_state") else out
|
|
||||||
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
|
|
||||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
|
||||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
|
||||||
batch_embeddings = sum_embeddings / sum_mask
|
|
||||||
pool_timer.print_elapsed()
|
|
||||||
|
|
||||||
return batch_embeddings.cpu().numpy()
|
|
||||||
|
|
||||||
# ZMQ server 主循环 - 修改为REP套接字
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.ROUTER) # 改为REP套接字
|
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
|
||||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
|
||||||
|
|
||||||
# 设置超时
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5秒接收超时
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300秒发送超时
|
|
||||||
|
|
||||||
from . import embedding_pb2
|
|
||||||
|
|
||||||
print(f"INFO: Embedding server ready to serve requests")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
parts = socket.recv_multipart()
|
|
||||||
|
|
||||||
# --- 恢复稳健的消息格式判断 ---
|
|
||||||
# 必须检查 parts 的长度,避免 IndexError
|
|
||||||
if len(parts) >= 3:
|
|
||||||
identity = parts[0]
|
|
||||||
# empty = parts[1] # 中间的空帧我们通常不关心
|
|
||||||
message = parts[2]
|
|
||||||
elif len(parts) == 2:
|
|
||||||
# 也能处理没有空帧的情况
|
|
||||||
identity = parts[0]
|
|
||||||
message = parts[1]
|
|
||||||
else:
|
|
||||||
# 如果收到格式错误的消息,打印警告并忽略它,而不是崩溃
|
|
||||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
|
||||||
continue
|
|
||||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
|
||||||
|
|
||||||
e2e_start = time.time()
|
|
||||||
lookup_timer = DeviceTimer("text lookup", device)
|
|
||||||
|
|
||||||
# 解析请求
|
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
||||||
req_proto.ParseFromString(message)
|
|
||||||
node_ids = req_proto.node_ids
|
|
||||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
|
||||||
|
|
||||||
# 添加调试信息
|
|
||||||
if len(node_ids) > 0:
|
|
||||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
|
||||||
|
|
||||||
# 查找文本
|
|
||||||
texts = []
|
|
||||||
missing_ids = []
|
|
||||||
with lookup_timer.timing():
|
|
||||||
for nid in node_ids:
|
|
||||||
txtinfo = passages[nid]
|
|
||||||
txt = txtinfo["text"]
|
|
||||||
texts.append(txt)
|
|
||||||
lookup_timer.print_elapsed()
|
|
||||||
|
|
||||||
if missing_ids:
|
|
||||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
|
||||||
|
|
||||||
# 处理批次
|
|
||||||
total_size = len(texts)
|
|
||||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
|
||||||
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
if total_size > max_batch_size:
|
|
||||||
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
|
||||||
for i in range(0, total_size, max_batch_size):
|
|
||||||
end_idx = min(i + max_batch_size, total_size)
|
|
||||||
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
|
||||||
|
|
||||||
chunk_texts = texts[i:end_idx]
|
|
||||||
chunk_ids = node_ids[i:end_idx]
|
|
||||||
|
|
||||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
|
||||||
all_embeddings.append(embeddings_chunk)
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
|
|
||||||
hidden = np.vstack(all_embeddings)
|
|
||||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
|
||||||
else:
|
|
||||||
hidden = process_batch(texts, node_ids, missing_ids)
|
|
||||||
|
|
||||||
# 序列化响应
|
|
||||||
ser_start = time.time()
|
|
||||||
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
|
|
||||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
|
||||||
resp_proto.missing_ids.extend(missing_ids)
|
|
||||||
|
|
||||||
response_data = resp_proto.SerializeToString()
|
|
||||||
|
|
||||||
# REP 套接字发送单个响应
|
|
||||||
socket.send_multipart([identity, b'', response_data])
|
|
||||||
|
|
||||||
ser_end = time.time()
|
|
||||||
|
|
||||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
|
||||||
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
e2e_end = time.time()
|
|
||||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
|
||||||
# REP套接字不需要重新创建,只需要继续监听
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Error in ZMQ server: {e}")
|
|
||||||
try:
|
|
||||||
# 发送空响应以维持REQ-REP状态
|
|
||||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
socket.send(empty_resp.SerializeToString())
|
|
||||||
except:
|
|
||||||
# 如果发送失败,重新创建socket
|
|
||||||
socket.close()
|
|
||||||
socket = context.socket(zmq.REP)
|
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000)
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
||||||
print("INFO: ZMQ socket recreated after error")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to start embedding server: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# 保持原有的 create_embedding_server 函数不变,只添加线程化版本
|
|
||||||
def create_embedding_server(
|
|
||||||
domain="demo",
|
|
||||||
load_passages=True,
|
|
||||||
load_embeddings=False,
|
|
||||||
use_fp16=True,
|
|
||||||
use_int8=False,
|
|
||||||
use_cuda_graphs=False,
|
|
||||||
zmq_port=5555,
|
|
||||||
max_batch_size=128,
|
|
||||||
lazy_load_passages=False,
|
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
原有的 create_embedding_server 函数保持不变
|
|
||||||
这个是阻塞版本,用于直接运行
|
|
||||||
"""
|
|
||||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Embedding service")
|
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
|
||||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
|
||||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
|
||||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
|
||||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
|
||||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
|
||||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
help="Embedding model name")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
create_embedding_server(
|
|
||||||
domain=args.domain,
|
|
||||||
load_passages=args.load_passages,
|
|
||||||
load_embeddings=args.load_embeddings,
|
|
||||||
use_fp16=args.use_fp16,
|
|
||||||
use_int8=args.use_int8,
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
zmq_port=args.zmq_port,
|
|
||||||
max_batch_size=args.max_batch_size,
|
|
||||||
lazy_load_passages=args.lazy_load_passages,
|
|
||||||
model_name=args.model_name,
|
|
||||||
)
|
|
||||||
@@ -4,13 +4,16 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.1.0"
|
version = "0.2.6"
|
||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = ["leann-core==0.2.6", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# 关键:简化的 CMake 路径
|
# Key: simplified CMake path
|
||||||
cmake.source-dir = "third_party/DiskANN"
|
cmake.source-dir = "third_party/DiskANN"
|
||||||
# 关键:Python 包在根目录,路径完全匹配
|
# Key: Python package in root directory, paths match exactly
|
||||||
wheel.packages = ["leann_backend_diskann"]
|
wheel.packages = ["leann_backend_diskann"]
|
||||||
# 使用默认的 redirect 模式
|
# Use default redirect mode
|
||||||
editable.mode = "redirect"
|
editable.mode = "redirect"
|
||||||
|
cmake.build-type = "Release"
|
||||||
|
build.verbose = true
|
||||||
|
build.tool-args = ["-j8"]
|
||||||
|
|||||||
1
packages/leann-backend-diskann/third_party/DiskANN
vendored
Submodule
1
packages/leann-backend-diskann/third_party/DiskANN
vendored
Submodule
Submodule packages/leann-backend-diskann/third_party/DiskANN added at b2dc4ea2c7
@@ -1,6 +0,0 @@
|
|||||||
---
|
|
||||||
BasedOnStyle: Microsoft
|
|
||||||
---
|
|
||||||
Language: Cpp
|
|
||||||
SortIncludes: false
|
|
||||||
...
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
# Set the default behavior, in case people don't have core.autocrlf set.
|
|
||||||
* text=auto
|
|
||||||
|
|
||||||
# Explicitly declare text files you want to always be normalized and converted
|
|
||||||
# to native line endings on checkout.
|
|
||||||
*.c text
|
|
||||||
*.h text
|
|
||||||
|
|
||||||
# Declare files that will always have CRLF line endings on checkout.
|
|
||||||
*.sln text eol=crlf
|
|
||||||
|
|
||||||
# Denote all files that are truly binary and should not be modified.
|
|
||||||
*.png binary
|
|
||||||
*.jpg binary
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
---
|
|
||||||
name: Bug report
|
|
||||||
about: Bug reports help us improve! Thanks for submitting yours!
|
|
||||||
title: "[BUG] "
|
|
||||||
labels: bug
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Expected Behavior
|
|
||||||
Tell us what should happen
|
|
||||||
|
|
||||||
## Actual Behavior
|
|
||||||
Tell us what happens instead
|
|
||||||
|
|
||||||
## Example Code
|
|
||||||
Please see [How to create a Minimal, Reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) for some guidance on creating the best possible example of the problem
|
|
||||||
```bash
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dataset Description
|
|
||||||
Please tell us about the shape and datatype of your data, (e.g. 128 dimensions, 12.3 billion points, floats)
|
|
||||||
- Dimensions:
|
|
||||||
- Number of Points:
|
|
||||||
- Data type:
|
|
||||||
|
|
||||||
## Error
|
|
||||||
```
|
|
||||||
Paste the full error, with any sensitive information minimally redacted and marked $$REDACTED$$
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
## Your Environment
|
|
||||||
* Operating system (e.g. Windows 11 Pro, Ubuntu 22.04.1 LTS)
|
|
||||||
* DiskANN version (or commit built from)
|
|
||||||
|
|
||||||
## Additional Details
|
|
||||||
Any other contextual information you might feel is important.
|
|
||||||
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
blank_issues_enabled: false
|
|
||||||
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
---
|
|
||||||
name: Feature request
|
|
||||||
about: Suggest an idea for this project
|
|
||||||
title: ''
|
|
||||||
labels: enhancement
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Is your feature request related to a problem? Please describe.
|
|
||||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
|
||||||
|
|
||||||
## Describe the solution you'd like
|
|
||||||
A clear and concise description of what you want to happen.
|
|
||||||
|
|
||||||
## Describe alternatives you've considered
|
|
||||||
A clear and concise description of any alternative solutions or features you've considered.
|
|
||||||
|
|
||||||
## Provide references (if applicable)
|
|
||||||
If your feature request is related to a published algorithm/idea, please provide links to
|
|
||||||
any relevant articles or webpages.
|
|
||||||
|
|
||||||
## Additional context
|
|
||||||
Add any other context or screenshots about the feature request here.
|
|
||||||
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
---
|
|
||||||
name: Usage Question
|
|
||||||
about: Ask us a question about DiskANN!
|
|
||||||
title: "[Question]"
|
|
||||||
labels: question
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
This is our forum for asking whatever DiskANN question you'd like! No need to feel shy - we're happy to talk about use cases and optimal tuning strategies!
|
|
||||||
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
<!--
|
|
||||||
Thanks for contributing a pull request! Please ensure you have taken a look at
|
|
||||||
the contribution guidelines: https://github.com/microsoft/DiskANN/blob/main/CONTRIBUTING.md
|
|
||||||
-->
|
|
||||||
- [ ] Does this PR have a descriptive title that could go in our release notes?
|
|
||||||
- [ ] Does this PR add any new dependencies?
|
|
||||||
- [ ] Does this PR modify any existing APIs?
|
|
||||||
- [ ] Is the change to the API backwards compatible?
|
|
||||||
- [ ] Should this result in any changes to our documentation, either updating existing docs or adding new ones?
|
|
||||||
|
|
||||||
#### Reference Issues/PRs
|
|
||||||
<!--
|
|
||||||
Example: Fixes #1234. See also #3456.
|
|
||||||
Please use keywords (e.g., Fixes) to create link to the issues or pull requests
|
|
||||||
you resolved, so that they will automatically be closed when your pull request
|
|
||||||
is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests
|
|
||||||
-->
|
|
||||||
|
|
||||||
#### What does this implement/fix? Briefly explain your changes.
|
|
||||||
|
|
||||||
#### Any other comments?
|
|
||||||
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
name: 'DiskANN Build Bootstrap'
|
|
||||||
description: 'Prepares DiskANN build environment and executes build'
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
# ------------ Linux Build ---------------
|
|
||||||
- name: Prepare and Execute Build
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
run: |
|
|
||||||
sudo scripts/dev/install-dev-deps-ubuntu.bash
|
|
||||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True
|
|
||||||
cmake --build build -- -j
|
|
||||||
cmake --install build --prefix="dist"
|
|
||||||
shell: bash
|
|
||||||
# ------------ End Linux Build ---------------
|
|
||||||
# ------------ Windows Build ---------------
|
|
||||||
- name: Add VisualStudio command line tools into path
|
|
||||||
if: runner.os == 'Windows'
|
|
||||||
uses: ilammy/msvc-dev-cmd@v1
|
|
||||||
- name: Run configure and build for Windows
|
|
||||||
if: runner.os == 'Windows'
|
|
||||||
run: |
|
|
||||||
mkdir build && cd build && cmake .. -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
|
|
||||||
cd ..
|
|
||||||
mkdir dist
|
|
||||||
mklink /j .\dist\bin .\x64\Release\
|
|
||||||
shell: cmd
|
|
||||||
# ------------ End Windows Build ---------------
|
|
||||||
# ------------ Windows Build With EXEC_ENV_OLS and USE_BING_INFRA ---------------
|
|
||||||
- name: Add VisualStudio command line tools into path
|
|
||||||
if: runner.os == 'Windows'
|
|
||||||
uses: ilammy/msvc-dev-cmd@v1
|
|
||||||
- name: Run configure and build for Windows with Bing feature flags
|
|
||||||
if: runner.os == 'Windows'
|
|
||||||
run: |
|
|
||||||
mkdir build_bing && cd build_bing && cmake .. -DEXEC_ENV_OLS=1 -DUSE_BING_INFRA=1 -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
|
|
||||||
cd ..
|
|
||||||
shell: cmd
|
|
||||||
# ------------ End Windows Build ---------------
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
name: 'Checking code formatting...'
|
|
||||||
description: 'Ensures code complies with code formatting rules'
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Checking code formatting...
|
|
||||||
run: |
|
|
||||||
sudo apt install clang-format
|
|
||||||
find include -name '*.h' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
|
||||||
find src -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
|
||||||
find apps -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
|
||||||
find python -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
|
||||||
shell: bash
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
name: 'Generating Random Data (Basic)'
|
|
||||||
description: 'Generates the random data files used in acceptance tests'
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Generate Random Data (Basic)
|
|
||||||
run: |
|
|
||||||
mkdir data
|
|
||||||
|
|
||||||
echo "Generating random 1020,1024,1536D float and 4096 int8 vectors for index"
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_5K_norm1.0.bin -D 1020 -N 5000 --norm 1.0
|
|
||||||
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_5K_norm1.0.bin -D 1024 -N 5000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_5K_norm1.0.bin -D 1536 -N 5000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_5K_norm1.0.bin -D 4096 -N 5000 --norm 1.0
|
|
||||||
|
|
||||||
echo "Generating random 1020,1024,1536D float and 4096D int8 avectors for query"
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_1K_norm1.0.bin -D 1020 -N 1000 --norm 1.0
|
|
||||||
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_1K_norm1.0.bin -D 1024 -N 1000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_1K_norm1.0.bin -D 1536 -N 1000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_1K_norm1.0.bin -D 4096 -N 1000 --norm 1.0
|
|
||||||
|
|
||||||
echo "Computing ground truth for 1020,1024,1536D float and 4096D int8 avectors for query"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1020D_5K_norm1.0.bin --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --K 100
|
|
||||||
#dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1024D_5K_norm1.0.bin --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1536D_5K_norm1.0.bin --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_4096D_5K_norm1.0.bin --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --K 100
|
|
||||||
|
|
||||||
shell: bash
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
name: 'Generating Random Data (Basic)'
|
|
||||||
description: 'Generates the random data files used in acceptance tests'
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Generate Random Data (Basic)
|
|
||||||
run: |
|
|
||||||
mkdir data
|
|
||||||
|
|
||||||
echo "Generating random vectors for index"
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_norm1.0.bin -D 10 -N 10000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_unnorm.bin -D 10 -N 10000 --rand_scaling 2.0
|
|
||||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
|
||||||
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
|
||||||
|
|
||||||
echo "Generating random vectors for query"
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_norm1.0.bin -D 10 -N 1000 --norm 1.0
|
|
||||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_unnorm.bin -D 10 -N 1000 --rand_scaling 2.0
|
|
||||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
|
||||||
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
|
||||||
|
|
||||||
echo "Computing ground truth for floats across l2, mips, and cosine distance functions"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn mips --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_unnorm.bin --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --K 100
|
|
||||||
|
|
||||||
echo "Computing ground truth for int8s across l2, mips, and cosine distance functions"
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn mips --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/mips_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn cosine --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
|
|
||||||
echo "Computing ground truth for uint8s across l2, mips, and cosine distance functions"
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn mips --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
|
|
||||||
shell: bash
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
name: Build Python Wheel
|
|
||||||
description: Builds a python wheel with cibuildwheel
|
|
||||||
inputs:
|
|
||||||
cibw-identifier:
|
|
||||||
description: "CI build wheel identifier to build"
|
|
||||||
required: true
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- uses: actions/setup-python@v3
|
|
||||||
- name: Install cibuildwheel
|
|
||||||
run: python -m pip install cibuildwheel==2.11.3
|
|
||||||
shell: bash
|
|
||||||
- name: Building Python ${{inputs.cibw-identifier}} Wheel
|
|
||||||
run: python -m cibuildwheel --output-dir dist
|
|
||||||
env:
|
|
||||||
CIBW_BUILD: ${{inputs.cibw-identifier}}
|
|
||||||
shell: bash
|
|
||||||
- uses: actions/upload-artifact@v3
|
|
||||||
with:
|
|
||||||
name: wheels
|
|
||||||
path: ./dist/*.whl
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
name: DiskANN Build PDoc Documentation
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
build-reference-documentation:
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Set up Python 3.9
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: 3.9
|
|
||||||
- name: Install python build
|
|
||||||
run: python -m pip install build
|
|
||||||
shell: bash
|
|
||||||
# Install required dependencies
|
|
||||||
- name: Prepare Linux environment
|
|
||||||
run: |
|
|
||||||
sudo scripts/dev/install-dev-deps-ubuntu.bash
|
|
||||||
shell: bash
|
|
||||||
# We need to build the wheel in order to run pdoc. pdoc does not seem to work if you just point it at
|
|
||||||
# our source directory.
|
|
||||||
- name: Building Python Wheel for documentation generation
|
|
||||||
run: python -m build --wheel --outdir documentation_dist
|
|
||||||
shell: bash
|
|
||||||
- name: "Run Reference Documentation Generation"
|
|
||||||
run: |
|
|
||||||
pip install pdoc pipdeptree
|
|
||||||
pip install documentation_dist/*.whl
|
|
||||||
echo "documentation" > dependencies_documentation.txt
|
|
||||||
pipdeptree >> dependencies_documentation.txt
|
|
||||||
pdoc -o docs/python/html diskannpy
|
|
||||||
- name: Create version environment variable
|
|
||||||
run: |
|
|
||||||
echo "DISKANN_VERSION=$(python <<EOF
|
|
||||||
from importlib.metadata import version
|
|
||||||
v = version('diskannpy')
|
|
||||||
print(v)
|
|
||||||
EOF
|
|
||||||
)" >> $GITHUB_ENV
|
|
||||||
- name: Archive documentation version artifact
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dependencies
|
|
||||||
path: |
|
|
||||||
${{ github.run_id }}-dependencies_documentation.txt
|
|
||||||
overwrite: true
|
|
||||||
- name: Archive documentation artifacts
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: documentation-site
|
|
||||||
path: |
|
|
||||||
docs/python/html
|
|
||||||
# Publish to /dev if we are on the "main" branch
|
|
||||||
- name: Publish reference docs for latest development version (main branch)
|
|
||||||
uses: peaceiris/actions-gh-pages@v3
|
|
||||||
if: github.ref == 'refs/heads/main'
|
|
||||||
with:
|
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
publish_dir: docs/python/html
|
|
||||||
destination_dir: docs/python/dev
|
|
||||||
# Publish to /<version> if we are releasing
|
|
||||||
- name: Publish reference docs by version (main branch)
|
|
||||||
uses: peaceiris/actions-gh-pages@v3
|
|
||||||
if: github.event_name == 'release'
|
|
||||||
with:
|
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
publish_dir: docs/python/html
|
|
||||||
destination_dir: docs/python/${{ env.DISKANN_VERSION }}
|
|
||||||
# Publish to /latest if we are releasing
|
|
||||||
- name: Publish latest reference docs (main branch)
|
|
||||||
uses: peaceiris/actions-gh-pages@v3
|
|
||||||
if: github.event_name == 'release'
|
|
||||||
with:
|
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
publish_dir: docs/python/html
|
|
||||||
destination_dir: docs/python/latest
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
name: DiskANN Build Python Wheel
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
linux-build:
|
|
||||||
name: Python - Ubuntu - ${{matrix.cibw-identifier}}
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
cibw-identifier: ["cp39-manylinux_x86_64", "cp310-manylinux_x86_64", "cp311-manylinux_x86_64"]
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Building python wheel ${{matrix.cibw-identifier}}
|
|
||||||
uses: ./.github/actions/python-wheel
|
|
||||||
with:
|
|
||||||
cibw-identifier: ${{matrix.cibw-identifier}}
|
|
||||||
windows-build:
|
|
||||||
name: Python - Windows - ${{matrix.cibw-identifier}}
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
cibw-identifier: ["cp39-win_amd64", "cp310-win_amd64", "cp311-win_amd64"]
|
|
||||||
runs-on: windows-latest
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Building python wheel ${{matrix.cibw-identifier}}
|
|
||||||
uses: ./.github/actions/python-wheel
|
|
||||||
with:
|
|
||||||
cibw-identifier: ${{matrix.cibw-identifier}}
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
name: DiskANN Common Checks
|
|
||||||
# common means common to both pr-test and push-test
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
formatting-check:
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: Code Formatting Test
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checking code formatting...
|
|
||||||
uses: ./.github/actions/format-check
|
|
||||||
docker-container-build:
|
|
||||||
name: Docker Container Build
|
|
||||||
needs: [formatting-check]
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Docker build
|
|
||||||
run: |
|
|
||||||
docker build .
|
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
name: Disk With PQ
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-disk-pq:
|
|
||||||
name: Disk, PQ
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, cosine, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16\
|
|
||||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (sharded graph build, cosine, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (int8)
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (sharded graph build, MIPS, diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 --PQ_disk_bytes 5
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: disk-pq-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
name: Dynamic-Labels
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-dynamic:
|
|
||||||
name: Dynamic-Labels
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: Generate Labels
|
|
||||||
run: |
|
|
||||||
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
|
|
||||||
|
|
||||||
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
|
|
||||||
|
|
||||||
- name: Test a streaming index (float) with labels (Zipf distributed)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_zipf_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
|
|
||||||
|
|
||||||
echo "Computing groundtruth with filter"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 --label_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
echo "Searching with filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
echo "Computing groundtruth w/o filter"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000
|
|
||||||
echo "Searching without filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
|
|
||||||
|
|
||||||
- name: Test a streaming index (float) with labels (random distributed)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_rand_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
|
|
||||||
|
|
||||||
echo "Computing groundtruth with filter"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 --label_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
echo "Searching with filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
echo "Computing groundtruth w/o filter"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000
|
|
||||||
echo "Searching without filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
|
|
||||||
|
|
||||||
- name: Test Insert Delete Consolidate (float) with labels (zipf distributed)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/zipf_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_zipf_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
|
|
||||||
|
|
||||||
echo "Computing groundtruth with filter"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K_wlabel_5 --label_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
echo "Searching with filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 10 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_zipf_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
echo "Computing groundtruth w/o filter"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K
|
|
||||||
echo "Searching without filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
|
|
||||||
|
|
||||||
- name: Test Insert Delete Consolidate (float) with labels (random distributed)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/rand_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_rand_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
|
|
||||||
|
|
||||||
echo "Computing groundtruth with filter"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K_wlabel_5 --label_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
echo "Searching with filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 40 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_rand_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
echo "Computing groundtruth w/o filter"
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K
|
|
||||||
echo "Searching without filter"
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dynamic-labels-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
name: Dynamic
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-dynamic:
|
|
||||||
name: Dynamic
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: test a streaming index (float)
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
|
||||||
- name: test a streaming index (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
|
||||||
- name: test a streaming index
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
|
||||||
|
|
||||||
- name: build and search an incremental index (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2;
|
|
||||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
|
||||||
- name: build and search an incremental index (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
|
|
||||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
|
||||||
- name: build and search an incremental index (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_10K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dynamic-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
name: In-Memory Without PQ
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-mem-no-pq:
|
|
||||||
name: In-Mem, Without PQ
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: build and search in-memory index with L2 metrics (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
- name: build and search in-memory index with L2 metrics (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
- name: build and search in-memory index with L2 metrics (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: Searching with fast_l2 distance function (float)
|
|
||||||
if: runner.os != 'Windows' && (success() || failure())
|
|
||||||
run: |
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: build and search in-memory index with MIPS metric (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_mips_rand_float_10D_10K_norm1.0
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn mips --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: build and search in-memory index with cosine metric (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_cosine_rand_float_10D_10K_norm1.0
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
- name: build and search in-memory index with cosine metric (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type int8 --dist_fn cosine --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_int8_10D_10K_norm50.0
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
- name: build and search in-memory index with cosine metric
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50.0
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: in-memory-no-pq-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
name: In-Memory With PQ
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-mem-pq:
|
|
||||||
name: In-Mem, PQ
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: build and search in-memory index with L2 metric with PQ based distance comparisons (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (uint8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: in-memory-pq-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
name: Labels
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-labels:
|
|
||||||
name: Labels
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-random
|
|
||||||
|
|
||||||
- name: Generate Labels
|
|
||||||
run: |
|
|
||||||
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
|
|
||||||
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/mips_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
|
|
||||||
echo "Generating synthetic labels and computing ground truth for filtered search without a universal label"
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --K 100
|
|
||||||
dist/bin/generate_synthetic_labels --num_labels 10 --num_points 1000 --output_file data/query_labels_1K.txt --distribution_type one_per_point
|
|
||||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label_file data/query_labels_1K.txt --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
|
||||||
|
|
||||||
- name: build and search in-memory index with labels using L2 and Cosine metrics (random distributed labels)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
|
|
||||||
echo "Searching without filters"
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
|
||||||
|
|
||||||
- name: build and search disk index with labels using L2 and Cosine metrics (random distributed labels)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: build and search in-memory index with labels using L2 and Cosine metrics (zipf distributed labels)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
|
|
||||||
echo "Searching without filters"
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
|
||||||
|
|
||||||
- name: build and search disk index with labels using L2 and Cosine metrics (zipf distributed labels)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
|
|
||||||
- name : build and search in-memory and disk index (without universal label, zipf distributed)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal
|
|
||||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal -R 32 -L 5 -B 0.00003 -M 1
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal -L 16 32
|
|
||||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
|
||||||
- name: Generate combined GT for each query with a separate label and search
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --query_filters_file data/query_labels_1K.txt --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
- name: build and search in-memory index with pq_dist of 5 with 10 dimensions
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5
|
|
||||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
|
||||||
- name: Build and search stitched vamana with random and zipf distributed labels
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_rand_32_100_64_new --universal_label 0
|
|
||||||
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_zipf_32_100_64_new --universal_label 0
|
|
||||||
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix data/stit_rand_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/rand_stit_96_10_90_new --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
|
|
||||||
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/stit_zipf_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/zipf_stit_96_10_90_new --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
if: success() || failure()
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: labels-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
name: Disk With PQ
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-disk-pq:
|
|
||||||
name: Disk, PQ
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Generate Data
|
|
||||||
uses: ./.github/actions/generate-high-dim-random
|
|
||||||
|
|
||||||
- name: build and search disk index (1020D, one shot graph build, L2, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1020D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
|
||||||
#- name: build and search disk index (1024D, one shot graph build, L2, no diskPQ) (float)
|
|
||||||
# if: success() || failure()
|
|
||||||
# run: |
|
|
||||||
# dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1024D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
|
||||||
# dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
|
||||||
- name: build and search disk index (1536D, one shot graph build, L2, no diskPQ) (float)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1536D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
|
||||||
|
|
||||||
- name: build and search disk index (4096D, one shot graph build, L2, no diskPQ) (int8)
|
|
||||||
if: success() || failure()
|
|
||||||
run: |
|
|
||||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_4096D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
|
||||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
|
||||||
|
|
||||||
- name: upload data and bin
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: multi-sector-disk-pq-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./dist/**
|
|
||||||
./data/**
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
name: DiskANN Nightly Performance Metrics
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: "41 14 * * *" # 14:41 UTC, 7:41 PDT, 8:41 PST, 08:11 IST
|
|
||||||
jobs:
|
|
||||||
perf-test:
|
|
||||||
name: Run Perf Test from main
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Build Perf Container
|
|
||||||
run: |
|
|
||||||
docker build --build-arg GIT_COMMIT_ISH="$GITHUB_SHA" -t perf -f scripts/perf/Dockerfile scripts
|
|
||||||
- name: Performance Tests
|
|
||||||
run: |
|
|
||||||
mkdir metrics
|
|
||||||
docker run -v ./metrics:/app/logs perf &> ./metrics/combined_stdouterr.log
|
|
||||||
- name: Upload Metrics Logs
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: metrics-${{matrix.os}}
|
|
||||||
path: |
|
|
||||||
./metrics/**
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
name: DiskANN Pull Request Build and Test
|
|
||||||
on: [pull_request]
|
|
||||||
jobs:
|
|
||||||
common:
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: DiskANN Common Build Checks
|
|
||||||
uses: ./.github/workflows/common.yml
|
|
||||||
unit-tests:
|
|
||||||
name: Unit tests
|
|
||||||
uses: ./.github/workflows/unit-tests.yml
|
|
||||||
in-mem-pq:
|
|
||||||
name: In-Memory with PQ
|
|
||||||
uses: ./.github/workflows/in-mem-pq.yml
|
|
||||||
in-mem-no-pq:
|
|
||||||
name: In-Memory without PQ
|
|
||||||
uses: ./.github/workflows/in-mem-no-pq.yml
|
|
||||||
disk-pq:
|
|
||||||
name: Disk with PQ
|
|
||||||
uses: ./.github/workflows/disk-pq.yml
|
|
||||||
multi-sector-disk-pq:
|
|
||||||
name: Multi-sector Disk with PQ
|
|
||||||
uses: ./.github/workflows/multi-sector-disk-pq.yml
|
|
||||||
labels:
|
|
||||||
name: Labels
|
|
||||||
uses: ./.github/workflows/labels.yml
|
|
||||||
dynamic:
|
|
||||||
name: Dynamic
|
|
||||||
uses: ./.github/workflows/dynamic.yml
|
|
||||||
dynamic-labels:
|
|
||||||
name: Dynamic Labels
|
|
||||||
uses: ./.github/workflows/dynamic-labels.yml
|
|
||||||
python:
|
|
||||||
name: Python
|
|
||||||
uses: ./.github/workflows/build-python.yml
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
name: DiskANN Push Build
|
|
||||||
on: [push]
|
|
||||||
jobs:
|
|
||||||
common:
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: DiskANN Common Build Checks
|
|
||||||
uses: ./.github/workflows/common.yml
|
|
||||||
build-documentation:
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: DiskANN Build Documentation
|
|
||||||
uses: ./.github/workflows/build-python-pdoc.yml
|
|
||||||
build:
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ ubuntu-latest, windows-2019, windows-latest ]
|
|
||||||
name: Build for ${{matrix.os}}
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: Build diskannpy dependency tree
|
|
||||||
run: |
|
|
||||||
pip install diskannpy pipdeptree
|
|
||||||
echo "dependencies" > dependencies_${{ matrix.os }}.txt
|
|
||||||
pipdeptree >> dependencies_${{ matrix.os }}.txt
|
|
||||||
- name: Archive diskannpy dependencies artifact
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dependencies_${{ matrix.os }}
|
|
||||||
path: |
|
|
||||||
dependencies_${{ matrix.os }}.txt
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
name: Build and Release Python Wheels
|
|
||||||
on:
|
|
||||||
release:
|
|
||||||
types: [published]
|
|
||||||
jobs:
|
|
||||||
python-release-wheels:
|
|
||||||
name: Python
|
|
||||||
uses: ./.github/workflows/build-python.yml
|
|
||||||
build-documentation:
|
|
||||||
strategy:
|
|
||||||
fail-fast: true
|
|
||||||
name: DiskANN Build Documentation
|
|
||||||
uses: ./.github/workflows/build-python-pdoc.yml
|
|
||||||
release:
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: python-release-wheels
|
|
||||||
steps:
|
|
||||||
- uses: actions/download-artifact@v3
|
|
||||||
with:
|
|
||||||
name: wheels
|
|
||||||
path: dist/
|
|
||||||
- name: Generate SHA256 files for each wheel
|
|
||||||
run: |
|
|
||||||
sha256sum dist/*.whl > checksums.txt
|
|
||||||
cat checksums.txt
|
|
||||||
- uses: actions/setup-python@v3
|
|
||||||
- name: Install twine
|
|
||||||
run: python -m pip install twine
|
|
||||||
- name: Publish with twine
|
|
||||||
env:
|
|
||||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
|
||||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
|
||||||
run: |
|
|
||||||
twine upload dist/*.whl
|
|
||||||
- name: Update release with SHA256 and Artifacts
|
|
||||||
uses: softprops/action-gh-release@v1
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
files: |
|
|
||||||
dist/*.whl
|
|
||||||
checksums.txt
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
name: Unit Tests
|
|
||||||
on: [workflow_call]
|
|
||||||
jobs:
|
|
||||||
acceptance-tests-labels:
|
|
||||||
name: Unit Tests
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
|
||||||
runs-on: ${{matrix.os}}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: bash
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Linux' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
- name: Checkout repository
|
|
||||||
if: ${{ runner.os == 'Windows' }}
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
submodules: true
|
|
||||||
- name: DiskANN Build CLI Applications
|
|
||||||
uses: ./.github/actions/build
|
|
||||||
|
|
||||||
- name: Run Unit Tests
|
|
||||||
run: |
|
|
||||||
cd build
|
|
||||||
ctest -C Release
|
|
||||||
@@ -1,384 +0,0 @@
|
|||||||
## Ignore Visual Studio temporary files, build results, and
|
|
||||||
## files generated by popular Visual Studio add-ons.
|
|
||||||
##
|
|
||||||
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
|
||||||
|
|
||||||
# User-specific files
|
|
||||||
*.rsuser
|
|
||||||
*.suo
|
|
||||||
*.user
|
|
||||||
*.userosscache
|
|
||||||
*.sln.docstates
|
|
||||||
|
|
||||||
# User-specific files (MonoDevelop/Xamarin Studio)
|
|
||||||
*.userprefs
|
|
||||||
|
|
||||||
# Mono auto generated files
|
|
||||||
mono_crash.*
|
|
||||||
|
|
||||||
# Build results
|
|
||||||
[Dd]ebug/
|
|
||||||
[Dd]ebugPublic/
|
|
||||||
[Rr]elease/
|
|
||||||
[Rr]eleases/
|
|
||||||
x64/
|
|
||||||
x86/
|
|
||||||
[Aa][Rr][Mm]/
|
|
||||||
[Aa][Rr][Mm]64/
|
|
||||||
bld/
|
|
||||||
[Bb]in/
|
|
||||||
[Oo]bj/
|
|
||||||
[Ll]og/
|
|
||||||
[Ll]ogs/
|
|
||||||
|
|
||||||
# Visual Studio 2015/2017 cache/options directory
|
|
||||||
.vs/
|
|
||||||
# Uncomment if you have tasks that create the project's static files in wwwroot
|
|
||||||
#wwwroot/
|
|
||||||
|
|
||||||
# Visual Studio 2017 auto generated files
|
|
||||||
Generated\ Files/
|
|
||||||
|
|
||||||
# MSTest test Results
|
|
||||||
[Tt]est[Rr]esult*/
|
|
||||||
[Bb]uild[Ll]og.*
|
|
||||||
|
|
||||||
# NUnit
|
|
||||||
*.VisualState.xml
|
|
||||||
TestResult.xml
|
|
||||||
nunit-*.xml
|
|
||||||
|
|
||||||
# Build Results of an ATL Project
|
|
||||||
[Dd]ebugPS/
|
|
||||||
[Rr]eleasePS/
|
|
||||||
dlldata.c
|
|
||||||
|
|
||||||
# Benchmark Results
|
|
||||||
BenchmarkDotNet.Artifacts/
|
|
||||||
|
|
||||||
# .NET Core
|
|
||||||
project.lock.json
|
|
||||||
project.fragment.lock.json
|
|
||||||
artifacts/
|
|
||||||
|
|
||||||
# StyleCop
|
|
||||||
StyleCopReport.xml
|
|
||||||
|
|
||||||
# Files built by Visual Studio
|
|
||||||
*_i.c
|
|
||||||
*_p.c
|
|
||||||
*_h.h
|
|
||||||
*.ilk
|
|
||||||
*.meta
|
|
||||||
*.obj
|
|
||||||
*.iobj
|
|
||||||
*.pch
|
|
||||||
*.pdb
|
|
||||||
*.ipdb
|
|
||||||
*.pgc
|
|
||||||
*.pgd
|
|
||||||
*.rsp
|
|
||||||
*.sbr
|
|
||||||
*.tlb
|
|
||||||
*.tli
|
|
||||||
*.tlh
|
|
||||||
*.tmp
|
|
||||||
*.tmp_proj
|
|
||||||
*_wpftmp.csproj
|
|
||||||
*.log
|
|
||||||
*.vspscc
|
|
||||||
*.vssscc
|
|
||||||
.builds
|
|
||||||
*.pidb
|
|
||||||
*.svclog
|
|
||||||
*.scc
|
|
||||||
|
|
||||||
# Chutzpah Test files
|
|
||||||
_Chutzpah*
|
|
||||||
|
|
||||||
# Visual C++ cache files
|
|
||||||
ipch/
|
|
||||||
*.aps
|
|
||||||
*.ncb
|
|
||||||
*.opendb
|
|
||||||
*.opensdf
|
|
||||||
*.sdf
|
|
||||||
*.cachefile
|
|
||||||
*.VC.db
|
|
||||||
*.VC.VC.opendb
|
|
||||||
|
|
||||||
# Visual Studio profiler
|
|
||||||
*.psess
|
|
||||||
*.vsp
|
|
||||||
*.vspx
|
|
||||||
*.sap
|
|
||||||
|
|
||||||
# Visual Studio Trace Files
|
|
||||||
*.e2e
|
|
||||||
|
|
||||||
# TFS 2012 Local Workspace
|
|
||||||
$tf/
|
|
||||||
|
|
||||||
# Guidance Automation Toolkit
|
|
||||||
*.gpState
|
|
||||||
|
|
||||||
# ReSharper is a .NET coding add-in
|
|
||||||
_ReSharper*/
|
|
||||||
*.[Rr]e[Ss]harper
|
|
||||||
*.DotSettings.user
|
|
||||||
|
|
||||||
# TeamCity is a build add-in
|
|
||||||
_TeamCity*
|
|
||||||
|
|
||||||
# DotCover is a Code Coverage Tool
|
|
||||||
*.dotCover
|
|
||||||
|
|
||||||
# AxoCover is a Code Coverage Tool
|
|
||||||
.axoCover/*
|
|
||||||
!.axoCover/settings.json
|
|
||||||
|
|
||||||
# Visual Studio code coverage results
|
|
||||||
*.coverage
|
|
||||||
*.coveragexml
|
|
||||||
|
|
||||||
# NCrunch
|
|
||||||
_NCrunch_*
|
|
||||||
.*crunch*.local.xml
|
|
||||||
nCrunchTemp_*
|
|
||||||
|
|
||||||
# MightyMoose
|
|
||||||
*.mm.*
|
|
||||||
AutoTest.Net/
|
|
||||||
|
|
||||||
# Web workbench (sass)
|
|
||||||
.sass-cache/
|
|
||||||
|
|
||||||
# Installshield output folder
|
|
||||||
[Ee]xpress/
|
|
||||||
|
|
||||||
# DocProject is a documentation generator add-in
|
|
||||||
DocProject/buildhelp/
|
|
||||||
DocProject/Help/*.HxT
|
|
||||||
DocProject/Help/*.HxC
|
|
||||||
DocProject/Help/*.hhc
|
|
||||||
DocProject/Help/*.hhk
|
|
||||||
DocProject/Help/*.hhp
|
|
||||||
DocProject/Help/Html2
|
|
||||||
DocProject/Help/html
|
|
||||||
|
|
||||||
# Click-Once directory
|
|
||||||
publish/
|
|
||||||
|
|
||||||
# Publish Web Output
|
|
||||||
*.[Pp]ublish.xml
|
|
||||||
*.azurePubxml
|
|
||||||
# Note: Comment the next line if you want to checkin your web deploy settings,
|
|
||||||
# but database connection strings (with potential passwords) will be unencrypted
|
|
||||||
*.pubxml
|
|
||||||
*.publishproj
|
|
||||||
|
|
||||||
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
|
||||||
# checkin your Azure Web App publish settings, but sensitive information contained
|
|
||||||
# in these scripts will be unencrypted
|
|
||||||
PublishScripts/
|
|
||||||
|
|
||||||
# NuGet Packages
|
|
||||||
*.nupkg
|
|
||||||
# NuGet Symbol Packages
|
|
||||||
*.snupkg
|
|
||||||
# The packages folder can be ignored because of Package Restore
|
|
||||||
**/[Pp]ackages/*
|
|
||||||
# except build/, which is used as an MSBuild target.
|
|
||||||
!**/[Pp]ackages/build/
|
|
||||||
# Uncomment if necessary however generally it will be regenerated when needed
|
|
||||||
#!**/[Pp]ackages/repositories.config
|
|
||||||
# NuGet v3's project.json files produces more ignorable files
|
|
||||||
*.nuget.props
|
|
||||||
*.nuget.targets
|
|
||||||
|
|
||||||
# Microsoft Azure Build Output
|
|
||||||
csx/
|
|
||||||
*.build.csdef
|
|
||||||
|
|
||||||
# Microsoft Azure Emulator
|
|
||||||
ecf/
|
|
||||||
rcf/
|
|
||||||
|
|
||||||
# Windows Store app package directories and files
|
|
||||||
AppPackages/
|
|
||||||
BundleArtifacts/
|
|
||||||
Package.StoreAssociation.xml
|
|
||||||
_pkginfo.txt
|
|
||||||
*.appx
|
|
||||||
*.appxbundle
|
|
||||||
*.appxupload
|
|
||||||
|
|
||||||
# Visual Studio cache files
|
|
||||||
# files ending in .cache can be ignored
|
|
||||||
*.[Cc]ache
|
|
||||||
# but keep track of directories ending in .cache
|
|
||||||
!?*.[Cc]ache/
|
|
||||||
|
|
||||||
# Others
|
|
||||||
ClientBin/
|
|
||||||
~$*
|
|
||||||
*~
|
|
||||||
*.dbmdl
|
|
||||||
*.dbproj.schemaview
|
|
||||||
*.jfm
|
|
||||||
*.pfx
|
|
||||||
*.publishsettings
|
|
||||||
orleans.codegen.cs
|
|
||||||
|
|
||||||
# Including strong name files can present a security risk
|
|
||||||
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
|
||||||
#*.snk
|
|
||||||
|
|
||||||
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
|
||||||
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
|
||||||
#bower_components/
|
|
||||||
|
|
||||||
# RIA/Silverlight projects
|
|
||||||
Generated_Code/
|
|
||||||
|
|
||||||
# Backup & report files from converting an old project file
|
|
||||||
# to a newer Visual Studio version. Backup files are not needed,
|
|
||||||
# because we have git ;-)
|
|
||||||
_UpgradeReport_Files/
|
|
||||||
Backup*/
|
|
||||||
UpgradeLog*.XML
|
|
||||||
UpgradeLog*.htm
|
|
||||||
ServiceFabricBackup/
|
|
||||||
*.rptproj.bak
|
|
||||||
|
|
||||||
# SQL Server files
|
|
||||||
*.mdf
|
|
||||||
*.ldf
|
|
||||||
*.ndf
|
|
||||||
|
|
||||||
# Business Intelligence projects
|
|
||||||
*.rdl.data
|
|
||||||
*.bim.layout
|
|
||||||
*.bim_*.settings
|
|
||||||
*.rptproj.rsuser
|
|
||||||
*- [Bb]ackup.rdl
|
|
||||||
*- [Bb]ackup ([0-9]).rdl
|
|
||||||
*- [Bb]ackup ([0-9][0-9]).rdl
|
|
||||||
|
|
||||||
# Microsoft Fakes
|
|
||||||
FakesAssemblies/
|
|
||||||
|
|
||||||
# GhostDoc plugin setting file
|
|
||||||
*.GhostDoc.xml
|
|
||||||
|
|
||||||
# Node.js Tools for Visual Studio
|
|
||||||
.ntvs_analysis.dat
|
|
||||||
node_modules/
|
|
||||||
|
|
||||||
# Visual Studio 6 build log
|
|
||||||
*.plg
|
|
||||||
|
|
||||||
# Visual Studio 6 workspace options file
|
|
||||||
*.opt
|
|
||||||
|
|
||||||
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
|
||||||
*.vbw
|
|
||||||
|
|
||||||
# Visual Studio LightSwitch build output
|
|
||||||
**/*.HTMLClient/GeneratedArtifacts
|
|
||||||
**/*.DesktopClient/GeneratedArtifacts
|
|
||||||
**/*.DesktopClient/ModelManifest.xml
|
|
||||||
**/*.Server/GeneratedArtifacts
|
|
||||||
**/*.Server/ModelManifest.xml
|
|
||||||
_Pvt_Extensions
|
|
||||||
|
|
||||||
# Paket dependency manager
|
|
||||||
.paket/paket.exe
|
|
||||||
paket-files/
|
|
||||||
|
|
||||||
# FAKE - F# Make
|
|
||||||
.fake/
|
|
||||||
|
|
||||||
# CodeRush personal settings
|
|
||||||
.cr/personal
|
|
||||||
|
|
||||||
# Python Tools for Visual Studio (PTVS)
|
|
||||||
__pycache__/
|
|
||||||
*.pyc
|
|
||||||
|
|
||||||
# Cake - Uncomment if you are using it
|
|
||||||
# tools/**
|
|
||||||
# !tools/packages.config
|
|
||||||
|
|
||||||
# Tabs Studio
|
|
||||||
*.tss
|
|
||||||
|
|
||||||
# Telerik's JustMock configuration file
|
|
||||||
*.jmconfig
|
|
||||||
|
|
||||||
# BizTalk build output
|
|
||||||
*.btp.cs
|
|
||||||
*.btm.cs
|
|
||||||
*.odx.cs
|
|
||||||
*.xsd.cs
|
|
||||||
|
|
||||||
# OpenCover UI analysis results
|
|
||||||
OpenCover/
|
|
||||||
|
|
||||||
# Azure Stream Analytics local run output
|
|
||||||
ASALocalRun/
|
|
||||||
|
|
||||||
# MSBuild Binary and Structured Log
|
|
||||||
*.binlog
|
|
||||||
|
|
||||||
# NVidia Nsight GPU debugger configuration file
|
|
||||||
*.nvuser
|
|
||||||
|
|
||||||
# MFractors (Xamarin productivity tool) working folder
|
|
||||||
.mfractor/
|
|
||||||
|
|
||||||
# Local History for Visual Studio
|
|
||||||
.localhistory/
|
|
||||||
|
|
||||||
# BeatPulse healthcheck temp database
|
|
||||||
healthchecksdb
|
|
||||||
|
|
||||||
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
|
||||||
MigrationBackup/
|
|
||||||
|
|
||||||
# Ionide (cross platform F# VS Code tools) working folder
|
|
||||||
.ionide/
|
|
||||||
|
|
||||||
/vcproj/nsg/x64/Debug/nsg.Build.CppClean.log
|
|
||||||
/vcproj/test_recall/x64/Debug/test_recall.Build.CppClean.log
|
|
||||||
/vcproj/test_recall/test_recall.vcxproj.user
|
|
||||||
/.vs
|
|
||||||
/out/build/x64-Debug
|
|
||||||
cscope*
|
|
||||||
|
|
||||||
build/
|
|
||||||
build_linux/
|
|
||||||
!.github/actions/build
|
|
||||||
|
|
||||||
# jetbrains specific stuff
|
|
||||||
.idea/
|
|
||||||
cmake-build-debug/
|
|
||||||
|
|
||||||
#python extension module ignores
|
|
||||||
python/diskannpy.egg-info/
|
|
||||||
python/dist/
|
|
||||||
|
|
||||||
**/*.egg-info
|
|
||||||
wheelhouse/*
|
|
||||||
dist/*
|
|
||||||
venv*/**
|
|
||||||
*.swp
|
|
||||||
|
|
||||||
gperftools
|
|
||||||
|
|
||||||
# Rust
|
|
||||||
rust/target
|
|
||||||
|
|
||||||
python/src/*.so
|
|
||||||
|
|
||||||
compile_commands.json
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
[submodule "gperftools"]
|
|
||||||
path = gperftools
|
|
||||||
url = https://github.com/gperftools/gperftools.git
|
|
||||||
@@ -1,563 +0,0 @@
|
|||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
# Parameters:
|
|
||||||
#
|
|
||||||
# BOOST_ROOT:
|
|
||||||
# Specify root of the Boost library if Boost cannot be auto-detected. On Windows, a fallback to a
|
|
||||||
# downloaded nuget version will be used if Boost cannot be found.
|
|
||||||
#
|
|
||||||
# DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS:
|
|
||||||
# This is a work-in-progress feature, not completed yet. The core DiskANN library will be split into
|
|
||||||
# build-related and search-related functionality. In build-related functionality, when using tcmalloc,
|
|
||||||
# it's possible to release memory that's free but reserved by tcmalloc. Setting this to true enables
|
|
||||||
# such behavior.
|
|
||||||
# Contact for this feature: gopalrs.
|
|
||||||
|
|
||||||
|
|
||||||
# Some variables like MSVC are defined only after project(), so put that first.
|
|
||||||
cmake_minimum_required(VERSION 3.20)
|
|
||||||
project(diskann)
|
|
||||||
|
|
||||||
#Set option to use tcmalloc
|
|
||||||
option(USE_TCMALLOC "Use tcmalloc from gperftools" ON)
|
|
||||||
|
|
||||||
# set tcmalloc to false when on macos
|
|
||||||
if(APPLE)
|
|
||||||
set(USE_TCMALLOC OFF)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
option(PYBIND "Build with Python bindings" ON)
|
|
||||||
|
|
||||||
if(PYBIND)
|
|
||||||
# Find Python
|
|
||||||
find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
|
|
||||||
execute_process(
|
|
||||||
COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())"
|
|
||||||
OUTPUT_VARIABLE pybind11_DIR
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
find_package(pybind11 CONFIG REQUIRED)
|
|
||||||
|
|
||||||
message(STATUS "Python include dirs: ${Python_INCLUDE_DIRS}")
|
|
||||||
message(STATUS "Pybind11 include dirs: ${pybind11_INCLUDE_DIRS}")
|
|
||||||
|
|
||||||
# Add pybind11 include directories
|
|
||||||
include_directories(SYSTEM ${pybind11_INCLUDE_DIRS} ${Python_INCLUDE_DIRS})
|
|
||||||
|
|
||||||
# Add compilation definitions
|
|
||||||
add_definitions(-DPYBIND11_EMBEDDED)
|
|
||||||
|
|
||||||
# Set visibility flags
|
|
||||||
if(NOT MSVC)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(CMAKE_STANDARD 17)
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
||||||
|
|
||||||
# if(NOT MSVC)
|
|
||||||
# set(CMAKE_CXX_COMPILER g++)
|
|
||||||
# endif()
|
|
||||||
|
|
||||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
|
|
||||||
|
|
||||||
# Install nuget packages for dependencies.
|
|
||||||
if (MSVC)
|
|
||||||
find_program(NUGET_EXE NAMES nuget)
|
|
||||||
|
|
||||||
if (NOT NUGET_EXE)
|
|
||||||
message(FATAL_ERROR "Cannot find nuget command line tool.\nPlease install it from e.g. https://www.nuget.org/downloads")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(DISKANN_MSVC_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/packages.config)
|
|
||||||
set(DISKANN_MSVC_PACKAGES ${CMAKE_BINARY_DIR}/packages)
|
|
||||||
|
|
||||||
message(STATUS "Invoking nuget to download Boost, OpenMP and MKL dependencies...")
|
|
||||||
configure_file(${PROJECT_SOURCE_DIR}/windows/packages.config.in ${DISKANN_MSVC_PACKAGES_CONFIG})
|
|
||||||
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
|
|
||||||
if (RESTAPI)
|
|
||||||
set(DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/restapi/packages.config)
|
|
||||||
configure_file(${PROJECT_SOURCE_DIR}/windows/packages_restapi.config.in ${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG})
|
|
||||||
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
|
|
||||||
endif()
|
|
||||||
message(STATUS "Finished setting up nuget dependencies")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
|
||||||
|
|
||||||
include(FetchContent)
|
|
||||||
|
|
||||||
if(USE_TCMALLOC)
|
|
||||||
FetchContent_Declare(
|
|
||||||
tcmalloc
|
|
||||||
GIT_REPOSITORY https://github.com/google/tcmalloc.git
|
|
||||||
GIT_TAG origin/master # or specify a particular version or commit
|
|
||||||
)
|
|
||||||
|
|
||||||
FetchContent_MakeAvailable(tcmalloc)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(NOT PYBIND)
|
|
||||||
set(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS ON)
|
|
||||||
endif()
|
|
||||||
# It's necessary to include tcmalloc headers only if calling into MallocExtension interface.
|
|
||||||
# For using tcmalloc in DiskANN tools, it's enough to just link with tcmalloc.
|
|
||||||
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
|
||||||
include_directories(${tcmalloc_SOURCE_DIR}/src)
|
|
||||||
if (MSVC)
|
|
||||||
include_directories(${tcmalloc_SOURCE_DIR}/src/windows)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#OpenMP
|
|
||||||
if (MSVC)
|
|
||||||
# Do not use find_package here since it would use VisualStudio's built-in OpenMP, but MKL libraries
|
|
||||||
# refer to Intel's OpenMP.
|
|
||||||
#
|
|
||||||
# No extra settings are needed for compilation: it only needs /openmp flag which is set further below,
|
|
||||||
# in the common MSVC compiler options block.
|
|
||||||
include_directories(BEFORE "${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/include")
|
|
||||||
link_libraries("${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/win-x64/libiomp5md.lib")
|
|
||||||
|
|
||||||
set(OPENMP_WINDOWS_RUNTIME_FILES
|
|
||||||
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.dll"
|
|
||||||
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.pdb")
|
|
||||||
elseif(APPLE)
|
|
||||||
# Check if we're building Python bindings
|
|
||||||
if(PYBIND)
|
|
||||||
# First look for PyTorch's OpenMP to avoid conflicts
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${Python_EXECUTABLE} -c "import os; import torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib', 'libomp.dylib'))"
|
|
||||||
RESULT_VARIABLE TORCH_PATH_RESULT
|
|
||||||
OUTPUT_VARIABLE TORCH_LIBOMP_PATH
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
ERROR_QUIET
|
|
||||||
)
|
|
||||||
|
|
||||||
execute_process(
|
|
||||||
COMMAND brew --prefix libomp
|
|
||||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
|
|
||||||
if(EXISTS "${TORCH_LIBOMP_PATH}")
|
|
||||||
message(STATUS "Found PyTorch's libomp: ${TORCH_LIBOMP_PATH}")
|
|
||||||
set(OpenMP_CXX_FLAGS "-Xclang -fopenmp")
|
|
||||||
set(OpenMP_C_FLAGS "-Xclang -fopenmp")
|
|
||||||
set(OpenMP_CXX_LIBRARIES "${TORCH_LIBOMP_PATH}")
|
|
||||||
set(OpenMP_C_LIBRARIES "${TORCH_LIBOMP_PATH}")
|
|
||||||
set(OpenMP_FOUND TRUE)
|
|
||||||
|
|
||||||
include_directories(${LIBOMP_ROOT}/include)
|
|
||||||
|
|
||||||
# Set compiler flags and link libraries
|
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
|
||||||
link_libraries("${TORCH_LIBOMP_PATH}")
|
|
||||||
else()
|
|
||||||
message(STATUS "No PyTorch's libomp found, falling back to normal OpenMP detection")
|
|
||||||
# Fallback to normal OpenMP detection
|
|
||||||
execute_process(
|
|
||||||
COMMAND brew --prefix libomp
|
|
||||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
|
|
||||||
set(OpenMP_ROOT "${LIBOMP_ROOT}")
|
|
||||||
find_package(OpenMP)
|
|
||||||
|
|
||||||
if (OPENMP_FOUND)
|
|
||||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
|
||||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
|
||||||
link_libraries(OpenMP::OpenMP_CXX)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "No OpenMP support")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
# Regular OpenMP setup for non-Python builds
|
|
||||||
execute_process(
|
|
||||||
COMMAND brew --prefix libomp
|
|
||||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
set(OpenMP_ROOT "${LIBOMP_ROOT}")
|
|
||||||
find_package(OpenMP)
|
|
||||||
|
|
||||||
if (OPENMP_FOUND)
|
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
|
||||||
link_libraries(OpenMP::OpenMP_CXX)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "No OpenMP support")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
find_package(OpenMP)
|
|
||||||
|
|
||||||
if (OPENMP_FOUND)
|
|
||||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
|
||||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "No OpenMP support")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# DiskANN core uses header-only libraries. Only DiskANN tools need program_options which has a linker library,
|
|
||||||
# but its size is small. Reduce number of dependent DLLs by linking statically.
|
|
||||||
if (MSVC)
|
|
||||||
set(Boost_USE_STATIC_LIBS ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(NOT MSVC)
|
|
||||||
find_package(Boost COMPONENTS program_options)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# For Windows, fall back to nuget version if find_package didn't find it.
|
|
||||||
if (MSVC AND NOT Boost_FOUND)
|
|
||||||
set(DISKANN_BOOST_INCLUDE "${DISKANN_MSVC_PACKAGES}/boost/lib/native/include")
|
|
||||||
# Multi-threaded static library.
|
|
||||||
set(PROGRAM_OPTIONS_LIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-x64-*.lib")
|
|
||||||
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_LIB ${PROGRAM_OPTIONS_LIB_PATTERN})
|
|
||||||
|
|
||||||
set(PROGRAM_OPTIONS_DLIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-gd-x64-*.lib")
|
|
||||||
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_DLIB ${PROGRAM_OPTIONS_DLIB_PATTERN})
|
|
||||||
|
|
||||||
if (EXISTS ${DISKANN_BOOST_INCLUDE} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_LIB} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB})
|
|
||||||
set(Boost_FOUND ON)
|
|
||||||
set(Boost_INCLUDE_DIR ${DISKANN_BOOST_INCLUDE})
|
|
||||||
add_library(Boost::program_options STATIC IMPORTED)
|
|
||||||
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_RELEASE "${DISKANN_BOOST_PROGRAM_OPTIONS_LIB}")
|
|
||||||
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_DEBUG "${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB}")
|
|
||||||
message(STATUS "Falling back to using Boost from the nuget package")
|
|
||||||
else()
|
|
||||||
message(WARNING "Couldn't find Boost. Was looking for ${DISKANN_BOOST_INCLUDE} and ${PROGRAM_OPTIONS_LIB_PATTERN}")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (NOT Boost_FOUND)
|
|
||||||
message(FATAL_ERROR "Couldn't find Boost dependency")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include_directories(${Boost_INCLUDE_DIR})
|
|
||||||
|
|
||||||
#MKL Config
|
|
||||||
if (MSVC)
|
|
||||||
# Only the DiskANN DLL and one of the tools need MKL libraries. Additionally, only a small part of MKL is used.
|
|
||||||
# Given that and given that MKL DLLs are huge, use static linking to end up with no MKL DLL dependencies and with
|
|
||||||
# significantly smaller disk footprint.
|
|
||||||
#
|
|
||||||
# The compile options are not modified as there's already an unconditional -DMKL_ILP64 define below
|
|
||||||
# for all architectures, which is all that's needed.
|
|
||||||
set(DISKANN_MKL_INCLUDE_DIRECTORIES "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/include")
|
|
||||||
set(DISKANN_MKL_LIB_PATH "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/win-x64")
|
|
||||||
|
|
||||||
set(DISKANN_MKL_LINK_LIBRARIES
|
|
||||||
"${DISKANN_MKL_LIB_PATH}/mkl_intel_ilp64.lib"
|
|
||||||
"${DISKANN_MKL_LIB_PATH}/mkl_core.lib"
|
|
||||||
"${DISKANN_MKL_LIB_PATH}/mkl_intel_thread.lib")
|
|
||||||
elseif(APPLE)
|
|
||||||
# no mkl on non-intel devices
|
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
|
||||||
message(STATUS "Found Accelerate (${ACCELERATE_LIBRARY})")
|
|
||||||
set(DISKANN_ACCEL_LINK_OPTIONS ${ACCELERATE_LIBRARY})
|
|
||||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
|
||||||
else()
|
|
||||||
# expected path for manual intel mkl installs
|
|
||||||
set(POSSIBLE_OMP_PATHS "/opt/intel/oneapi/compiler/2025.0/lib/libiomp5.so;/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin/libiomp5.so;/usr/lib/x86_64-linux-gnu/libiomp5.so;/opt/intel/lib/intel64_lin/libiomp5.so")
|
|
||||||
foreach(POSSIBLE_OMP_PATH ${POSSIBLE_OMP_PATHS})
|
|
||||||
if (EXISTS ${POSSIBLE_OMP_PATH})
|
|
||||||
get_filename_component(OMP_PATH ${POSSIBLE_OMP_PATH} DIRECTORY)
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
if(NOT OMP_PATH)
|
|
||||||
message(FATAL_ERROR "Could not find Intel OMP in standard locations; use -DOMP_PATH to specify the install location for your environment")
|
|
||||||
endif()
|
|
||||||
link_directories(${OMP_PATH})
|
|
||||||
|
|
||||||
set(POSSIBLE_MKL_LIB_PATHS "/opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so;/usr/lib/x86_64-linux-gnu/libmkl_core.so;/opt/intel/mkl/lib/intel64/libmkl_core.so")
|
|
||||||
foreach(POSSIBLE_MKL_LIB_PATH ${POSSIBLE_MKL_LIB_PATHS})
|
|
||||||
if (EXISTS ${POSSIBLE_MKL_LIB_PATH})
|
|
||||||
get_filename_component(MKL_PATH ${POSSIBLE_MKL_LIB_PATH} DIRECTORY)
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
set(POSSIBLE_MKL_INCLUDE_PATHS "/opt/intel/oneapi/mkl/latest/include;/usr/include/mkl;/opt/intel/mkl/include/;")
|
|
||||||
foreach(POSSIBLE_MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATHS})
|
|
||||||
if (EXISTS ${POSSIBLE_MKL_INCLUDE_PATH})
|
|
||||||
set(MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATH})
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
if(NOT MKL_PATH)
|
|
||||||
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_PATH to specify the install location for your environment")
|
|
||||||
elseif(NOT MKL_INCLUDE_PATH)
|
|
||||||
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_INCLUDE_PATH to specify the install location for headers for your environment")
|
|
||||||
endif()
|
|
||||||
if (EXISTS ${MKL_PATH}/libmkl_def.so.2)
|
|
||||||
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so.2)
|
|
||||||
elseif(EXISTS ${MKL_PATH}/libmkl_def.so)
|
|
||||||
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Despite finding MKL, libmkl_def.so was not found in expected locations.")
|
|
||||||
endif()
|
|
||||||
link_directories(${MKL_PATH})
|
|
||||||
include_directories(${MKL_INCLUDE_PATH})
|
|
||||||
|
|
||||||
# compile flags and link libraries
|
|
||||||
# if gcc/g++
|
|
||||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
|
||||||
add_compile_options(-m64 -Wl,--no-as-needed)
|
|
||||||
endif()
|
|
||||||
if (NOT PYBIND)
|
|
||||||
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core iomp5 pthread m dl)
|
|
||||||
else()
|
|
||||||
# static linking for python so as to minimize customer dependency issues
|
|
||||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
|
||||||
# In debug mode, use dynamic linking to ensure all symbols are available
|
|
||||||
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core ${MKL_DEF_SO} iomp5 pthread m dl)
|
|
||||||
else()
|
|
||||||
# In release mode, use static linking to minimize dependencies
|
|
||||||
link_libraries(
|
|
||||||
${MKL_PATH}/libmkl_intel_ilp64.a
|
|
||||||
${MKL_PATH}/libmkl_intel_thread.a
|
|
||||||
${MKL_PATH}/libmkl_core.a
|
|
||||||
${MKL_DEF_SO}
|
|
||||||
iomp5
|
|
||||||
pthread
|
|
||||||
m
|
|
||||||
dl
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_definitions(-DMKL_ILP64)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
|
|
||||||
# Section for tcmalloc. The DiskANN tools are always linked to tcmalloc. For Windows, they also need to
|
|
||||||
# force-include the _tcmalloc symbol for enabling tcmalloc.
|
|
||||||
#
|
|
||||||
# The DLL itself needs to be linked to tcmalloc only if DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS
|
|
||||||
# is enabled.
|
|
||||||
if(USE_TCMALLOC)
|
|
||||||
if (MSVC)
|
|
||||||
if (NOT EXISTS "${PROJECT_SOURCE_DIR}/gperftools/gperftools.sln")
|
|
||||||
message(FATAL_ERROR "The gperftools submodule was not found. "
|
|
||||||
"Please check-out git submodules by doing 'git submodule init' followed by 'git submodule update'")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(TCMALLOC_LINK_LIBRARY "${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.lib")
|
|
||||||
set(TCMALLOC_WINDOWS_RUNTIME_FILES
|
|
||||||
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.dll"
|
|
||||||
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.pdb")
|
|
||||||
|
|
||||||
# Tell CMake how to build the tcmalloc linker library from the submodule.
|
|
||||||
add_custom_target(build_libtcmalloc_minimal DEPENDS ${TCMALLOC_LINK_LIBRARY})
|
|
||||||
add_custom_command(OUTPUT ${TCMALLOC_LINK_LIBRARY}
|
|
||||||
COMMAND ${CMAKE_VS_MSBUILD_COMMAND} gperftools.sln /m /nologo
|
|
||||||
/t:libtcmalloc_minimal /p:Configuration="Release-Patch"
|
|
||||||
/property:Platform="x64"
|
|
||||||
/p:PlatformToolset=v${MSVC_TOOLSET_VERSION}
|
|
||||||
/p:WindowsTargetPlatformVersion=${CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION}
|
|
||||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/gperftools)
|
|
||||||
|
|
||||||
add_library(libtcmalloc_minimal_for_exe STATIC IMPORTED)
|
|
||||||
add_library(libtcmalloc_minimal_for_dll STATIC IMPORTED)
|
|
||||||
|
|
||||||
set_target_properties(libtcmalloc_minimal_for_dll PROPERTIES
|
|
||||||
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}")
|
|
||||||
|
|
||||||
set_target_properties(libtcmalloc_minimal_for_exe PROPERTIES
|
|
||||||
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}"
|
|
||||||
INTERFACE_LINK_OPTIONS /INCLUDE:_tcmalloc)
|
|
||||||
|
|
||||||
# Ensure libtcmalloc_minimal is built before it's being used.
|
|
||||||
add_dependencies(libtcmalloc_minimal_for_dll build_libtcmalloc_minimal)
|
|
||||||
add_dependencies(libtcmalloc_minimal_for_exe build_libtcmalloc_minimal)
|
|
||||||
|
|
||||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_exe)
|
|
||||||
elseif(APPLE) # ! Inherited from #474, not been adjusted for TCMalloc Removal
|
|
||||||
execute_process(
|
|
||||||
COMMAND brew --prefix gperftools
|
|
||||||
OUTPUT_VARIABLE GPERFTOOLS_PREFIX
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
||||||
)
|
|
||||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-L${GPERFTOOLS_PREFIX}/lib -ltcmalloc")
|
|
||||||
elseif(NOT PYBIND)
|
|
||||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-ltcmalloc")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
|
||||||
add_definitions(-DRELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
|
||||||
|
|
||||||
if (MSVC)
|
|
||||||
set(DISKANN_DLL_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_dll)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (NOT MSVC AND NOT APPLE)
|
|
||||||
set(DISKANN_ASYNC_LIB aio)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#Main compiler/linker settings
|
|
||||||
if(MSVC)
|
|
||||||
#language options
|
|
||||||
add_compile_options(/permissive- /openmp:experimental /Zc:twoPhase- /Zc:inline /WX- /std:c++17 /Gd /W3 /MP /Zi /FC /nologo)
|
|
||||||
#code generation options
|
|
||||||
add_compile_options(/arch:AVX2 /fp:fast /fp:except- /EHsc /GS- /Gy)
|
|
||||||
#optimization options
|
|
||||||
add_compile_options(/Ot /Oy /Oi)
|
|
||||||
#path options
|
|
||||||
add_definitions(-DUSE_AVX2 -DUSE_ACCELERATED_PQ -D_WINDOWS -DNOMINMAX -DUNICODE)
|
|
||||||
# Linker options. Exclude VCOMP/VCOMPD.LIB which contain VisualStudio's version of OpenMP.
|
|
||||||
# MKL was linked against Intel's OpenMP and depends on the corresponding DLL.
|
|
||||||
add_link_options(/NODEFAULTLIB:VCOMP.LIB /NODEFAULTLIB:VCOMPD.LIB /DEBUG:FULL /OPT:REF /OPT:ICF)
|
|
||||||
|
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
|
||||||
|
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
|
||||||
elseif(APPLE)
|
|
||||||
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -Xclang -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -Wno-inconsistent-missing-override -Wno-return-type")
|
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DNDEBUG -ftree-vectorize")
|
|
||||||
if (NOT PYBIND)
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
|
|
||||||
if (NOT PORTABLE)
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -mtune=native")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
# -Ofast is not supported in a python extension module
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -fPIC")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma -msse2 -ftree-vectorize -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_AVX2 -fPIC")
|
|
||||||
if(USE_TCMALLOC)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free")
|
|
||||||
endif()
|
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
|
|
||||||
if (NOT PYBIND)
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
|
|
||||||
if (NOT PORTABLE)
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -mtune=native")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
# -Ofast is not supported in a python extension module
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_subdirectory(src)
|
|
||||||
if (NOT PYBIND)
|
|
||||||
add_subdirectory(apps)
|
|
||||||
add_subdirectory(apps/utils)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (UNIT_TEST)
|
|
||||||
enable_testing()
|
|
||||||
add_subdirectory(tests)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (MSVC)
|
|
||||||
message(STATUS "The ${PROJECT_NAME}.sln has been created, opened it from VisualStudio to build Release or Debug configurations.\n"
|
|
||||||
"Alternatively, use MSBuild to build:\n\n"
|
|
||||||
"msbuild.exe ${PROJECT_NAME}.sln /m /nologo /t:Build /p:Configuration=\"Release\" /property:Platform=\"x64\"\n")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (RESTAPI)
|
|
||||||
if (MSVC)
|
|
||||||
set(DISKANN_CPPRESTSDK "${DISKANN_MSVC_PACKAGES}/cpprestsdk.v142/build/native")
|
|
||||||
# expected path for apt packaged intel mkl installs
|
|
||||||
link_libraries("${DISKANN_CPPRESTSDK}/x64/lib/cpprest142_2_10.lib")
|
|
||||||
include_directories("${DISKANN_CPPRESTSDK}/include")
|
|
||||||
endif()
|
|
||||||
add_subdirectory(apps/restapi)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include(clang-format.cmake)
|
|
||||||
|
|
||||||
if(PYBIND)
|
|
||||||
add_subdirectory(python)
|
|
||||||
|
|
||||||
install(TARGETS _diskannpy
|
|
||||||
DESTINATION leann_backend_diskann
|
|
||||||
COMPONENT python_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
endif()
|
|
||||||
###############################################################################
|
|
||||||
# PROTOBUF SECTION - Corrected to use CONFIG mode explicitly
|
|
||||||
###############################################################################
|
|
||||||
set(Protobuf_USE_STATIC_LIBS OFF)
|
|
||||||
|
|
||||||
find_package(ZLIB REQUIRED)
|
|
||||||
|
|
||||||
find_package(Protobuf REQUIRED)
|
|
||||||
|
|
||||||
message(STATUS "Protobuf found: ${Protobuf_VERSION}")
|
|
||||||
message(STATUS "Protobuf include dirs: ${Protobuf_INCLUDE_DIRS}")
|
|
||||||
message(STATUS "Protobuf libraries: ${Protobuf_LIBRARIES}")
|
|
||||||
message(STATUS "Protobuf protoc executable: ${Protobuf_PROTOC_EXECUTABLE}")
|
|
||||||
|
|
||||||
include_directories(${Protobuf_INCLUDE_DIRS})
|
|
||||||
|
|
||||||
set(PROTO_FILE "${CMAKE_CURRENT_SOURCE_DIR}/../embedding.proto")
|
|
||||||
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
|
|
||||||
set(generated_proto_sources ${PROTO_SRCS})
|
|
||||||
|
|
||||||
|
|
||||||
add_library(proto_embeddings STATIC ${generated_proto_sources})
|
|
||||||
target_link_libraries(proto_embeddings PUBLIC protobuf::libprotobuf)
|
|
||||||
target_include_directories(proto_embeddings PUBLIC
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
|
||||||
${Protobuf_INCLUDE_DIRS}
|
|
||||||
)
|
|
||||||
|
|
||||||
target_link_libraries(diskann PRIVATE proto_embeddings protobuf::libprotobuf)
|
|
||||||
target_include_directories(diskann PRIVATE
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
|
||||||
${Protobuf_INCLUDE_DIRS}
|
|
||||||
)
|
|
||||||
|
|
||||||
target_link_libraries(diskann_s PRIVATE proto_embeddings protobuf::libprotobuf)
|
|
||||||
target_include_directories(diskann_s PRIVATE
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
|
||||||
${Protobuf_INCLUDE_DIRS}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
# ZEROMQ SECTION - REQUIRED
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
find_package(ZeroMQ QUIET)
|
|
||||||
if(NOT ZeroMQ_FOUND)
|
|
||||||
find_path(ZeroMQ_INCLUDE_DIR zmq.h)
|
|
||||||
find_library(ZeroMQ_LIBRARY zmq)
|
|
||||||
if(ZeroMQ_INCLUDE_DIR AND ZeroMQ_LIBRARY)
|
|
||||||
set(ZeroMQ_FOUND TRUE)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(ZeroMQ_FOUND)
|
|
||||||
message(STATUS "Found ZeroMQ: ${ZeroMQ_LIBRARY}")
|
|
||||||
include_directories(${ZeroMQ_INCLUDE_DIR})
|
|
||||||
target_link_libraries(diskann PRIVATE ${ZeroMQ_LIBRARY})
|
|
||||||
target_link_libraries(diskann_s PRIVATE ${ZeroMQ_LIBRARY})
|
|
||||||
add_definitions(-DUSE_ZEROMQ)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "ZeroMQ is required but not found. Please install ZeroMQ and try again.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
target_link_libraries(diskann ${PYBIND11_LIBRARIES})
|
|
||||||
target_link_libraries(diskann_s ${PYBIND11_LIBRARIES})
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
{
|
|
||||||
"configurations": [
|
|
||||||
{
|
|
||||||
"name": "x64-Release",
|
|
||||||
"generator": "Ninja",
|
|
||||||
"configurationType": "Release",
|
|
||||||
"inheritEnvironments": [ "msvc_x64" ],
|
|
||||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
|
||||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
|
||||||
"cmakeCommandArgs": "",
|
|
||||||
"buildCommandArgs": "",
|
|
||||||
"ctestCommandArgs": ""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "WSL-GCC-Release",
|
|
||||||
"generator": "Ninja",
|
|
||||||
"configurationType": "RelWithDebInfo",
|
|
||||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
|
||||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
|
||||||
"cmakeExecutable": "cmake",
|
|
||||||
"cmakeCommandArgs": "",
|
|
||||||
"buildCommandArgs": "",
|
|
||||||
"ctestCommandArgs": "",
|
|
||||||
"inheritEnvironments": [ "linux_x64" ],
|
|
||||||
"wslPath": "${defaultWSLPath}"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
# Microsoft Open Source Code of Conduct
|
|
||||||
|
|
||||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
|
||||||
|
|
||||||
Resources:
|
|
||||||
|
|
||||||
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
|
||||||
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
|
||||||
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
# Contributing
|
|
||||||
|
|
||||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
|
||||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
|
||||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
|
||||||
|
|
||||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
|
||||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
|
||||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
#Copyright(c) Microsoft Corporation.All rights reserved.
|
|
||||||
#Licensed under the MIT license.
|
|
||||||
|
|
||||||
FROM ubuntu:jammy
|
|
||||||
|
|
||||||
RUN apt update
|
|
||||||
RUN apt install -y software-properties-common
|
|
||||||
RUN add-apt-repository -y ppa:git-core/ppa
|
|
||||||
RUN apt update
|
|
||||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev libcpprest-dev python3.10
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
RUN git clone https://github.com/microsoft/DiskANN.git
|
|
||||||
WORKDIR /app/DiskANN
|
|
||||||
RUN mkdir build
|
|
||||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
|
|
||||||
RUN cmake --build build -- -j
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user