Compare commits
332 Commits
debug_disk
...
refactor-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0877960547 | ||
|
|
d68af63d05 | ||
|
|
b844aca968 | ||
|
|
85277ba67a | ||
|
|
e9562acdc2 | ||
|
|
7fd3db1ddb | ||
|
|
c1ccc51a75 | ||
|
|
b0239b6e4d | ||
|
|
58556ef44c | ||
|
|
87c930d705 | ||
|
|
86f919a6da | ||
|
|
f8d34663b4 | ||
|
|
568cf597f4 | ||
|
|
baf70dc411 | ||
|
|
7ad2ec39d6 | ||
|
|
31fd3c816a | ||
|
|
1f6c7f2f5a | ||
|
|
c1124eb349 | ||
|
|
274bbb19ea | ||
|
|
8c152c7a31 | ||
|
|
ce77eef13a | ||
|
|
9d77175ac8 | ||
|
|
7fbb6c98ef | ||
|
|
914a248c28 | ||
|
|
55fc5862f9 | ||
|
|
fd97b8dfa8 | ||
|
|
57959947a1 | ||
|
|
cc0c091ca5 | ||
|
|
ff389c7d8d | ||
|
|
6780a8eaba | ||
|
|
984056f126 | ||
|
|
bd4451bf50 | ||
|
|
34e313f64a | ||
|
|
ddc789b231 | ||
|
|
ff1b622bdd | ||
|
|
3cde4fc7b3 | ||
|
|
4e3bcda5fa | ||
|
|
46f6f76fc3 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 | ||
|
|
5d09586853 | ||
|
|
a7cba078dd | ||
|
|
b3e9ee96fa | ||
|
|
8537a6b17e | ||
|
|
7c8d7dc5c2 | ||
|
|
8e23d663e6 | ||
|
|
8a3994bf80 | ||
|
|
8375f601ba | ||
|
|
c87c0fe662 | ||
|
|
73927b68ef | ||
|
|
cc1a62e5aa | ||
|
|
802020cb41 | ||
|
|
cdb92f7cf4 | ||
|
|
dc69bdec00 | ||
|
|
98073e9868 | ||
|
|
cf2ef48967 | ||
|
|
0692bbf7a2 | ||
|
|
52584a171f | ||
|
|
efd6b5324b | ||
|
|
2baaa4549b | ||
|
|
35310ddd52 | ||
|
|
fc9c5cb39d | ||
|
|
8f2a1e87ea | ||
|
|
50caf65f28 | ||
|
|
1b48794ca8 | ||
|
|
4aef1d814e | ||
|
|
75ddcd6158 | ||
|
|
2a4df11f5c | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 | ||
|
|
a97d3ada1c | ||
|
|
6217bb5638 | ||
|
|
2760e99e18 | ||
|
|
0544f96b79 | ||
|
|
2ebb29de65 | ||
|
|
43762d44c7 | ||
|
|
cdaf0c98be | ||
|
|
aa9a14a917 | ||
|
|
9efcc6d95c | ||
|
|
f3f5d91207 | ||
|
|
6070160959 | ||
|
|
43155d2811 | ||
|
|
d3f85678ec | ||
|
|
2a96d05b21 | ||
|
|
851e888535 | ||
|
|
90120d4dff | ||
|
|
8513471573 | ||
|
|
71e5f1774c | ||
|
|
870a443446 | ||
|
|
cefaa2a4cc | ||
|
|
ab72a2ab9d | ||
|
|
046d457d22 | ||
|
|
7fd0a30fee | ||
|
|
c2f35c8e73 | ||
|
|
573313f0b6 | ||
|
|
f7af6805fa | ||
|
|
966de3a399 | ||
|
|
8a75829f3a | ||
|
|
0f7e34b9e2 | ||
|
|
be0322b616 | ||
|
|
232a525a62 | ||
|
|
587ce65cf6 | ||
|
|
ccf6c8bfd7 | ||
|
|
c112956d2d | ||
|
|
b3970793cf | ||
|
|
727724990e | ||
|
|
530f6e4af5 | ||
|
|
2f224f5793 | ||
|
|
1b6272ce0e | ||
|
|
5259ace111 | ||
|
|
48ea5566e9 | ||
|
|
3f8b6c5bbd | ||
|
|
725b32e74f | ||
|
|
d0b71f393f | ||
|
|
8a92efdae3 | ||
|
|
019cdce2e8 | ||
|
|
b64aa54fac | ||
|
|
c0d040f9d4 | ||
|
|
32364320f8 | ||
|
|
34c71c072d | ||
|
|
6d2149c503 | ||
|
|
043b0bf69d | ||
|
|
9b07e392c6 | ||
|
|
e60fad8c73 | ||
|
|
19c1b182c3 | ||
|
|
49edea780c | ||
|
|
12ef5a1900 | ||
|
|
d21a134b2a | ||
|
|
1cd809aa41 | ||
|
|
e728449b8f | ||
|
|
d0c20b14d5 | ||
|
|
83b7ea5a59 | ||
|
|
0796a52df1 | ||
|
|
85b7ba0168 | ||
|
|
e117743d24 | ||
|
|
aec2291f04 | ||
|
|
335ae003ac | ||
|
|
71c7de9c84 | ||
|
|
1c5fec5565 | ||
|
|
99d439577d | ||
|
|
4f83086788 | ||
|
|
a13c527e39 | ||
|
|
90d9f27383 | ||
|
|
0db81c16cd | ||
|
|
e115e186b7 | ||
|
|
6546b29ef7 | ||
|
|
51255bdffa | ||
|
|
f77c4e38cb | ||
|
|
2a1a152073 | ||
|
|
7b9406a3ea | ||
|
|
c3fb949693 | ||
|
|
ed3f8dbfd6 | ||
|
|
42aa6db170 | ||
|
|
a6591d20ca | ||
|
|
c1bc2603a2 | ||
|
|
e595bbb5fb | ||
|
|
4a2cb914d7 | ||
|
|
b1c93fe178 | ||
|
|
0719458775 | ||
|
|
6a1dc895fb | ||
|
|
125c1f6f25 | ||
|
|
1ceaa7d709 | ||
|
|
dec3ee85fd | ||
|
|
d94a5176dc | ||
|
|
326783f7f1 | ||
|
|
e5a9ca8787 | ||
|
|
f2feccdbd0 | ||
|
|
246a077d64 | ||
|
|
3ba100ff25 | ||
|
|
1e3b571e72 | ||
|
|
b89e56e9c2 | ||
|
|
ed8a02e721 | ||
|
|
baa60b40d1 | ||
|
|
ef01d6997a | ||
|
|
3da5b44d7f | ||
|
|
8b4654921b | ||
|
|
cf1cbafa78 | ||
|
|
c96091744b | ||
|
|
711fb4a775 | ||
|
|
3b5a185e60 | ||
|
|
77ac013a74 | ||
|
|
b8e5728e6a | ||
|
|
d038319d8b | ||
|
|
c611d0f30f | ||
|
|
c17899662f | ||
|
|
c51d5320fa | ||
|
|
6fa9512a64 | ||
|
|
fddc61df5e | ||
|
|
53c58fa755 | ||
|
|
c69afb56e4 | ||
|
|
0fa8a9191f | ||
|
|
48dda1cb5b | ||
|
|
71ef4b7d4c | ||
|
|
ecab43e307 | ||
|
|
88ca09440d | ||
|
|
8e0ab4a28d | ||
|
|
9b8c5041dc | ||
|
|
74ffd7ec64 | ||
|
|
eb6f504789 | ||
|
|
91a026f38b | ||
|
|
595138a0a3 | ||
|
|
19df04095f | ||
|
|
8239bbb48f | ||
|
|
16ee9d0422 | ||
|
|
8a961f8ab3 | ||
|
|
558126c46e | ||
|
|
04c9684488 | ||
|
|
b744faa7e6 | ||
|
|
27b3a26e75 | ||
|
|
41d872504e | ||
|
|
963cd05273 | ||
|
|
09b6e67baf | ||
|
|
dafb2aacab | ||
|
|
a6c400cd4f | ||
|
|
c013e5ccce | ||
|
|
f25a1a3840 | ||
|
|
6497e17671 | ||
|
|
44369a8138 | ||
|
|
dfca00c21b | ||
|
|
637dab379e | ||
|
|
6fc57eb48e | ||
|
|
95a653993a | ||
|
|
af0959818d | ||
|
|
cf17c85607 | ||
|
|
a38bc0a3fc | ||
|
|
449983c937 |
11
.github/workflows/build-and-publish.yml
vendored
Normal file
11
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
251
.github/workflows/build-reusable.yml
vendored
Normal file
251
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
name: Reusable Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: 'Git ref to build'
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint and Format Check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install ruff
|
||||||
|
run: |
|
||||||
|
uv tool install ruff
|
||||||
|
|
||||||
|
- name: Run ruff check
|
||||||
|
run: |
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
- name: Run ruff format check
|
||||||
|
run: |
|
||||||
|
ruff format --check .
|
||||||
|
|
||||||
|
build:
|
||||||
|
needs: lint
|
||||||
|
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.13'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||||
|
|
||||||
|
# Install Intel MKL for DiskANN
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install system dependencies (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Don't install LLVM, use system clang for better compatibility
|
||||||
|
brew install libomp boost protobuf zeromq
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||||
|
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||||
|
uv pip install --system auditwheel
|
||||||
|
else
|
||||||
|
uv pip install --system delocate
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Build packages
|
||||||
|
run: |
|
||||||
|
# Build core (platform independent)
|
||||||
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
|
cd packages/leann-core
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build HNSW backend
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
|
# Use system clang instead of homebrew LLVM for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=11.0
|
||||||
|
uv build --wheel --python python
|
||||||
|
else
|
||||||
|
uv build --wheel --python python
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build DiskANN backend
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
|
# Use system clang instead of homebrew LLVM for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
uv build --wheel --python python
|
||||||
|
else
|
||||||
|
uv build --wheel --python python
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build meta package (platform independent)
|
||||||
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
|
cd packages/leann
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Repair wheels (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: List built packages
|
||||||
|
run: |
|
||||||
|
echo "📦 Built packages:"
|
||||||
|
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||||
|
|
||||||
|
- name: Install built packages for testing
|
||||||
|
run: |
|
||||||
|
# Create a virtual environment
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Install the built wheels
|
||||||
|
# Use --find-links to let uv choose the correct wheel for the platform
|
||||||
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
|
uv pip install leann-core --find-links packages/leann-core/dist
|
||||||
|
uv pip install leann --find-links packages/leann/dist
|
||||||
|
fi
|
||||||
|
uv pip install leann-backend-hnsw --find-links packages/leann-backend-hnsw/dist
|
||||||
|
uv pip install leann-backend-diskann --find-links packages/leann-backend-diskann/dist
|
||||||
|
|
||||||
|
# Install test dependencies using extras
|
||||||
|
uv pip install -e ".[test]"
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
env:
|
||||||
|
CI: true # Mark as CI environment to skip memory-intensive tests
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
|
TOKENIZERS_PARALLELISM: false
|
||||||
|
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
||||||
|
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
||||||
|
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
||||||
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
pytest tests/
|
||||||
|
|
||||||
|
- name: Run sanity checks (optional)
|
||||||
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Run distance function tests if available
|
||||||
|
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||||
|
echo "Running distance function sanity checks..."
|
||||||
|
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
|
path: packages/*/dist/
|
||||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
name: Link Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 3 * * 1"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
link-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: lycheeverse/lychee-action@v2
|
||||||
|
with:
|
||||||
|
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
129
.github/workflows/release-manual.yml
vendored
Normal file
129
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: 'Version to release (e.g., 0.1.2)'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-version:
|
||||||
|
name: Update Version
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
outputs:
|
||||||
|
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Validate version
|
||||||
|
run: |
|
||||||
|
# Remove 'v' prefix if present for validation
|
||||||
|
VERSION_CLEAN="${{ inputs.version }}"
|
||||||
|
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
||||||
|
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||||
|
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
- name: Update versions and push
|
||||||
|
id: push
|
||||||
|
run: |
|
||||||
|
# Check current version
|
||||||
|
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||||
|
echo "Current version: $CURRENT_VERSION"
|
||||||
|
echo "Target version: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||||
|
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
else
|
||||||
|
./scripts/bump_version.sh ${{ inputs.version }}
|
||||||
|
git config user.name "GitHub Actions"
|
||||||
|
git config user.email "actions@github.com"
|
||||||
|
git add packages/*/pyproject.toml
|
||||||
|
git commit -m "chore: release v${{ inputs.version }}"
|
||||||
|
git push origin main
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
build-packages:
|
||||||
|
name: Build packages
|
||||||
|
needs: update-version
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish and Release
|
||||||
|
needs: [update-version, build-packages]
|
||||||
|
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: dist-artifacts
|
||||||
|
|
||||||
|
- name: Collect packages
|
||||||
|
run: |
|
||||||
|
mkdir -p dist
|
||||||
|
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||||
|
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||||
|
|
||||||
|
echo "📦 Packages to publish:"
|
||||||
|
ls -la dist/
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "$TWINE_PASSWORD" ]; then
|
||||||
|
echo "❌ PYPI_API_TOKEN not configured!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
pip install twine
|
||||||
|
twine upload dist/* --skip-existing --verbose
|
||||||
|
|
||||||
|
echo "✅ Published to PyPI!"
|
||||||
|
|
||||||
|
- name: Create release
|
||||||
|
run: |
|
||||||
|
# Check if tag already exists
|
||||||
|
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||||
|
else
|
||||||
|
git tag "v${{ inputs.version }}"
|
||||||
|
git push origin "v${{ inputs.version }}"
|
||||||
|
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if release already exists
|
||||||
|
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||||
|
else
|
||||||
|
gh release create "v${{ inputs.version }}" \
|
||||||
|
--title "Release v${{ inputs.version }}" \
|
||||||
|
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||||
|
--latest
|
||||||
|
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
25
.gitignore
vendored
25
.gitignore
vendored
@@ -8,11 +8,16 @@ demo/indices/
|
|||||||
*pycache*
|
*pycache*
|
||||||
outputs/
|
outputs/
|
||||||
*.pkl
|
*.pkl
|
||||||
|
*.pdf
|
||||||
|
*.idx
|
||||||
|
*.map
|
||||||
.history/
|
.history/
|
||||||
scripts/
|
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
|
*.eml
|
||||||
|
*.emlx
|
||||||
|
*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
@@ -29,7 +34,15 @@ build/
|
|||||||
nprobe_logs/
|
nprobe_logs/
|
||||||
micro/results
|
micro/results
|
||||||
micro/contriever-INT8
|
micro/contriever-INT8
|
||||||
examples/data/
|
data/*
|
||||||
|
!data/2501.14312v1 (1).pdf
|
||||||
|
!data/2506.08276v1.pdf
|
||||||
|
!data/PrideandPrejudice.txt
|
||||||
|
!data/README.md
|
||||||
|
!data/ground_truth/
|
||||||
|
!data/indices/
|
||||||
|
!data/queries/
|
||||||
|
!data/.gitattributes
|
||||||
*.qdstrm
|
*.qdstrm
|
||||||
benchmark_results/
|
benchmark_results/
|
||||||
results/
|
results/
|
||||||
@@ -42,6 +55,7 @@ embedding_comparison_results/
|
|||||||
*.ivecs
|
*.ivecs
|
||||||
*.index
|
*.index
|
||||||
*.bin
|
*.bin
|
||||||
|
*.old
|
||||||
|
|
||||||
read_graph
|
read_graph
|
||||||
analyze_diskann_graph
|
analyze_diskann_graph
|
||||||
@@ -71,3 +85,10 @@ test_indices*/
|
|||||||
test_*.py
|
test_*.py
|
||||||
!tests/**
|
!tests/**
|
||||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||||
|
|
||||||
|
*.meta.json
|
||||||
|
*.passages.json
|
||||||
|
|
||||||
|
batchtest.py
|
||||||
|
tests/__pytest_cache__/
|
||||||
|
tests/__pycache__/
|
||||||
|
|||||||
14
.gitmodules
vendored
14
.gitmodules
vendored
@@ -1,6 +1,16 @@
|
|||||||
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
|
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
|
||||||
path = packages/leann-backend-diskann/third_party/DiskANN
|
path = packages/leann-backend-diskann/third_party/DiskANN
|
||||||
url = https://github.com/yichuan520030910320/DiskANN.git
|
url = https://github.com/yichuan-w/DiskANN.git
|
||||||
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
||||||
path = packages/leann-backend-hnsw/third_party/faiss
|
path = packages/leann-backend-hnsw/third_party/faiss
|
||||||
url = https://github.com/yichuan520030910320/faiss.git
|
url = https://github.com/yichuan-w/faiss.git
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/msgpack-c
|
||||||
|
url = https://github.com/msgpack/msgpack-c.git
|
||||||
|
branch = cpp_master
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/cppzmq"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/cppzmq
|
||||||
|
url = https://github.com/zeromq/cppzmq.git
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
|||||||
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.5.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: debug-statements
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.2.1
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
- id: ruff-format
|
||||||
9
.vscode/extensions.json
vendored
9
.vscode/extensions.json
vendored
@@ -1,9 +0,0 @@
|
|||||||
{
|
|
||||||
"recommendations": [
|
|
||||||
"llvm-vs-code-extensions.vscode-clangd",
|
|
||||||
"ms-python.python",
|
|
||||||
"ms-vscode.cmake-tools",
|
|
||||||
"vadimcn.vscode-lldb",
|
|
||||||
"eamodio.gitlens",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
283
.vscode/launch.json
vendored
283
.vscode/launch.json
vendored
@@ -1,283 +0,0 @@
|
|||||||
{
|
|
||||||
// Use IntelliSense to learn about possible attributes.
|
|
||||||
// Hover to view descriptions of existing attributes.
|
|
||||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
|
||||||
"version": "0.2.0",
|
|
||||||
"configurations": [
|
|
||||||
// new emdedder
|
|
||||||
{
|
|
||||||
"name": "New Embedder",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/main.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"--search",
|
|
||||||
"--use-original",
|
|
||||||
"--domain",
|
|
||||||
"dpr",
|
|
||||||
"--nprobe",
|
|
||||||
"5000",
|
|
||||||
"--load",
|
|
||||||
"flat",
|
|
||||||
"--embedder",
|
|
||||||
"intfloat/multilingual-e5-small"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
|
|
||||||
{
|
|
||||||
"name": "main.py",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/main.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"--query",
|
|
||||||
"1000",
|
|
||||||
"--load",
|
|
||||||
"bm25"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Simple Build",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"faiss/demo/simple_build.py"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
//# Fix for Intel MKL error
|
|
||||||
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
|
|
||||||
//python faiss/demo/build_demo.py
|
|
||||||
{
|
|
||||||
"name": "Build Demo",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"faiss/demo/build_demo.py"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "DiskANN Serve",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"demo/main.py",
|
|
||||||
"--mode",
|
|
||||||
"serve",
|
|
||||||
"--engine",
|
|
||||||
"sglang",
|
|
||||||
"--load-indices",
|
|
||||||
"diskann",
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--lazy-load",
|
|
||||||
"--recompute-beighbor-embeddings",
|
|
||||||
"--port",
|
|
||||||
"8082",
|
|
||||||
"--diskann-search-memory-maximum",
|
|
||||||
"2",
|
|
||||||
"--diskann-graph",
|
|
||||||
"240",
|
|
||||||
"--search-only"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
|
|
||||||
},
|
|
||||||
"preLaunchTask": "CMake: build",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "DiskANN Serve MAC",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"demo/main.py",
|
|
||||||
"--mode",
|
|
||||||
"serve",
|
|
||||||
"--engine",
|
|
||||||
"ollama",
|
|
||||||
"--load-indices",
|
|
||||||
"diskann",
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--lazy-load",
|
|
||||||
"--recompute-beighbor-embeddings"
|
|
||||||
],
|
|
||||||
"preLaunchTask": "CMake: build",
|
|
||||||
"env": {
|
|
||||||
"KMP_DUPLICATE_LIB_OK": "TRUE",
|
|
||||||
"OMP_NUM_THREADS": "1",
|
|
||||||
"MKL_NUM_THREADS": "1",
|
|
||||||
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
|
|
||||||
"KMP_BLOCKTIME": "0"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Python Debugger: Current File with Arguments",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "ric/main_ric.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"--config-name",
|
|
||||||
"${input:configSelection}"
|
|
||||||
],
|
|
||||||
"justMyCode": false
|
|
||||||
},
|
|
||||||
//python ./demo/validate_equivalence.py sglang
|
|
||||||
{
|
|
||||||
"name": "Validate Equivalence",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/validate_equivalence.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"sglang"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
|
|
||||||
{
|
|
||||||
"name": "Retrieval Demo",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/retrieval_demo.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"--engine",
|
|
||||||
"vllm",
|
|
||||||
"--skip-embeddings",
|
|
||||||
"--domain",
|
|
||||||
"dpr",
|
|
||||||
"--load-indices",
|
|
||||||
// "flat",
|
|
||||||
"ivf_flat"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
|
|
||||||
{
|
|
||||||
"name": "Retrieval Demo DiskANN",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/retrieval_demo.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"--engine",
|
|
||||||
"sglang",
|
|
||||||
"--skip-embeddings",
|
|
||||||
"--domain",
|
|
||||||
"dpr",
|
|
||||||
"--load-indices",
|
|
||||||
"diskann",
|
|
||||||
"--hnsw-M",
|
|
||||||
"64",
|
|
||||||
"--hnsw-efConstruction",
|
|
||||||
"150",
|
|
||||||
"--hnsw-efSearch",
|
|
||||||
"128",
|
|
||||||
"--hnsw-sq-bits",
|
|
||||||
"8"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Find Probe",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "find_probe.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Python: Attach",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "attach",
|
|
||||||
"processId": "${command:pickProcess}",
|
|
||||||
"justMyCode": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Edge RAG",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"edgerag_demo.py"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
|
|
||||||
"MKL_NUM_THREADS": "1",
|
|
||||||
"OMP_NUM_THREADS": "1",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Launch Embedding Server",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/embedding_server.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--zmq-port",
|
|
||||||
"5556",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "HNSW Serve",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"demo/main.py",
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--load",
|
|
||||||
"hnsw",
|
|
||||||
"--mode",
|
|
||||||
"serve",
|
|
||||||
"--search",
|
|
||||||
"--skip-pa",
|
|
||||||
"--recompute",
|
|
||||||
"--hnsw-old"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"inputs": [
|
|
||||||
{
|
|
||||||
"id": "configSelection",
|
|
||||||
"type": "pickString",
|
|
||||||
"description": "Select a configuration",
|
|
||||||
"options": [
|
|
||||||
"example_config",
|
|
||||||
"vllm_gritlm"
|
|
||||||
],
|
|
||||||
"default": "example_config"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
43
.vscode/settings.json
vendored
43
.vscode/settings.json
vendored
@@ -1,43 +0,0 @@
|
|||||||
{
|
|
||||||
"python.analysis.extraPaths": [
|
|
||||||
"./sglang_repo/python"
|
|
||||||
],
|
|
||||||
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
|
|
||||||
"cmake.configureArgs": [
|
|
||||||
"-DPYBIND=True",
|
|
||||||
"-DUPDATE_EDITABLE_INSTALL=ON",
|
|
||||||
],
|
|
||||||
"cmake.environment": {
|
|
||||||
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
|
|
||||||
},
|
|
||||||
"cmake.buildDirectory": "${workspaceFolder}/build",
|
|
||||||
"files.associations": {
|
|
||||||
"*.tcc": "cpp",
|
|
||||||
"deque": "cpp",
|
|
||||||
"string": "cpp",
|
|
||||||
"unordered_map": "cpp",
|
|
||||||
"vector": "cpp",
|
|
||||||
"map": "cpp",
|
|
||||||
"unordered_set": "cpp",
|
|
||||||
"atomic": "cpp",
|
|
||||||
"inplace_vector": "cpp",
|
|
||||||
"*.ipp": "cpp",
|
|
||||||
"forward_list": "cpp",
|
|
||||||
"list": "cpp",
|
|
||||||
"any": "cpp",
|
|
||||||
"system_error": "cpp",
|
|
||||||
"__hash_table": "cpp",
|
|
||||||
"__split_buffer": "cpp",
|
|
||||||
"__tree": "cpp",
|
|
||||||
"ios": "cpp",
|
|
||||||
"set": "cpp",
|
|
||||||
"__string": "cpp",
|
|
||||||
"string_view": "cpp",
|
|
||||||
"ranges": "cpp",
|
|
||||||
"iosfwd": "cpp"
|
|
||||||
},
|
|
||||||
"lldb.displayFormat": "auto",
|
|
||||||
"lldb.showDisassembly": "auto",
|
|
||||||
"lldb.dereferencePointers": true,
|
|
||||||
"lldb.consoleMode": "commands",
|
|
||||||
}
|
|
||||||
16
.vscode/tasks.json
vendored
16
.vscode/tasks.json
vendored
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "2.0.0",
|
|
||||||
"tasks": [
|
|
||||||
{
|
|
||||||
"type": "cmake",
|
|
||||||
"label": "CMake: build",
|
|
||||||
"command": "build",
|
|
||||||
"targets": [
|
|
||||||
"all"
|
|
||||||
],
|
|
||||||
"group": "build",
|
|
||||||
"problemMatcher": [],
|
|
||||||
"detail": "CMake template build task"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2024 Rulin Shao
|
Copyright (c) 2025 LEANN Contributors
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
|||||||
717
README.md
717
README.md
@@ -1,172 +1,544 @@
|
|||||||
# 🚀 LEANN: A Low-Storage Vector Index
|
<p align="center">
|
||||||
|
<img src="assets/logo-text.png" alt="LEANN Logo" width="400">
|
||||||
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs Welcome">
|
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS%20%7C%20Windows-lightgrey" alt="Platform">
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
|
The smallest vector index in the world. RAG Everything with LEANN!
|
||||||
|
</h2>
|
||||||
|
|
||||||
|
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Why LEANN?
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>⚡ Real-time embedding computation for large-scale RAG on consumer hardware</strong>
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||||
<a href="#-quick-start">Quick Start</a> •
|
|
||||||
<a href="#-features">Features</a> •
|
|
||||||
<a href="#-benchmarks">Benchmarks</a> •
|
|
||||||
<a href="#-documentation">Documentation</a> •
|
|
||||||
<a href="#-paper">Paper</a>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🌟 What is Leann?
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
|
|
||||||
**Leann** revolutionizes Retrieval-Augmented Generation (RAG) by eliminating the storage bottleneck of traditional vector databases. Instead of pre-computing and storing billions of embeddings, Leann dynamically computes embeddings at query time using highly optimized graph-based search algorithms.
|
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||||
|
|
||||||
### 🎯 Why Leann?
|
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
|
||||||
|
|
||||||
Traditional RAG systems face a fundamental trade-off:
|
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||||
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space
|
|
||||||
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change
|
|
||||||
- **💰 Cost**: Vector databases are expensive to scale
|
|
||||||
|
|
||||||
**Leann solves this by:**
|
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||||
- ✅ **Zero embedding storage** - Only graph structure is persisted
|
|
||||||
- ✅ **Real-time computation** - Embeddings computed on-demand with ms latency
|
|
||||||
- ✅ **Memory efficient** - Runs on consumer hardware (8GB RAM)
|
|
||||||
- ✅ **Always fresh** - No stale embeddings, ever
|
|
||||||
|
|
||||||
## 🚀 Quick Start
|
## Installation
|
||||||
|
|
||||||
### Installation
|
### 📦 Prerequisites: Install uv
|
||||||
|
|
||||||
|
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🚀 Quick Install
|
||||||
|
|
||||||
|
Clone the repository to access all examples and try amazing applications,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
```
|
||||||
|
|
||||||
|
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install leann
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>
|
||||||
|
<strong>🔧 Build from Source (Recommended for development)</strong>
|
||||||
|
</summary>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
cd leann
|
cd leann
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
|
```
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux:**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
uv sync
|
uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
### 30-Second Example
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Our declarative API makes RAG as easy as writing a config file.
|
||||||
|
|
||||||
|
Check out [demo.ipynb](demo.ipynb) or [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from pathlib import Path
|
||||||
|
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||||
|
|
||||||
# 1. Build index (no embeddings stored!)
|
# Build an index
|
||||||
builder = LeannBuilder(backend_name="diskann")
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
builder.add_text("Python is a powerful programming language")
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
builder.add_text("Machine learning transforms industries")
|
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||||
builder.add_text("Neural networks process complex data")
|
builder.build_index(INDEX_PATH)
|
||||||
builder.build_index("knowledge.leann")
|
|
||||||
|
|
||||||
# 2. Search with real-time embeddings
|
# Search
|
||||||
searcher = LeannSearcher("knowledge.leann")
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
results = searcher.search("programming languages", top_k=2)
|
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||||
|
|
||||||
for result in results:
|
# Chat with your data
|
||||||
print(f"Score: {result['score']:.3f} - {result['text']}")
|
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||||
|
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run the Demo
|
## RAG on Everything!
|
||||||
|
|
||||||
|
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
|
### Generation Model Setup
|
||||||
|
|
||||||
|
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||||
|
|
||||||
|
Set your OpenAI API key as an environment variable:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run examples/document_search.py
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
```
|
```
|
||||||
|
|
||||||
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
|
</details>
|
||||||
|
|
||||||
This demo showcases how to build a RAG system for PDF documents using Leann.
|
<details>
|
||||||
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
|
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
|
||||||
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
|
|
||||||
|
**macOS:**
|
||||||
|
|
||||||
|
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run examples/main_cli_example.py
|
# Pull a lightweight model (recommended for consumer hardware)
|
||||||
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
## ✨ Features
|
**Linux:**
|
||||||
|
|
||||||
### 🔥 Core Features
|
```bash
|
||||||
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
|
# Install Ollama
|
||||||
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
|
curl -fsSL https://ollama.ai/install.sh | sh
|
||||||
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
|
|
||||||
- **📈 Scalable Architecture**: Handles millions of documents on consumer hardware
|
|
||||||
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
|
|
||||||
|
|
||||||
### 🛠️ Technical Highlights
|
# Start Ollama service manually
|
||||||
- **Zero-copy operations** for maximum performance
|
ollama serve &
|
||||||
- **SIMD-optimized** distance computations (AVX2/AVX512)
|
|
||||||
- **Async embedding pipeline** with batched processing
|
|
||||||
- **Memory-mapped indices** for fast startup
|
|
||||||
- **Recompute mode** for highest accuracy scenarios
|
|
||||||
|
|
||||||
### 🎨 Developer Experience
|
# Pull a lightweight model (recommended for consumer hardware)
|
||||||
- **Simple Python API** - Get started in minutes
|
ollama pull llama3.2:1b
|
||||||
- **Extensible backend system** - Easy to add new algorithms
|
|
||||||
- **Comprehensive examples** - From basic usage to production deployment
|
|
||||||
- **Rich debugging tools** - Built-in performance profiling
|
|
||||||
|
|
||||||
## 📊 Benchmarks
|
|
||||||
|
|
||||||
### Memory Usage Comparison
|
|
||||||
|
|
||||||
| System | 1M Documents | 10M Documents | 100M Documents |
|
|
||||||
|--------|-------------|---------------|----------------|
|
|
||||||
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
|
|
||||||
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
|
|
||||||
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
|
|
||||||
|
|
||||||
### Query Performance
|
|
||||||
|
|
||||||
| Backend | Index Size | Query Time | Recall@10 |
|
|
||||||
|---------|------------|------------|-----------|
|
|
||||||
| DiskANN | 1M docs | 12ms | 0.95 |
|
|
||||||
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
|
|
||||||
| HNSW | 1M docs | 8ms | 0.93 |
|
|
||||||
|
|
||||||
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
|
|
||||||
|
|
||||||
## 🏗️ Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
|
||||||
│ Query Text │───▶│ Embedding │───▶│ Graph-based │
|
|
||||||
│ │ │ Computation │ │ Search │
|
|
||||||
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
|
||||||
│ │
|
|
||||||
▼ ▼
|
|
||||||
┌──────────────┐ ┌──────────────┐
|
|
||||||
│ ZMQ Server │ │ Pruned Graph │
|
|
||||||
│ (Cached) │ │ Index │
|
|
||||||
└──────────────┘ └──────────────┘
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Key Components
|
</details>
|
||||||
|
|
||||||
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
|
### Flexible Configuration
|
||||||
2. **📊 Graph Index**: Memory-efficient navigation structures
|
|
||||||
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
|
|
||||||
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
|
|
||||||
|
|
||||||
## 🎓 Supported Models & Backends
|
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||||
|
|
||||||
### 🤖 Embedding Models
|
<details>
|
||||||
- **sentence-transformers/all-mpnet-base-v2** (default)
|
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
||||||
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
|
|
||||||
- Any HuggingFace sentence-transformer model
|
|
||||||
- Custom model support via API
|
|
||||||
|
|
||||||
### 🔧 Search Backends
|
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
|
||||||
- **DiskANN**: Microsoft's billion-scale ANN algorithm
|
|
||||||
- **HNSW**: Hierarchical Navigable Small World graphs
|
|
||||||
- **Coming soon**: ScaNN, Faiss-IVF, NGT
|
|
||||||
|
|
||||||
### 📏 Distance Functions
|
```bash
|
||||||
- **L2**: Euclidean distance for precise similarity
|
# Core Parameters (General preprocessing for all examples)
|
||||||
- **Cosine**: Angular similarity for normalized vectors
|
--index-dir DIR # Directory to store the index (default: current directory)
|
||||||
- **MIPS**: Maximum Inner Product Search for recommendation systems
|
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
||||||
|
--max-items N # Limit data preprocessing (default: -1, process all data)
|
||||||
|
--force-rebuild # Force rebuild index even if it exists
|
||||||
|
|
||||||
|
# Embedding Parameters
|
||||||
|
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small or mlx-community/multilingual-e5-base-mlx
|
||||||
|
--embedding-mode MODE # sentence-transformers, openai, or mlx
|
||||||
|
|
||||||
|
# LLM Parameters (Text generation models)
|
||||||
|
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||||
|
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||||
|
|
||||||
|
# Search Parameters
|
||||||
|
--top-k N # Number of results to retrieve (default: 20)
|
||||||
|
--search-complexity N # Search complexity for graph traversal (default: 32)
|
||||||
|
|
||||||
|
# Chunking Parameters
|
||||||
|
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
||||||
|
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
||||||
|
|
||||||
|
# Index Building Parameters
|
||||||
|
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
||||||
|
--graph-degree N # Graph degree for index construction (default: 32)
|
||||||
|
--build-complexity N # Build complexity for index construction (default: 64)
|
||||||
|
--no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
|
||||||
|
--no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
|
||||||
|
|
||||||
|
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a README in Chinese) and this is the **easiest example** to run here:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate # Don't forget to activate the virtual environment
|
||||||
|
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
```bash
|
||||||
|
--data-dir DIR # Directory containing documents to process (default: data)
|
||||||
|
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example Commands
|
||||||
|
```bash
|
||||||
|
# Process all documents with larger chunks for academic papers
|
||||||
|
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
||||||
|
|
||||||
|
# Filter only markdown and Python files with smaller chunks
|
||||||
|
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||||
|
|
||||||
|
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||||
|
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
||||||
|
```
|
||||||
|
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
```bash
|
||||||
|
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
|
||||||
|
--include-html # Include HTML content in processing (useful for newsletters)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example Commands
|
||||||
|
```bash
|
||||||
|
# Search work emails from a specific account
|
||||||
|
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
|
||||||
|
|
||||||
|
# Find all receipts and order confirmations (includes HTML)
|
||||||
|
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Example queries you can try</strong></summary>
|
||||||
|
|
||||||
|
Once the index is built, you can ask questions like:
|
||||||
|
- "Find emails from my boss about deadlines"
|
||||||
|
- "What did John say about the project timeline?"
|
||||||
|
- "Show me emails about travel expenses"
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
|
||||||
|
```
|
||||||
|
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
```bash
|
||||||
|
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example Commands
|
||||||
|
```bash
|
||||||
|
# Search academic research from your browsing history
|
||||||
|
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
|
||||||
|
|
||||||
|
# Track competitor analysis across work profile
|
||||||
|
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: How to find your Chrome profile</strong></summary>
|
||||||
|
|
||||||
|
The default Chrome profile path is configured for a typical macOS setup. If you need to find your specific Chrome profile:
|
||||||
|
|
||||||
|
1. Open Terminal
|
||||||
|
2. Run: `ls ~/Library/Application\ Support/Google/Chrome/`
|
||||||
|
3. Look for folders like "Default", "Profile 1", "Profile 2", etc.
|
||||||
|
4. Use the full path as your `--chrome-profile` argument
|
||||||
|
|
||||||
|
**Common Chrome profile locations:**
|
||||||
|
- macOS: `~/Library/Application Support/Google/Chrome/Default`
|
||||||
|
- Linux: `~/.config/google-chrome/Default`
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
|
||||||
|
|
||||||
|
Once the index is built, you can ask questions like:
|
||||||
|
|
||||||
|
- "What websites did I visit about machine learning?"
|
||||||
|
- "Find my search history about programming"
|
||||||
|
- "What YouTube videos did I watch recently?"
|
||||||
|
- "Show me websites I visited about travel planning"
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
||||||
|
```
|
||||||
|
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||||
|
|
||||||
|
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
|
||||||
|
|
||||||
|
```bash
|
||||||
|
brew install sunnyyoung/repo/wechattweak-cli
|
||||||
|
```
|
||||||
|
|
||||||
|
or install it manually (if you have issues with Homebrew):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo packages/wechat-exporter/wechattweak-cli install
|
||||||
|
```
|
||||||
|
|
||||||
|
**Troubleshooting:**
|
||||||
|
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||||
|
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||||
|
```bash
|
||||||
|
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||||
|
Failed to find or export WeChat data. Exiting.
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
```bash
|
||||||
|
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
|
||||||
|
--force-export # Force re-export even if data exists
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example Commands
|
||||||
|
```bash
|
||||||
|
# Search for travel plans discussed in group chats
|
||||||
|
python -m apps.wechat_rag --query "travel plans" --max-items 10000
|
||||||
|
|
||||||
|
# Re-export and search recent chats (useful after new messages)
|
||||||
|
python -m apps.wechat_rag --force-export --query "work schedule"
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
|
||||||
|
|
||||||
|
Once the index is built, you can ask questions like:
|
||||||
|
|
||||||
|
- "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?" (Chinese: Show me chat records about buying Magic Johnson's jersey)
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
|
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
If you followed the Quick Start, `leann` is already installed in your virtual environment:
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate
|
||||||
|
leann --help
|
||||||
|
```
|
||||||
|
|
||||||
|
**To make it globally available (recommended for daily use):**
|
||||||
|
```bash
|
||||||
|
# Install the LEANN CLI globally using uv tool
|
||||||
|
uv tool install leann
|
||||||
|
|
||||||
|
# Now you can use leann from anywhere without activating venv
|
||||||
|
leann --help
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Usage Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build an index from documents
|
||||||
|
leann build my-docs --docs ./documents
|
||||||
|
|
||||||
|
# Search your documents
|
||||||
|
leann search my-docs "machine learning concepts"
|
||||||
|
|
||||||
|
# Interactive chat with your documents
|
||||||
|
leann ask my-docs --interactive
|
||||||
|
|
||||||
|
# List all your indexes
|
||||||
|
leann list
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key CLI features:**
|
||||||
|
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||||
|
- Smart text chunking with overlap
|
||||||
|
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||||
|
- Organized index storage in `~/.leann/indexes/`
|
||||||
|
- Support for advanced search parameters
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||||
|
|
||||||
|
**Build Command:**
|
||||||
|
```bash
|
||||||
|
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||||
|
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||||
|
--graph-degree N Graph degree (default: 32)
|
||||||
|
--complexity N Build complexity (default: 64)
|
||||||
|
--force Force rebuild existing index
|
||||||
|
--compact Use compact storage (default: true)
|
||||||
|
--recompute Enable recomputation (default: true)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Search Command:**
|
||||||
|
```bash
|
||||||
|
leann search INDEX_NAME QUERY [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--top-k N Number of results (default: 5)
|
||||||
|
--complexity N Search complexity (default: 64)
|
||||||
|
--recompute-embeddings Use recomputation for highest accuracy
|
||||||
|
--pruning-strategy {global,local,proportional}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ask Command:**
|
||||||
|
```bash
|
||||||
|
leann ask INDEX_NAME [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||||
|
--model MODEL Model name (default: qwen3:8b)
|
||||||
|
--interactive Interactive chat mode
|
||||||
|
--top-k N Retrieval count (default: 20)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="assets/arch.png" alt="LEANN Architecture" width="800">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
**The magic:** Most vector DBs store every single embedding (expensive). LEANN stores a pruned graph structure (cheap) and recomputes embeddings only when needed (fast).
|
||||||
|
|
||||||
|
**Core techniques:**
|
||||||
|
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
||||||
|
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||||
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
|
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||||
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
|
||||||
|
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
|
||||||
|
### 📊 Storage Comparison
|
||||||
|
|
||||||
|
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||||
|
|--------|-------------|------------|-------------|--------------|---------------|
|
||||||
|
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
|
||||||
|
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
|
||||||
|
| Savings| 91% | 97% | 97% | 97% | 95% |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Reproduce Our Results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install -e ".[dev]" # Install dev dependencies
|
||||||
|
python benchmarks/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||||
|
python benchmarks/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||||
|
```
|
||||||
|
|
||||||
|
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||||
## 🔬 Paper
|
## 🔬 Paper
|
||||||
|
|
||||||
If you find Leann useful, please cite:
|
If you find Leann useful, please cite:
|
||||||
@@ -185,110 +557,15 @@ If you find Leann useful, please cite:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🌍 Use Cases
|
## ✨ [Detailed Features →](docs/features.md)
|
||||||
|
|
||||||
### 💼 Enterprise RAG
|
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
|
||||||
```python
|
|
||||||
# Handle millions of documents with limited resources
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="diskann",
|
|
||||||
distance_metric="cosine",
|
|
||||||
graph_degree=64,
|
|
||||||
memory_budget="4GB"
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 🔬 Research & Experimentation
|
|
||||||
```python
|
|
||||||
# Quick prototyping with different algorithms
|
|
||||||
for backend in ["diskann", "hnsw"]:
|
|
||||||
searcher = LeannSearcher(index_path, backend=backend)
|
|
||||||
evaluate_recall(searcher, queries, ground_truth)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 🚀 Real-time Applications
|
|
||||||
```python
|
|
||||||
# Sub-second response times
|
|
||||||
chat = LeannChat("knowledge.leann")
|
|
||||||
response = chat.ask("What is quantum computing?")
|
|
||||||
# Returns in <100ms with recompute mode
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🤝 Contributing
|
|
||||||
|
|
||||||
We welcome contributions! Leann is built by the community, for the community.
|
|
||||||
|
|
||||||
### Ways to Contribute
|
|
||||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
|
||||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
|
||||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
|
||||||
- 📖 **Documentation**: Help make Leann more accessible
|
|
||||||
- 🧪 **Benchmarks**: Share your performance results
|
|
||||||
|
|
||||||
### Development Setup
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/yourname/leann
|
|
||||||
cd leann
|
|
||||||
uv sync --dev
|
|
||||||
uv run pytest tests/
|
|
||||||
```
|
|
||||||
|
|
||||||
### Quick Tests
|
|
||||||
```bash
|
|
||||||
# Sanity check all distance functions
|
|
||||||
uv run python tests/sanity_checks/test_distance_functions.py
|
|
||||||
|
|
||||||
# Verify L2 implementation
|
|
||||||
uv run python tests/sanity_checks/test_l2_verification.py
|
|
||||||
```
|
|
||||||
## ❓ FAQ
|
|
||||||
|
|
||||||
### Common Issues
|
|
||||||
|
|
||||||
#### NCCL Topology Error
|
|
||||||
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
|
||||||
```
|
|
||||||
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
|
||||||
```
|
|
||||||
|
|
||||||
**Solution**: Set these environment variables before running your script:
|
|
||||||
```bash
|
|
||||||
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
|
||||||
export NCCL_DEBUG=INFO
|
|
||||||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
|
||||||
export NCCL_IB_DISABLE=1
|
|
||||||
export NCCL_NET_PLUGIN=none
|
|
||||||
export NCCL_SOCKET_IFNAME=ens5
|
|
||||||
|
|
||||||
|
|
||||||
## 📈 Roadmap
|
## ❓ [FAQ →](docs/faq.md)
|
||||||
|
|
||||||
### 🎯 Q1 2024
|
|
||||||
- [x] DiskANN backend with MIPS/L2/Cosine support
|
|
||||||
- [x] HNSW backend integration
|
|
||||||
- [x] Real-time embedding pipeline
|
|
||||||
- [x] Memory-efficient graph pruning
|
|
||||||
|
|
||||||
### 🚀 Q2 2024
|
## 📈 [Roadmap →](docs/roadmap.md)
|
||||||
- [ ] Distributed search across multiple nodes
|
|
||||||
- [ ] ScaNN backend support
|
|
||||||
- [ ] Advanced caching strategies
|
|
||||||
- [ ] Kubernetes deployment guides
|
|
||||||
|
|
||||||
### 🌟 Q3 2024
|
|
||||||
- [ ] GPU-accelerated embedding computation
|
|
||||||
- [ ] Approximate distance functions
|
|
||||||
- [ ] Integration with LangChain/LlamaIndex
|
|
||||||
- [ ] Visual similarity search
|
|
||||||
|
|
||||||
## 💬 Community
|
|
||||||
|
|
||||||
Join our growing community of researchers and engineers!
|
|
||||||
|
|
||||||
- 🐦 **Twitter**: [@LeannAI](https://twitter.com/LeannAI)
|
|
||||||
- 💬 **Discord**: [Join our server](https://discord.gg/leann)
|
|
||||||
- 📧 **Email**: leann@yourcompany.com
|
|
||||||
- 🐙 **GitHub Discussions**: [Ask questions here](https://github.com/yourname/leann/discussions)
|
|
||||||
|
|
||||||
## 📄 License
|
## 📄 License
|
||||||
|
|
||||||
@@ -296,10 +573,8 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
## 🙏 Acknowledgments
|
## 🙏 Acknowledgments
|
||||||
|
|
||||||
- **Microsoft Research** for the DiskANN algorithm
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||||
- **Meta AI** for FAISS and optimization insights
|
|
||||||
- **HuggingFace** for the transformer ecosystem
|
|
||||||
- **Our amazing contributors** who make this possible
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
296
apps/base_rag_example.py
Normal file
296
apps/base_rag_example.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
"""
|
||||||
|
Base class for unified RAG examples interface.
|
||||||
|
Provides common parameters and functionality for all RAG examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRAGExample(ABC):
|
||||||
|
"""Base class for all RAG examples with unified interface."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
default_index_name: str,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.default_index_name = default_index_name
|
||||||
|
self.parser = self._create_parser()
|
||||||
|
|
||||||
|
def _create_parser(self) -> argparse.ArgumentParser:
|
||||||
|
"""Create argument parser with common parameters."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Core parameters (all examples share these)
|
||||||
|
core_group = parser.add_argument_group("Core Parameters")
|
||||||
|
core_group.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default=f"./{self.default_index_name}",
|
||||||
|
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Query to run (if not provided, will run in interactive mode)",
|
||||||
|
)
|
||||||
|
# Allow subclasses to override default max_items
|
||||||
|
max_items_default = getattr(self, "max_items_default", -1)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--max-items",
|
||||||
|
type=int,
|
||||||
|
default=max_items_default,
|
||||||
|
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding parameters
|
||||||
|
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||||
|
# Allow subclasses to override default embedding_model
|
||||||
|
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default=embedding_model_default,
|
||||||
|
help=f"Embedding model to use (default: {embedding_model_default})",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
|
help="Embedding backend mode (default: sentence-transformers)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM parameters
|
||||||
|
llm_group = parser.add_argument_group("LLM Parameters")
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="openai",
|
||||||
|
choices=["openai", "ollama", "hf"],
|
||||||
|
help="LLM backend to use (default: openai)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-host",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:11434",
|
||||||
|
help="Host for Ollama API (default: http://localhost:11434)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search parameters
|
||||||
|
search_group = parser.add_argument_group("Search Parameters")
|
||||||
|
search_group.add_argument(
|
||||||
|
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||||
|
)
|
||||||
|
search_group.add_argument(
|
||||||
|
"--search-complexity",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Search complexity for graph traversal (default: 64)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index building parameters
|
||||||
|
index_group = parser.add_argument_group("Index Building Parameters")
|
||||||
|
index_group.add_argument(
|
||||||
|
"--backend-name",
|
||||||
|
type=str,
|
||||||
|
default="hnsw",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
help="Backend to use for index (default: hnsw)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--graph-degree",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Graph degree for index construction (default: 32)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--build-complexity",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Build complexity for index construction (default: 64)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-compact",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable compact index storage",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-recompute",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable embedding recomputation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add source-specific parameters
|
||||||
|
self._add_specific_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add source-specific arguments. Override in subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load data from the source. Returns list of text chunks."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_llm_config(self, args) -> dict[str, Any]:
|
||||||
|
"""Get LLM configuration based on arguments."""
|
||||||
|
config = {"type": args.llm}
|
||||||
|
|
||||||
|
if args.llm == "openai":
|
||||||
|
config["model"] = args.llm_model or "gpt-4o"
|
||||||
|
elif args.llm == "ollama":
|
||||||
|
config["model"] = args.llm_model or "llama3.2:1b"
|
||||||
|
config["host"] = args.llm_host
|
||||||
|
elif args.llm == "hf":
|
||||||
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def build_index(self, args, texts: list[str]) -> str:
|
||||||
|
"""Build LEANN index from texts."""
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
|
||||||
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
|
print(f"Total text chunks: {len(texts)}")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend_name,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
complexity=args.build_complexity,
|
||||||
|
is_compact=not args.no_compact,
|
||||||
|
is_recompute=not args.no_recompute,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add texts in batches for better progress tracking
|
||||||
|
batch_size = 1000
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
for text in batch:
|
||||||
|
builder.add_text(text)
|
||||||
|
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||||
|
|
||||||
|
print("Building index structure...")
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"Index saved to: {index_path}")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def run_interactive_chat(self, args, index_path: str):
|
||||||
|
"""Run interactive chat with the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||||
|
print("Type 'quit' or 'exit' to stop.\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("You: ").strip()
|
||||||
|
if query.lower() in ["quit", "exit", "q"]:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
response = chat.ask(query, top_k=args.top_k, complexity=args.search_complexity)
|
||||||
|
print(f"\nAssistant: {response}\n")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
async def run_single_query(self, args, index_path: str, query: str):
|
||||||
|
"""Run a single query against the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||||
|
response = chat.ask(query, top_k=args.top_k, complexity=args.search_complexity)
|
||||||
|
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Main entry point for the example."""
|
||||||
|
args = self.parser.parse_args()
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
index_exists = Path(args.index_dir).exists()
|
||||||
|
|
||||||
|
if not index_exists or args.force_rebuild:
|
||||||
|
# Load data and build index
|
||||||
|
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||||
|
texts = await self.load_data(args)
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
print("No data found to index!")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_path = await self.build_index(args, texts)
|
||||||
|
else:
|
||||||
|
print(f"\nUsing existing index in {args.index_dir}")
|
||||||
|
|
||||||
|
# Run query or interactive mode
|
||||||
|
if args.query:
|
||||||
|
await self.run_single_query(args, index_path, args.query)
|
||||||
|
else:
|
||||||
|
await self.run_interactive_chat(args, index_path)
|
||||||
|
|
||||||
|
|
||||||
|
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
||||||
|
"""Helper function to create text chunks from documents."""
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separator=" ",
|
||||||
|
paragraph_separator="\n\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
if nodes:
|
||||||
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
170
apps/browser_rag.py
Normal file
170
apps/browser_rag.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
Browser History RAG example using the unified interface.
|
||||||
|
Supports Chrome browser history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||||
|
|
||||||
|
from .history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Chrome browser history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Browser History",
|
||||||
|
description="Process and query Chrome browser history with LEANN",
|
||||||
|
default_index_name="google_history_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add browser-specific arguments."""
|
||||||
|
browser_group = parser.add_argument_group("Browser Parameters")
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chrome-profile",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--auto-find-profiles",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Automatically find all Chrome profiles (default: True)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_chrome_base_path(self) -> Path:
|
||||||
|
"""Get the base Chrome profile path based on OS."""
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||||
|
elif sys.platform.startswith("linux"):
|
||||||
|
return Path.home() / ".config" / "google-chrome"
|
||||||
|
elif sys.platform == "win32":
|
||||||
|
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||||
|
|
||||||
|
def _find_chrome_profiles(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Chrome profiles."""
|
||||||
|
base_path = self._get_chrome_base_path()
|
||||||
|
if not base_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
profiles = []
|
||||||
|
|
||||||
|
# Check Default profile
|
||||||
|
default_profile = base_path / "Default"
|
||||||
|
if default_profile.exists() and (default_profile / "History").exists():
|
||||||
|
profiles.append(default_profile)
|
||||||
|
|
||||||
|
# Check numbered profiles
|
||||||
|
for item in base_path.iterdir():
|
||||||
|
if item.is_dir() and item.name.startswith("Profile "):
|
||||||
|
if (item / "History").exists():
|
||||||
|
profiles.append(item)
|
||||||
|
|
||||||
|
return profiles
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load browser history and convert to text chunks."""
|
||||||
|
# Determine Chrome profiles
|
||||||
|
if args.chrome_profile and not args.auto_find_profiles:
|
||||||
|
profile_dirs = [Path(args.chrome_profile)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Chrome profiles...")
|
||||||
|
profile_dirs = self._find_chrome_profiles()
|
||||||
|
|
||||||
|
# If specific profile given, filter to just that one
|
||||||
|
if args.chrome_profile:
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||||
|
|
||||||
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found!")
|
||||||
|
print("Please specify --chrome-profile manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
# Process each profile
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
|
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per profile
|
||||||
|
max_per_profile = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_profile = remaining
|
||||||
|
|
||||||
|
# Load history
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=str(profile_dir),
|
||||||
|
max_count=max_per_profile,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} history entries from this profile")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No browser history found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for browser history RAG
|
||||||
|
print("\n🌐 Browser History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What websites did I visit about machine learning?'")
|
||||||
|
print("- 'Find my search history about programming'")
|
||||||
|
print("- 'What YouTube videos did I watch recently?'")
|
||||||
|
print("- 'Show me websites about travel planning'")
|
||||||
|
print("\nNote: Make sure Chrome is closed before running\n")
|
||||||
|
|
||||||
|
rag = BrowserRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
106
apps/document_rag.py
Normal file
106
apps/document_rag.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
Document RAG example using the unified interface.
|
||||||
|
Supports PDF, TXT, MD, and other document formats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRAG(BaseRAGExample):
|
||||||
|
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Document",
|
||||||
|
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||||
|
default_index_name="test_doc_files",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add document-specific arguments."""
|
||||||
|
doc_group = parser.add_argument_group("Document Parameters")
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
type=str,
|
||||||
|
default="data",
|
||||||
|
help="Directory containing documents to index (default: data)",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--file-types",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load documents and convert to text chunks."""
|
||||||
|
print(f"Loading documents from: {args.data_dir}")
|
||||||
|
if args.file_types:
|
||||||
|
print(f"Filtering by file types: {args.file_types}")
|
||||||
|
else:
|
||||||
|
print("Processing all supported file types")
|
||||||
|
|
||||||
|
# Check if data directory exists
|
||||||
|
data_path = Path(args.data_dir)
|
||||||
|
if not data_path.exists():
|
||||||
|
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
}
|
||||||
|
if args.file_types:
|
||||||
|
reader_kwargs["required_exts"] = args.file_types
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} documents")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for document RAG
|
||||||
|
print("\n📄 Document RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What are the main techniques LEANN uses?'")
|
||||||
|
print("- 'What is the technique DLPM?'")
|
||||||
|
print("- 'Who does Elizabeth Bennet marry?'")
|
||||||
|
print("- 'What is the problem of developing pan gu model? (盘古大模型开发中遇到什么问题?)'")
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = DocumentRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
167
apps/email_data/LEANN_email_reader.py
Normal file
167
apps/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import email
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Recursively find all 'Messages' directories under the given root.
|
||||||
|
Returns a list of Path objects.
|
||||||
|
"""
|
||||||
|
if root is None:
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
root = os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
messages_dirs = []
|
||||||
|
for dirpath, _dirnames, _filenames in os.walk(root):
|
||||||
|
if os.path.basename(dirpath) == "Messages":
|
||||||
|
messages_dirs.append(Path(dirpath))
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, include_html: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_html: Whether to include HTML content in the email body (default: False)
|
||||||
|
"""
|
||||||
|
self.include_html = include_html
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
count = 0
|
||||||
|
total_files = 0
|
||||||
|
successful_files = 0
|
||||||
|
failed_files = 0
|
||||||
|
|
||||||
|
print(f"Starting to process directory: {input_dir}")
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
# Check if we've reached the max count (skip if max_count == -1)
|
||||||
|
if max_count > 0 and count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
total_files += 1
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get("Subject", "No Subject")
|
||||||
|
from_addr = msg.get("From", "Unknown")
|
||||||
|
to_addr = msg.get("To", "Unknown")
|
||||||
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/plain"
|
||||||
|
or part.get_content_type() == "text/html"
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/html"
|
||||||
|
and not self.include_html
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload = part.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body += payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding payload: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
payload = msg.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body = payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding single part payload: {e}")
|
||||||
|
body = ""
|
||||||
|
|
||||||
|
# Only create document if we have some content
|
||||||
|
if body.strip() or subject != "No Subject":
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[File]: {filename}
|
||||||
|
[From]: {from_addr}
|
||||||
|
[To]: {to_addr}
|
||||||
|
[Subject]: {subject}
|
||||||
|
[Date]: {date}
|
||||||
|
[EMAIL BODY Start]:
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No separate metadata - everything is in the text
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
successful_files += 1
|
||||||
|
|
||||||
|
# Print first few successful files for debugging
|
||||||
|
if successful_files <= 3:
|
||||||
|
print(
|
||||||
|
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("Processing summary:")
|
||||||
|
print(f" Total .emlx files found: {total_files}")
|
||||||
|
print(f" Successfully loaded: {successful_files}")
|
||||||
|
print(f" Failed to load: {failed_files}")
|
||||||
|
print(f" Final documents: {len(docs)}")
|
||||||
|
|
||||||
|
return docs
|
||||||
186
apps/email_data/email.py
Normal file
186
apps/email_data/email.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Contains simple parser for mbox files.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MboxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Extract messages from mailbox files.
|
||||||
|
Returns string including date, subject, sender, receiver and
|
||||||
|
content for each message.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
|
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
max_count: int = 0,
|
||||||
|
message_format: str = DEFAULT_MESSAGE_FORMAT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Init params."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup # noqa
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.max_count = max_count
|
||||||
|
self.message_format = message_format
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
file: Path,
|
||||||
|
extra_info: dict | None = None,
|
||||||
|
fs: AbstractFileSystem | None = None,
|
||||||
|
) -> list[Document]:
|
||||||
|
"""Parse file into string."""
|
||||||
|
# Import required libraries
|
||||||
|
import mailbox
|
||||||
|
from email.parser import BytesParser
|
||||||
|
from email.policy import default
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but MboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
results: list[str] = []
|
||||||
|
# Load file using mailbox
|
||||||
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
|
|
||||||
|
# Iterate through all messages
|
||||||
|
for _, _msg in enumerate(mbox):
|
||||||
|
try:
|
||||||
|
msg: mailbox.mboxMessage = _msg
|
||||||
|
# Parse multipart messages
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
ctype = part.get_content_type()
|
||||||
|
cdispo = str(part.get("Content-Disposition"))
|
||||||
|
if "attachment" in cdispo:
|
||||||
|
print(f"Attachment found: {part.get_filename()}")
|
||||||
|
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||||
|
content = part.get_payload(decode=True) # decode
|
||||||
|
break
|
||||||
|
# Get plain message payload for non-multipart messages
|
||||||
|
else:
|
||||||
|
content = msg.get_payload(decode=True)
|
||||||
|
|
||||||
|
# Parse message HTML content and remove unneeded whitespace
|
||||||
|
soup = BeautifulSoup(content)
|
||||||
|
stripped_content = " ".join(soup.get_text().split())
|
||||||
|
# Format message to include date, sender, receiver and subject
|
||||||
|
msg_string = self.message_format.format(
|
||||||
|
_date=msg["date"],
|
||||||
|
_from=msg["from"],
|
||||||
|
_to=msg["to"],
|
||||||
|
_subject=msg["subject"],
|
||||||
|
_content=stripped_content,
|
||||||
|
)
|
||||||
|
# Add message string to results
|
||||||
|
results.append(msg_string)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
|
||||||
|
|
||||||
|
# Increment counter and return if max count is met
|
||||||
|
i += 1
|
||||||
|
if self.max_count > 0 and i >= self.max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
return [Document(text=result, metadata=extra_info or {}) for result in results]
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxMboxReader(MboxReader):
|
||||||
|
"""
|
||||||
|
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||||
|
|
||||||
|
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||||
|
1. Reading .emlx files from a directory
|
||||||
|
2. Converting them to mbox format in memory
|
||||||
|
3. Using the parent MboxReader's parsing logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
directory: Path,
|
||||||
|
extra_info: dict | None = None,
|
||||||
|
fs: AbstractFileSystem | None = None,
|
||||||
|
) -> list[Document]:
|
||||||
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find all .emlx files in the directory
|
||||||
|
emlx_files = list(directory.glob("*.emlx"))
|
||||||
|
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||||
|
|
||||||
|
if not emlx_files:
|
||||||
|
logger.warning(f"No .emlx files found in {directory}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a temporary mbox file
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
||||||
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
|
# Convert .emlx files to mbox format
|
||||||
|
for emlx_file in emlx_files:
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx format: first line is length, rest is email content
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
|
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||||
|
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Close the temporary file so MboxReader can read it
|
||||||
|
temp_mbox.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the parent MboxReader's logic to parse the mbox file
|
||||||
|
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||||
|
finally:
|
||||||
|
# Clean up temporary file
|
||||||
|
try:
|
||||||
|
os.unlink(temp_mbox_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
156
apps/email_rag.py
Normal file
156
apps/email_rag.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
Email RAG example using the unified interface.
|
||||||
|
Supports Apple Mail on macOS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||||
|
|
||||||
|
from .email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Apple Mail processing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Process all emails by default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Email",
|
||||||
|
description="Process and query Apple Mail emails with LEANN",
|
||||||
|
default_index_name="mail_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add email-specific arguments."""
|
||||||
|
email_group = parser.add_argument_group("Email Parameters")
|
||||||
|
email_group.add_argument(
|
||||||
|
"--mail-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_mail_directories(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Apple Mail directories."""
|
||||||
|
mail_base = Path.home() / "Library" / "Mail"
|
||||||
|
if not mail_base.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find all Messages directories
|
||||||
|
messages_dirs = []
|
||||||
|
for item in mail_base.rglob("Messages"):
|
||||||
|
if item.is_dir():
|
||||||
|
messages_dirs.append(item)
|
||||||
|
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load emails and convert to text chunks."""
|
||||||
|
# Determine mail directories
|
||||||
|
if args.mail_path:
|
||||||
|
messages_dirs = [Path(args.mail_path)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Apple Mail directories...")
|
||||||
|
messages_dirs = self._find_mail_directories()
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Apple Mail directories found!")
|
||||||
|
print("Please specify --mail-path manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(messages_dirs)} mail directories")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = EmlxReader(include_html=args.include_html)
|
||||||
|
|
||||||
|
# Process each directory
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Count emlx files
|
||||||
|
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||||
|
print(f"Found {len(emlx_files)} email files")
|
||||||
|
|
||||||
|
# Apply max_items limit per directory
|
||||||
|
max_per_dir = -1 # Default to process all
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_dir = remaining
|
||||||
|
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||||
|
|
||||||
|
# Load emails - fix the parameter passing
|
||||||
|
documents = reader.load_data(
|
||||||
|
input_dir=str(messages_dir),
|
||||||
|
max_count=max_per_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} emails from this directory")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No emails found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
# Email reader uses chunk_overlap=25 as in original
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||||
|
print(" Windows/Linux support coming soon!\n")
|
||||||
|
|
||||||
|
# Example queries for email RAG
|
||||||
|
print("\n📧 Email RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What did my boss say about deadlines?'")
|
||||||
|
print("- 'Find emails about travel expenses'")
|
||||||
|
print("- 'Show me emails from last month about the project'")
|
||||||
|
print("- 'What food did I order from DoorDash?'")
|
||||||
|
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||||
|
|
||||||
|
rag = EmailRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
3
apps/history_data/__init__.py
Normal file
3
apps/history_data/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .history import ChromeHistoryReader
|
||||||
|
|
||||||
|
__all__ = ["ChromeHistoryReader"]
|
||||||
186
apps/history_data/history.py
Normal file
186
apps/history_data/history.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class ChromeHistoryReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||||
|
|
||||||
|
Reads Chrome history from the default Chrome profile location and creates documents
|
||||||
|
with embedded metadata similar to the email reader structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load Chrome history data from the default Chrome profile location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Not used for Chrome history (kept for compatibility)
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of history entries to read.
|
||||||
|
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
||||||
|
|
||||||
|
# Default Chrome profile path on macOS
|
||||||
|
if chrome_profile_path is None:
|
||||||
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
|
|
||||||
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
|
if not os.path.exists(history_db_path):
|
||||||
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect to the Chrome history database
|
||||||
|
print(f"Connecting to database: {history_db_path}")
|
||||||
|
conn = sqlite3.connect(history_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Query to get browsing history with metadata (removed created_time column)
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
|
url,
|
||||||
|
title,
|
||||||
|
visit_count,
|
||||||
|
typed_count,
|
||||||
|
hidden
|
||||||
|
FROM urls
|
||||||
|
ORDER BY last_visit_time DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(f"Executing query on database: {history_db_path}")
|
||||||
|
cursor.execute(query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
print(f"Query returned {len(rows)} rows")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for row in rows:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[Title]: {title}
|
||||||
|
[URL of the page]: {url}
|
||||||
|
[Last visited time]: {last_visit}
|
||||||
|
[Visit times]: {visit_count}
|
||||||
|
[Typed times]: {typed_count}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create document with embedded metadata
|
||||||
|
doc = Document(text=doc_content, metadata={"title": title[0:150]})
|
||||||
|
# if len(title) > 150:
|
||||||
|
# print(f"Title is too long: {title}")
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
print(f"Loaded {len(docs)} Chrome history documents")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading Chrome history: {e}")
|
||||||
|
# add you may need to close your browser to make the database file available
|
||||||
|
# also highlight in red
|
||||||
|
print(
|
||||||
|
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_chrome_profiles() -> list[Path]:
|
||||||
|
"""
|
||||||
|
Find all Chrome profile directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to Chrome profile directories
|
||||||
|
"""
|
||||||
|
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
||||||
|
profile_dirs = []
|
||||||
|
|
||||||
|
if not chrome_base_path.exists():
|
||||||
|
print(f"Chrome directory not found at: {chrome_base_path}")
|
||||||
|
return profile_dirs
|
||||||
|
|
||||||
|
# Find all profile directories
|
||||||
|
for profile_dir in chrome_base_path.iterdir():
|
||||||
|
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
||||||
|
history_path = profile_dir / "History"
|
||||||
|
if history_path.exists():
|
||||||
|
profile_dirs.append(profile_dir)
|
||||||
|
print(f"Found Chrome profile: {profile_dir}")
|
||||||
|
|
||||||
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
|
return profile_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def export_history_to_file(
|
||||||
|
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Export Chrome history to a text file using the same SQL query format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Path to the output file
|
||||||
|
max_count: Maximum number of entries to export
|
||||||
|
"""
|
||||||
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
|
if not os.path.exists(history_db_path):
|
||||||
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(history_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
|
url,
|
||||||
|
title,
|
||||||
|
visit_count,
|
||||||
|
typed_count,
|
||||||
|
hidden
|
||||||
|
FROM urls
|
||||||
|
ORDER BY last_visit_time DESC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute(query, (max_count,))
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
for row in rows:
|
||||||
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
|
f.write(
|
||||||
|
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting Chrome history: {e}")
|
||||||
774
apps/history_data/wechat_history.py
Normal file
774
apps/history_data/wechat_history.py
Normal file
@@ -0,0 +1,774 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class WeChatHistoryReader(BaseReader):
|
||||||
|
"""
|
||||||
|
WeChat chat history reader that extracts chat data from exported JSON files.
|
||||||
|
|
||||||
|
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
||||||
|
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
||||||
|
|
||||||
|
Also includes utilities for automatic WeChat chat history export.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
||||||
|
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
||||||
|
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
||||||
|
|
||||||
|
def check_wechat_running(self) -> bool:
|
||||||
|
"""Check if WeChat is currently running."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
|
||||||
|
return result.returncode == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def install_wechattweak(self) -> bool:
|
||||||
|
"""Install WeChatTweak CLI tool."""
|
||||||
|
try:
|
||||||
|
# Create wechat-exporter directory if it doesn't exist
|
||||||
|
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||||
|
if not wechattweak_path.exists():
|
||||||
|
print("Downloading WeChatTweak CLI...")
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"curl",
|
||||||
|
"-L",
|
||||||
|
"-o",
|
||||||
|
str(wechattweak_path),
|
||||||
|
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make executable
|
||||||
|
wechattweak_path.chmod(0o755)
|
||||||
|
|
||||||
|
# Install WeChatTweak
|
||||||
|
print("Installing WeChatTweak...")
|
||||||
|
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error installing WeChatTweak: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def restart_wechat(self):
|
||||||
|
"""Restart WeChat to apply WeChatTweak."""
|
||||||
|
try:
|
||||||
|
print("Restarting WeChat...")
|
||||||
|
subprocess.run(["pkill", "-f", "WeChat"], check=False)
|
||||||
|
time.sleep(2)
|
||||||
|
subprocess.run(["open", "-a", "WeChat"], check=True)
|
||||||
|
time.sleep(5) # Wait for WeChat to start
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error restarting WeChat: {e}")
|
||||||
|
|
||||||
|
def check_api_available(self) -> bool:
|
||||||
|
"""Check if WeChatTweak API is available."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
return result.returncode == 0 and result.stdout.strip()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _extract_readable_text(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract readable text from message content, removing XML and system messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The raw message content (can be string or dict)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned, readable text
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Handle dictionary content (like quoted messages)
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Extract text from dictionary structure
|
||||||
|
text_parts = []
|
||||||
|
if "title" in content:
|
||||||
|
text_parts.append(str(content["title"]))
|
||||||
|
if "quoted" in content:
|
||||||
|
text_parts.append(str(content["quoted"]))
|
||||||
|
if "content" in content:
|
||||||
|
text_parts.append(str(content["content"]))
|
||||||
|
if "text" in content:
|
||||||
|
text_parts.append(str(content["text"]))
|
||||||
|
|
||||||
|
if text_parts:
|
||||||
|
return " | ".join(text_parts)
|
||||||
|
else:
|
||||||
|
# If we can't extract meaningful text from dict, return empty
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Handle string content
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Remove common prefixes like "wxid_xxx:\n"
|
||||||
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
|
# If it's just XML or system message, return empty
|
||||||
|
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return clean_content.strip()
|
||||||
|
|
||||||
|
def _is_text_message(self, content: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a message contains readable text content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The message content (can be string or dict)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the message contains readable text, False otherwise
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle dictionary content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Check if dict has any readable text fields
|
||||||
|
text_fields = ["title", "quoted", "content", "text"]
|
||||||
|
for field in text_fields:
|
||||||
|
if content.get(field):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle string content
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip image messages (contain XML with img tags)
|
||||||
|
if "<img" in content and "cdnurl" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip emoji messages (contain emoji XML tags)
|
||||||
|
if "<emoji" in content and "productid" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip voice messages
|
||||||
|
if "<voice" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip video messages
|
||||||
|
if "<video" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip file messages
|
||||||
|
if "<appmsg" in content and "appid" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip system messages (like "recalled a message")
|
||||||
|
if "recalled a message" in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if there's actual readable text (not just XML or system messages)
|
||||||
|
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||||
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
|
# If after cleaning we have meaningful text, consider it readable
|
||||||
|
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _concatenate_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
max_length: int = 128,
|
||||||
|
time_window_minutes: int = 30,
|
||||||
|
overlap_messages: int = 0,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Concatenate messages based on length and time rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||||
|
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||||
|
overlap_messages: Number of messages to overlap between consecutive groups
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of concatenated message groups
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
concatenated_groups = []
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
last_timestamp = None
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
# Extract message info
|
||||||
|
content = message.get("content", "")
|
||||||
|
message_text = message.get("message", "")
|
||||||
|
create_time = message.get("createTime", 0)
|
||||||
|
message.get("fromUser", "")
|
||||||
|
message.get("toUser", "")
|
||||||
|
message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
readable_text = message_text
|
||||||
|
|
||||||
|
# Skip empty messages
|
||||||
|
if not readable_text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check time window constraint (only if time_window_minutes != -1)
|
||||||
|
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||||
|
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||||
|
if time_diff_minutes > time_window_minutes:
|
||||||
|
# Time gap too large, start new group
|
||||||
|
if current_group:
|
||||||
|
concatenated_groups.append(
|
||||||
|
{
|
||||||
|
"messages": current_group,
|
||||||
|
"total_length": current_length,
|
||||||
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Keep last few messages for overlap
|
||||||
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
|
current_group = current_group[-overlap_messages:]
|
||||||
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
# Check length constraint (only if max_length != -1)
|
||||||
|
message_length = len(readable_text)
|
||||||
|
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||||
|
# Current group would exceed max length, save it and start new
|
||||||
|
concatenated_groups.append(
|
||||||
|
{
|
||||||
|
"messages": current_group,
|
||||||
|
"total_length": current_length,
|
||||||
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Keep last few messages for overlap
|
||||||
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
|
current_group = current_group[-overlap_messages:]
|
||||||
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
# Add message to current group
|
||||||
|
current_group.append(message)
|
||||||
|
current_length += message_length
|
||||||
|
last_timestamp = create_time
|
||||||
|
|
||||||
|
# Add the last group if it exists
|
||||||
|
if current_group:
|
||||||
|
concatenated_groups.append(
|
||||||
|
{
|
||||||
|
"messages": current_group,
|
||||||
|
"total_length": current_length,
|
||||||
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return concatenated_groups
|
||||||
|
|
||||||
|
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_group: Dictionary containing messages and metadata
|
||||||
|
contact_name: Name of the contact
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted concatenated content
|
||||||
|
"""
|
||||||
|
messages = message_group["messages"]
|
||||||
|
start_time = message_group["start_time"]
|
||||||
|
end_time = message_group["end_time"]
|
||||||
|
|
||||||
|
# Format timestamps
|
||||||
|
if start_time:
|
||||||
|
try:
|
||||||
|
start_timestamp = datetime.fromtimestamp(start_time)
|
||||||
|
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
start_time_str = str(start_time)
|
||||||
|
else:
|
||||||
|
start_time_str = "Unknown"
|
||||||
|
|
||||||
|
if end_time:
|
||||||
|
try:
|
||||||
|
end_timestamp = datetime.fromtimestamp(end_time)
|
||||||
|
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
end_time_str = str(end_time)
|
||||||
|
else:
|
||||||
|
end_time_str = "Unknown"
|
||||||
|
|
||||||
|
# Build concatenated message content
|
||||||
|
message_parts = []
|
||||||
|
for message in messages:
|
||||||
|
content = message.get("content", "")
|
||||||
|
message_text = message.get("message", "")
|
||||||
|
create_time = message.get("createTime", 0)
|
||||||
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
readable_text = message_text
|
||||||
|
|
||||||
|
# Format individual message
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
# change to YYYY-MM-DD HH:MM:SS
|
||||||
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||||
|
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||||
|
|
||||||
|
concatenated_text = "\n".join(message_parts)
|
||||||
|
|
||||||
|
# Create final document content
|
||||||
|
doc_content = f"""
|
||||||
|
Contact: {contact_name}
|
||||||
|
Time Range: {start_time_str} - {end_time_str}
|
||||||
|
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||||
|
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
# TODO @yichuan give better format and rich info here!
|
||||||
|
doc_content = f"""
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
return doc_content, contact_name
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load WeChat chat history data from exported JSON files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing exported WeChat JSON files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of chat entries to read.
|
||||||
|
wechat_export_dir (str): Custom path to WeChat export directory.
|
||||||
|
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
|
||||||
|
concatenate_messages (bool): Whether to concatenate messages based on length rules.
|
||||||
|
max_length (int): Maximum length for concatenated message groups (default: 1000).
|
||||||
|
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||||
|
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||||
|
include_non_text = load_kwargs.get("include_non_text", False)
|
||||||
|
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||||
|
max_length = load_kwargs.get("max_length", 1000)
|
||||||
|
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||||
|
|
||||||
|
# Default WeChat export path
|
||||||
|
if wechat_export_dir is None:
|
||||||
|
wechat_export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
|
if not os.path.exists(wechat_export_dir):
|
||||||
|
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Find all JSON files in the export directory
|
||||||
|
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
||||||
|
print(f"Found {len(json_files)} WeChat chat history files")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for json_file in json_files:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file, encoding="utf-8") as f:
|
||||||
|
chat_data = json.load(f)
|
||||||
|
|
||||||
|
# Extract contact name from filename
|
||||||
|
contact_name = json_file.stem
|
||||||
|
|
||||||
|
if concatenate_messages:
|
||||||
|
# Filter messages to only include readable text messages
|
||||||
|
readable_messages = []
|
||||||
|
for message in chat_data:
|
||||||
|
try:
|
||||||
|
content = message.get("content", "")
|
||||||
|
if not include_non_text and not self._is_text_message(content):
|
||||||
|
continue
|
||||||
|
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text and not include_non_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
readable_messages.append(message)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Concatenate messages based on rules
|
||||||
|
message_groups = self._concatenate_messages(
|
||||||
|
readable_messages,
|
||||||
|
max_length=max_length,
|
||||||
|
time_window_minutes=time_window_minutes,
|
||||||
|
overlap_messages=0, # No overlap between groups
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create documents from concatenated groups
|
||||||
|
for message_group in message_groups:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
doc_content, contact_name = self._create_concatenated_content(
|
||||||
|
message_group, contact_name
|
||||||
|
)
|
||||||
|
doc = Document(
|
||||||
|
text=doc_content,
|
||||||
|
metadata={"contact_name": contact_name},
|
||||||
|
)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Original single-message processing
|
||||||
|
for message in chat_data:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract message information
|
||||||
|
message.get("fromUser", "")
|
||||||
|
message.get("toUser", "")
|
||||||
|
content = message.get("content", "")
|
||||||
|
message_text = message.get("message", "")
|
||||||
|
create_time = message.get("createTime", 0)
|
||||||
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
|
# Handle content that might be dict or string
|
||||||
|
try:
|
||||||
|
# Check if this is a readable text message
|
||||||
|
if not include_non_text and not self._is_text_message(content):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text and not include_non_text:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
# Skip messages that cause processing errors
|
||||||
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert timestamp to readable format
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
# Create document content with metadata header and contact info
|
||||||
|
doc_content = f"""
|
||||||
|
Contact: {contact_name}
|
||||||
|
Is sent from self: {is_sent_from_self}
|
||||||
|
Time: {time_str}
|
||||||
|
Message: {readable_text if readable_text else message_text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create document with embedded metadata
|
||||||
|
doc = Document(
|
||||||
|
text=doc_content, metadata={"contact_name": contact_name}
|
||||||
|
)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} WeChat chat documents")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading WeChat history: {e}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_wechat_export_dirs() -> list[Path]:
|
||||||
|
"""
|
||||||
|
Find all WeChat export directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to WeChat export directories
|
||||||
|
"""
|
||||||
|
export_dirs = []
|
||||||
|
|
||||||
|
# Look for common export directory names
|
||||||
|
possible_dirs = [
|
||||||
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_export_direct"),
|
||||||
|
Path("./wechat_chat_history"),
|
||||||
|
Path("./chat_export"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for export_dir in possible_dirs:
|
||||||
|
if export_dir.exists() and export_dir.is_dir():
|
||||||
|
json_files = list(export_dir.glob("*.json"))
|
||||||
|
if json_files:
|
||||||
|
export_dirs.append(export_dir)
|
||||||
|
print(
|
||||||
|
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||||
|
return export_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def export_chat_to_file(
|
||||||
|
output_file: str = "wechat_chat_export.txt",
|
||||||
|
max_count: int = 1000,
|
||||||
|
export_dir: str | None = None,
|
||||||
|
include_non_text: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Export WeChat chat history to a text file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Path to the output file
|
||||||
|
max_count: Maximum number of entries to export
|
||||||
|
export_dir: Directory containing WeChat JSON files
|
||||||
|
include_non_text: Whether to include non-text messages
|
||||||
|
"""
|
||||||
|
if export_dir is None:
|
||||||
|
export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
|
if not os.path.exists(export_dir):
|
||||||
|
print(f"WeChat export directory not found at: {export_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_files = list(Path(export_dir).glob("*.json"))
|
||||||
|
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
count = 0
|
||||||
|
for json_file in json_files:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file, encoding="utf-8") as json_f:
|
||||||
|
chat_data = json.load(json_f)
|
||||||
|
|
||||||
|
contact_name = json_file.stem
|
||||||
|
f.write(f"\n=== Chat with {contact_name} ===\n")
|
||||||
|
|
||||||
|
for message in chat_data:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
from_user = message.get("fromUser", "")
|
||||||
|
content = message.get("content", "")
|
||||||
|
message_text = message.get("message", "")
|
||||||
|
create_time = message.get("createTime", 0)
|
||||||
|
|
||||||
|
# Skip non-text messages unless requested
|
||||||
|
if not include_non_text:
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
if not reader._is_text_message(content):
|
||||||
|
continue
|
||||||
|
readable_text = reader._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
continue
|
||||||
|
message_text = readable_text
|
||||||
|
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Exported {count} chat entries to {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting WeChat chat history: {e}")
|
||||||
|
|
||||||
|
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
|
||||||
|
"""
|
||||||
|
Export WeChat chat history using wechat-exporter tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Directory to save exported chat history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to export directory if successful, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_path = Path(export_dir)
|
||||||
|
export_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Exporting WeChat chat history to {export_path}...")
|
||||||
|
|
||||||
|
# Check if wechat-exporter directory exists
|
||||||
|
if not self.wechat_exporter_dir.exists():
|
||||||
|
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Install requirements if needed
|
||||||
|
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||||
|
if requirements_file.exists():
|
||||||
|
print("Installing wechat-exporter requirements...")
|
||||||
|
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
|
||||||
|
|
||||||
|
# Run the export command
|
||||||
|
print("Running wechat-exporter...")
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(self.wechat_exporter_dir / "main.py"),
|
||||||
|
"export-all",
|
||||||
|
str(export_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Export command output:")
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("Export errors:")
|
||||||
|
print(result.stderr)
|
||||||
|
|
||||||
|
# Check if export was successful
|
||||||
|
if export_path.exists() and any(export_path.glob("*.json")):
|
||||||
|
json_files = list(export_path.glob("*.json"))
|
||||||
|
print(
|
||||||
|
f"Successfully exported {len(json_files)} chat history files to {export_path}"
|
||||||
|
)
|
||||||
|
return export_path
|
||||||
|
else:
|
||||||
|
print("Export completed but no JSON files found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Export command failed: {e}")
|
||||||
|
print(f"Command output: {e.stdout}")
|
||||||
|
print(f"Command errors: {e.stderr}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Export failed: {e}")
|
||||||
|
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
|
||||||
|
"""
|
||||||
|
Find existing WeChat exports or create new ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Directory to save exported chat history if needed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to WeChat export directories
|
||||||
|
"""
|
||||||
|
export_dirs = []
|
||||||
|
|
||||||
|
# Look for existing exports in common locations
|
||||||
|
possible_export_dirs = [
|
||||||
|
Path("./wechat_database_export"),
|
||||||
|
Path("./wechat_export_test"),
|
||||||
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_export_direct"),
|
||||||
|
Path("./wechat_chat_history"),
|
||||||
|
Path("./chat_export"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for export_dir_path in possible_export_dirs:
|
||||||
|
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
||||||
|
export_dirs.append(export_dir_path)
|
||||||
|
print(f"Found existing export: {export_dir_path}")
|
||||||
|
|
||||||
|
# If no existing exports, try to export automatically
|
||||||
|
if not export_dirs:
|
||||||
|
print("No existing WeChat exports found. Starting direct export...")
|
||||||
|
|
||||||
|
# Try to export using wechat-exporter
|
||||||
|
exported_path = self.export_wechat_chat_history(export_dir)
|
||||||
|
if exported_path:
|
||||||
|
export_dirs = [exported_path]
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
|
||||||
|
)
|
||||||
|
|
||||||
|
return export_dirs
|
||||||
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
WeChat History RAG example using the unified interface.
|
||||||
|
Supports WeChat chat history export and search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
|
||||||
|
from .history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class WeChatRAG(BaseRAGExample):
|
||||||
|
"""RAG example for WeChat chat history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Match original default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="WeChat History",
|
||||||
|
description="Process and query WeChat chat history with LEANN",
|
||||||
|
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add WeChat-specific arguments."""
|
||||||
|
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--export-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_export",
|
||||||
|
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||||
|
"""Export WeChat data using wechattweak-cli."""
|
||||||
|
print("Exporting WeChat data...")
|
||||||
|
|
||||||
|
# Check if WeChat is running
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
print("WeChat is not running. Please start WeChat first.")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
pass # pgrep might not be available on all systems
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Run export command
|
||||||
|
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Running: {' '.join(cmd)}")
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print("WeChat data exported successfully!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"Export failed: {result.stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("\nError: wechattweak-cli not found!")
|
||||||
|
print("Please install it first:")
|
||||||
|
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Export error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load WeChat history and convert to text chunks."""
|
||||||
|
# Initialize WeChat reader with export capabilities
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
# Find existing exports or create new ones using the centralized method
|
||||||
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
if not export_dirs:
|
||||||
|
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||||
|
# Try to find any existing exports in common locations
|
||||||
|
export_dirs = reader.find_wechat_export_dirs()
|
||||||
|
if not export_dirs:
|
||||||
|
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Load documents from all found export directories
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, export_dir in enumerate(export_dirs):
|
||||||
|
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per export
|
||||||
|
max_per_export = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_export = remaining
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=str(export_dir),
|
||||||
|
max_count=max_per_export,
|
||||||
|
concatenate_messages=True, # Enable message concatenation for better context
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {export_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks with contact information
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
text_splitter = SentenceSplitter(
|
||||||
|
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
# Add contact information to each chunk
|
||||||
|
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||||
|
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||||
|
print(" You can still query existing exports on other platforms\n")
|
||||||
|
|
||||||
|
# Example queries for WeChat RAG
|
||||||
|
print("\n💬 WeChat History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'Show me conversations about travel plans'")
|
||||||
|
print("- 'Find group chats about weekend activities'")
|
||||||
|
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||||
|
print("- 'What did we discuss about the project last month?'")
|
||||||
|
print("\nNote: WeChat must be running for export to work\n")
|
||||||
|
|
||||||
|
rag = WeChatRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
BIN
assets/arch.png
Normal file
BIN
assets/arch.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
BIN
assets/effects.png
Normal file
BIN
assets/effects.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 339 KiB |
BIN
assets/logo-text.png
Normal file
BIN
assets/logo-text.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 818 KiB |
BIN
assets/logo.png
Normal file
BIN
assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 276 KiB |
141
benchmarks/benchmark_embeddings.py
Normal file
141
benchmarks/benchmark_embeddings.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mlx_lm import load
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||||
|
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||||
|
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||||
|
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||||
|
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||||
|
|
||||||
|
# --- Generate Dummy Data ---
|
||||||
|
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||||
|
|
||||||
|
# --- Benchmark Functions ---b
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_torch(model, sentences):
|
||||||
|
start_time = time.time()
|
||||||
|
model.encode(sentences, convert_to_numpy=True)
|
||||||
|
end_time = time.time()
|
||||||
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_mlx(model, tokenizer, sentences):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Tokenize sentences using MLX tokenizer
|
||||||
|
tokens = []
|
||||||
|
for sentence in sentences:
|
||||||
|
token_ids = tokenizer.encode(sentence)
|
||||||
|
tokens.append(token_ids)
|
||||||
|
|
||||||
|
# Pad sequences to the same length
|
||||||
|
max_len = max(len(t) for t in tokens)
|
||||||
|
input_ids = []
|
||||||
|
attention_mask = []
|
||||||
|
|
||||||
|
for token_seq in tokens:
|
||||||
|
# Pad sequence
|
||||||
|
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
|
||||||
|
input_ids.append(padded)
|
||||||
|
# Create attention mask (1 for real tokens, 0 for padding)
|
||||||
|
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
|
||||||
|
attention_mask.append(mask)
|
||||||
|
|
||||||
|
# Convert to MLX arrays
|
||||||
|
input_ids = mx.array(input_ids)
|
||||||
|
attention_mask = mx.array(attention_mask)
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
embeddings = model(input_ids)
|
||||||
|
|
||||||
|
# Mean pooling
|
||||||
|
mask = mx.expand_dims(attention_mask, -1)
|
||||||
|
sum_embeddings = (embeddings * mask).sum(axis=1)
|
||||||
|
sum_mask = mask.sum(axis=1)
|
||||||
|
_ = sum_embeddings / sum_mask
|
||||||
|
|
||||||
|
mx.eval() # Ensure computation is finished
|
||||||
|
end_time = time.time()
|
||||||
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Execution ---
|
||||||
|
def main():
|
||||||
|
print("--- Initializing Models ---")
|
||||||
|
# Load PyTorch model
|
||||||
|
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
|
||||||
|
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
|
||||||
|
print(f"PyTorch model loaded on: {device}")
|
||||||
|
|
||||||
|
# Load MLX model
|
||||||
|
print(f"Loading MLX model: {MODEL_NAME_MLX}")
|
||||||
|
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
|
||||||
|
print("MLX model loaded.")
|
||||||
|
|
||||||
|
# --- Warm-up ---
|
||||||
|
print("\n--- Performing Warm-up Runs ---")
|
||||||
|
for _ in range(WARMUP_RUNS):
|
||||||
|
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
|
||||||
|
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
|
||||||
|
print("Warm-up complete.")
|
||||||
|
|
||||||
|
# --- Benchmarking ---
|
||||||
|
print("\n--- Starting Benchmark ---")
|
||||||
|
results_torch = []
|
||||||
|
results_mlx = []
|
||||||
|
|
||||||
|
for batch_size in BATCH_SIZES:
|
||||||
|
print(f"Benchmarking batch size: {batch_size}")
|
||||||
|
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
||||||
|
|
||||||
|
# Benchmark PyTorch
|
||||||
|
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
||||||
|
results_torch.append(np.mean(torch_times))
|
||||||
|
|
||||||
|
# Benchmark MLX
|
||||||
|
mlx_times = [
|
||||||
|
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
||||||
|
]
|
||||||
|
results_mlx.append(np.mean(mlx_times))
|
||||||
|
|
||||||
|
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||||
|
print(f"Batch Sizes: {BATCH_SIZES}")
|
||||||
|
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
|
||||||
|
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
|
||||||
|
|
||||||
|
# --- Plotting ---
|
||||||
|
print("\n--- Generating Plot ---")
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(
|
||||||
|
BATCH_SIZES,
|
||||||
|
results_torch,
|
||||||
|
marker="o",
|
||||||
|
linestyle="-",
|
||||||
|
label=f"PyTorch ({device})",
|
||||||
|
)
|
||||||
|
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||||
|
|
||||||
|
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||||
|
plt.xlabel("Batch Size")
|
||||||
|
plt.ylabel("Average Time per Batch (ms)")
|
||||||
|
plt.xticks(BATCH_SIZES)
|
||||||
|
plt.grid(True)
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
# Save the plot
|
||||||
|
output_filename = "embedding_benchmark.png"
|
||||||
|
plt.savefig(output_filename)
|
||||||
|
print(f"Plot saved to {output_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
326
benchmarks/compare_faiss_vs_leann.py
Normal file
326
benchmarks/compare_faiss_vs_leann.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
"""Get current memory usage in MB"""
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
|
def print_memory_stats(stage: str, start_mem: float):
|
||||||
|
"""Print memory statistics"""
|
||||||
|
current_mem = get_memory_usage()
|
||||||
|
diff = current_mem - start_mem
|
||||||
|
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTracker:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.start_mem = get_memory_usage()
|
||||||
|
self.stages = []
|
||||||
|
|
||||||
|
def checkpoint(self, stage: str):
|
||||||
|
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
|
||||||
|
self.stages.append((stage, current_mem))
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
print(f"\n=== {self.name} Memory Summary ===")
|
||||||
|
for stage, mem in self.stages:
|
||||||
|
print(f"{stage}: {mem:.1f} MB")
|
||||||
|
peak_mem = max(mem for _, mem in self.stages)
|
||||||
|
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||||
|
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
|
||||||
|
return peak_mem
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_hnsw():
|
||||||
|
"""Test Faiss HNSW Vector Store in subprocess"""
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("TESTING FAISS HNSW VECTOR STORE")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[sys.executable, "benchmarks/faiss_only.py"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("Stderr:", result.stderr)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return {
|
||||||
|
"peak_memory": float("inf"),
|
||||||
|
"error": f"Process failed with code {result.returncode}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse peak memory from output
|
||||||
|
lines = result.stdout.split("\n")
|
||||||
|
peak_memory = 0.0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if "Peak Memory:" in line:
|
||||||
|
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||||
|
|
||||||
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"peak_memory": float("inf"),
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_leann_hnsw():
|
||||||
|
"""Test LEANN HNSW Search Memory (load existing index)"""
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("TESTING LEANN HNSW SEARCH MEMORY")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
tracker = MemoryTracker("LEANN HNSW Search")
|
||||||
|
|
||||||
|
# Import and setup
|
||||||
|
tracker.checkpoint("Initial")
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
tracker.checkpoint("After imports")
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
# Load and parse documents
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
"data",
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data()
|
||||||
|
|
||||||
|
tracker.checkpoint("After document loading")
|
||||||
|
|
||||||
|
# Parse into chunks
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
print(f"Total number of chunks: {len(all_texts)}")
|
||||||
|
|
||||||
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
|
# Build LEANN index
|
||||||
|
INDEX_DIR = Path("./test_leann_comparison")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
|
||||||
|
|
||||||
|
# Check if index already exists
|
||||||
|
if os.path.exists(INDEX_PATH + ".meta.json"):
|
||||||
|
print("Loading existing LEANN HNSW index...")
|
||||||
|
tracker.checkpoint("After loading existing index")
|
||||||
|
else:
|
||||||
|
print("Building new LEANN HNSW index...")
|
||||||
|
# Clean up previous index
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if INDEX_DIR.exists():
|
||||||
|
shutil.rmtree(INDEX_DIR)
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.checkpoint("After builder setup")
|
||||||
|
|
||||||
|
print("Building LEANN HNSW index...")
|
||||||
|
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
|
# Find existing LEANN index
|
||||||
|
index_paths = [
|
||||||
|
"./test_leann_comparison/comparison.leann",
|
||||||
|
]
|
||||||
|
index_path = None
|
||||||
|
for path in index_paths:
|
||||||
|
if os.path.exists(path + ".meta.json"):
|
||||||
|
index_path = path
|
||||||
|
break
|
||||||
|
|
||||||
|
if not index_path:
|
||||||
|
print("❌ LEANN index not found. Please build it first")
|
||||||
|
return {"peak_memory": float("inf"), "error": "Index not found"}
|
||||||
|
|
||||||
|
# Measure runtime memory overhead
|
||||||
|
print("\nMeasuring runtime memory overhead...")
|
||||||
|
runtime_start_mem = get_memory_usage()
|
||||||
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
|
# Load searcher
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
tracker.checkpoint("After searcher loading")
|
||||||
|
|
||||||
|
print("Running search queries...")
|
||||||
|
queries = [
|
||||||
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
|
"What is LEANN and how does it work?",
|
||||||
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
start_time = time.time()
|
||||||
|
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
|
||||||
|
_ = searcher.search(query, top_k=20, ef=120)
|
||||||
|
query_time = time.time() - start_time
|
||||||
|
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||||
|
tracker.checkpoint(f"After query {i + 1}")
|
||||||
|
|
||||||
|
runtime_end_mem = get_memory_usage()
|
||||||
|
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||||
|
|
||||||
|
peak_memory = tracker.summary()
|
||||||
|
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||||
|
|
||||||
|
# Get storage size before cleanup
|
||||||
|
storage_size = 0
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
if INDEX_DIR.exists():
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
|
||||||
|
for filename in filenames:
|
||||||
|
# Only count actual index files, skip text data and backups
|
||||||
|
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
|
||||||
|
continue
|
||||||
|
# Count .index, .idx, .map files (actual index structures)
|
||||||
|
if filename.endswith((".index", ".idx", ".map")):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
del searcher
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"peak_memory": peak_memory,
|
||||||
|
"storage_size": storage_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run comparison tests"""
|
||||||
|
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Test Faiss HNSW
|
||||||
|
faiss_results = test_faiss_hnsw()
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Test LEANN HNSW
|
||||||
|
leann_results = test_leann_hnsw()
|
||||||
|
|
||||||
|
# Final comparison
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("STORAGE + SEARCH MEMORY COMPARISON")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Get storage sizes
|
||||||
|
faiss_storage_size = 0
|
||||||
|
leann_storage_size = leann_results.get("storage_size", 0)
|
||||||
|
|
||||||
|
# Get Faiss storage size using Python
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, _, filenames in os.walk("./storage_faiss"):
|
||||||
|
for filename in filenames:
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
print("Faiss HNSW:")
|
||||||
|
if "error" in faiss_results:
|
||||||
|
print(f" ❌ Failed: {faiss_results['error']}")
|
||||||
|
else:
|
||||||
|
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {faiss_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
print("\nLEANN HNSW:")
|
||||||
|
if "error" in leann_results:
|
||||||
|
print(f" ❌ Failed: {leann_results['error']}")
|
||||||
|
else:
|
||||||
|
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
# Calculate improvements only if both tests succeeded
|
||||||
|
if "error" not in faiss_results and "error" not in leann_results:
|
||||||
|
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
|
||||||
|
|
||||||
|
print("\nLEANN vs Faiss Performance:")
|
||||||
|
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||||
|
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||||
|
|
||||||
|
# Storage comparison
|
||||||
|
if leann_storage_size > faiss_storage_size:
|
||||||
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
|
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||||
|
elif faiss_storage_size > leann_storage_size:
|
||||||
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
|
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||||
|
else:
|
||||||
|
print(" Storage Size: similar")
|
||||||
|
else:
|
||||||
|
if "error" not in leann_results:
|
||||||
|
print("\n✅ LEANN HNSW completed successfully!")
|
||||||
|
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
if "error" not in faiss_results:
|
||||||
|
print("\n✅ Faiss HNSW completed successfully!")
|
||||||
|
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||||
|
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
151
benchmarks/faiss_only.py
Normal file
151
benchmarks/faiss_only.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Test only Faiss HNSW"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTracker:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.start_mem = get_memory_usage()
|
||||||
|
self.stages = []
|
||||||
|
|
||||||
|
def checkpoint(self, stage: str):
|
||||||
|
current_mem = get_memory_usage()
|
||||||
|
diff = current_mem - self.start_mem
|
||||||
|
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||||
|
self.stages.append((stage, current_mem))
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
peak_mem = max(mem for _, mem in self.stages)
|
||||||
|
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||||
|
return peak_mem
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
except ImportError:
|
||||||
|
print("Faiss is not installed.")
|
||||||
|
print(
|
||||||
|
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
from llama_index.core import (
|
||||||
|
Settings,
|
||||||
|
SimpleDirectoryReader,
|
||||||
|
StorageContext,
|
||||||
|
VectorStoreIndex,
|
||||||
|
)
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
|
tracker = MemoryTracker("Faiss HNSW")
|
||||||
|
tracker.checkpoint("Initial")
|
||||||
|
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||||
|
Settings.embed_model = embed_model
|
||||||
|
tracker.checkpoint("After embedding model setup")
|
||||||
|
|
||||||
|
d = 768
|
||||||
|
faiss_index = faiss.IndexHNSWFlat(d, 32)
|
||||||
|
faiss_index.hnsw.efConstruction = 64
|
||||||
|
tracker.checkpoint("After Faiss index creation")
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
"data",
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data()
|
||||||
|
tracker.checkpoint("After document loading")
|
||||||
|
|
||||||
|
# Parse into chunks using the same splitter as LEANN
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.checkpoint("After text splitter setup")
|
||||||
|
|
||||||
|
# Check if index already exists and try to load it
|
||||||
|
index_loaded = False
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
print("Loading existing Faiss HNSW index...")
|
||||||
|
try:
|
||||||
|
# Use the correct Faiss loading pattern from the example
|
||||||
|
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
|
||||||
|
storage_context = StorageContext.from_defaults(
|
||||||
|
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||||
|
)
|
||||||
|
from llama_index.core import load_index_from_storage
|
||||||
|
|
||||||
|
index = load_index_from_storage(storage_context=storage_context)
|
||||||
|
print("Index loaded from ./storage_faiss")
|
||||||
|
tracker.checkpoint("After loading existing index")
|
||||||
|
index_loaded = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load existing index: {e}")
|
||||||
|
print("Cleaning up corrupted index and building new one...")
|
||||||
|
# Clean up corrupted index
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
shutil.rmtree("./storage_faiss")
|
||||||
|
|
||||||
|
if not index_loaded:
|
||||||
|
print("Building new Faiss HNSW index...")
|
||||||
|
|
||||||
|
# Use the correct Faiss building pattern from the example
|
||||||
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents, storage_context=storage_context, transformations=[node_parser]
|
||||||
|
)
|
||||||
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
|
# Save index to disk using the correct pattern
|
||||||
|
index.storage_context.persist(persist_dir="./storage_faiss")
|
||||||
|
tracker.checkpoint("After index saving")
|
||||||
|
|
||||||
|
# Measure runtime memory overhead
|
||||||
|
print("\nMeasuring runtime memory overhead...")
|
||||||
|
runtime_start_mem = get_memory_usage()
|
||||||
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||||
|
queries = [
|
||||||
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
|
"What is LEANN and how does it work?",
|
||||||
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
start_time = time.time()
|
||||||
|
_ = query_engine.query(query)
|
||||||
|
query_time = time.time() - start_time
|
||||||
|
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||||
|
tracker.checkpoint(f"After query {i + 1}")
|
||||||
|
|
||||||
|
runtime_end_mem = get_memory_usage()
|
||||||
|
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||||
|
|
||||||
|
peak_memory = tracker.summary()
|
||||||
|
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||||
|
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
391
research/micro/embedd_micro.py → benchmarks/micro_tpt.py
Executable file → Normal file
391
research/micro/embedd_micro.py → benchmarks/micro_tpt.py
Executable file → Normal file
@@ -2,21 +2,20 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchao import quantize_
|
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from contextlib import contextmanager
|
from transformers import AutoModel, BitsAndBytesConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str
|
model_path: str
|
||||||
batch_sizes: List[int]
|
batch_sizes: list[int]
|
||||||
seq_length: int
|
seq_length: int
|
||||||
num_runs: int
|
num_runs: int
|
||||||
use_fp16: bool = True
|
use_fp16: bool = True
|
||||||
@@ -27,46 +26,58 @@ class BenchmarkConfig:
|
|||||||
use_linear8bitlt: bool = False
|
use_linear8bitlt: bool = False
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphContainer:
|
class GraphContainer:
|
||||||
"""Container for managing CUDA graphs for different batch sizes."""
|
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, seq_length: int):
|
def __init__(self, model: nn.Module, seq_length: int):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
self.graphs: dict[int, GraphWrapper] = {}
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
def get_or_create(self, batch_size: int) -> "GraphWrapper":
|
||||||
if batch_size not in self.graphs:
|
if batch_size not in self.graphs:
|
||||||
self.graphs[batch_size] = CUDAGraphWrapper(
|
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
|
||||||
self.model, batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[batch_size]
|
return self.graphs[batch_size]
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphWrapper:
|
class GraphWrapper:
|
||||||
"""Wrapper for CUDA graph capture and replay."""
|
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.device = self._get_device()
|
||||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||||
|
|
||||||
# Warm up
|
# Warm up
|
||||||
self._warmup()
|
self._warmup()
|
||||||
|
|
||||||
# Capture graph
|
# Only use CUDA graphs on NVIDIA GPUs
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
|
||||||
with torch.cuda.graph(self.graph):
|
# Capture graph
|
||||||
self.static_output = self.model(
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
input_ids=self.static_input,
|
with torch.cuda.graph(self.graph):
|
||||||
attention_mask=self.static_attention_mask
|
self.static_output = self.model(
|
||||||
)
|
input_ids=self.static_input,
|
||||||
|
attention_mask=self.static_attention_mask,
|
||||||
|
)
|
||||||
|
self.use_cuda_graph = True
|
||||||
|
else:
|
||||||
|
# For MPS or CPU, just store the model
|
||||||
|
self.use_cuda_graph = False
|
||||||
|
self.static_output = None
|
||||||
|
|
||||||
|
def _get_device(self) -> str:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
else:
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000, (batch_size, seq_length),
|
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
def _warmup(self, num_warmup: int = 3):
|
||||||
@@ -74,14 +85,18 @@ class CUDAGraphWrapper:
|
|||||||
for _ in range(num_warmup):
|
for _ in range(num_warmup):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=self.static_input,
|
input_ids=self.static_input,
|
||||||
attention_mask=self.static_attention_mask
|
attention_mask=self.static_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
self.static_input.copy_(input_ids)
|
if self.use_cuda_graph:
|
||||||
self.static_attention_mask.copy_(attention_mask)
|
self.static_input.copy_(input_ids)
|
||||||
self.graph.replay()
|
self.static_attention_mask.copy_(attention_mask)
|
||||||
return self.static_output
|
self.graph.replay()
|
||||||
|
return self.static_output
|
||||||
|
else:
|
||||||
|
# For MPS/CPU, just run normally
|
||||||
|
return self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
class ModelOptimizer:
|
class ModelOptimizer:
|
||||||
@@ -95,8 +110,16 @@ class ModelOptimizer:
|
|||||||
raise ValueError("Cannot optimize None model")
|
raise ValueError("Cannot optimize None model")
|
||||||
|
|
||||||
# Move to GPU
|
# Move to GPU
|
||||||
model = model.cuda()
|
if torch.cuda.is_available():
|
||||||
print("- Model moved to GPU")
|
model = model.cuda()
|
||||||
|
device = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
model = model.to("mps")
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
model = model.cpu()
|
||||||
|
device = "cpu"
|
||||||
|
print(f"- Model moved to {device}")
|
||||||
|
|
||||||
# FP16
|
# FP16
|
||||||
if config.use_fp16 and not config.use_int4:
|
if config.use_fp16 and not config.use_int4:
|
||||||
@@ -105,17 +128,22 @@ class ModelOptimizer:
|
|||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
print("- Using FP16 precision")
|
print("- Using FP16 precision")
|
||||||
|
|
||||||
# Check if using SDPA
|
# Check if using SDPA (only on CUDA)
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
|
|
||||||
# Flash Attention
|
# Flash Attention (only on CUDA)
|
||||||
if config.use_flash_attention:
|
if config.use_flash_attention and torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attention import FlashAttention
|
from flash_attn.flash_attention import FlashAttention # noqa: F401
|
||||||
|
|
||||||
print("- Flash Attention 2 available")
|
print("- Flash Attention 2 available")
|
||||||
if hasattr(model.config, "attention_mode"):
|
if hasattr(model.config, "attention_mode"):
|
||||||
model.config.attention_mode = "flash_attention_2"
|
model.config.attention_mode = "flash_attention_2"
|
||||||
@@ -123,16 +151,18 @@ class ModelOptimizer:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
print("- Flash Attention not available")
|
print("- Flash Attention not available")
|
||||||
|
|
||||||
# Memory efficient attention
|
# Memory efficient attention (only on CUDA)
|
||||||
try:
|
if torch.cuda.is_available():
|
||||||
from xformers.ops import memory_efficient_attention
|
try:
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
print("- Enabled xformers memory efficient attention")
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
else:
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Model doesn't support xformers")
|
print("- Enabled xformers memory efficient attention")
|
||||||
except (ImportError, AttributeError):
|
else:
|
||||||
print("- Xformers not available")
|
print("- Model doesn't support xformers")
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
print("- Xformers not available")
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
print("- Model set to eval mode")
|
print("- Model set to eval mode")
|
||||||
@@ -141,21 +171,38 @@ class ModelOptimizer:
|
|||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
"""Handles accurate GPU timing using CUDA events."""
|
"""Handles accurate GPU timing using GPU events or CPU timing."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
if torch.cuda.is_available():
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
self.use_gpu_timing = True
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
# MPS doesn't have events, use CPU timing
|
||||||
|
self.use_gpu_timing = False
|
||||||
|
else:
|
||||||
|
# CPU timing
|
||||||
|
self.use_gpu_timing = False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def timing(self):
|
def timing(self):
|
||||||
self.start_event.record()
|
if self.use_gpu_timing:
|
||||||
yield
|
self.start_event.record()
|
||||||
self.end_event.record()
|
yield
|
||||||
self.end_event.synchronize()
|
self.end_event.record()
|
||||||
|
self.end_event.synchronize()
|
||||||
|
else:
|
||||||
|
# Use CPU timing for MPS/CPU
|
||||||
|
start_time = time.time()
|
||||||
|
yield
|
||||||
|
self.cpu_elapsed = time.time() - start_time
|
||||||
|
|
||||||
def elapsed_time(self) -> float:
|
def elapsed_time(self) -> float:
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
if self.use_gpu_timing:
|
||||||
|
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||||
|
else:
|
||||||
|
return self.cpu_elapsed
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
@@ -168,14 +215,14 @@ class Benchmark:
|
|||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise ValueError("Model initialization failed - model is None")
|
raise ValueError("Model initialization failed - model is None")
|
||||||
|
|
||||||
self.cuda_graphs = (
|
# Only use CUDA graphs on NVIDIA GPUs
|
||||||
CUDAGraphContainer(self.model, config.seq_length)
|
if config.use_cuda_graphs and torch.cuda.is_available():
|
||||||
if config.use_cuda_graphs
|
self.graphs = GraphContainer(self.model, config.seq_length)
|
||||||
else None
|
else:
|
||||||
)
|
self.graphs = None
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
print(f"ERROR in benchmark initialization: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
def _load_model(self) -> nn.Module:
|
||||||
@@ -185,15 +232,17 @@ class Benchmark:
|
|||||||
# Int4 quantization using HuggingFace integration
|
# Int4 quantization using HuggingFace integration
|
||||||
if self.config.use_int4:
|
if self.config.use_int4:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||||
|
|
||||||
# 检查是否使用自定义的8bit量化
|
# Check if using custom 8bit quantization
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||||
|
|
||||||
# 加载原始模型(不使用量化配置)
|
# Load original model (without quantization config)
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# set default to half
|
# set default to half
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||||
@@ -202,52 +251,58 @@ class Benchmark:
|
|||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义替换函数
|
# Define replacement function
|
||||||
def replace_linear_with_linear8bitlt(model):
|
def replace_linear_with_linear8bitlt(model):
|
||||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
|
||||||
for name, module in list(model.named_children()):
|
for name, module in list(model.named_children()):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
# 获取原始线性层的参数
|
# Get original linear layer parameters
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
out_features = module.out_features
|
out_features = module.out_features
|
||||||
bias = module.bias is not None
|
bias = module.bias is not None
|
||||||
|
|
||||||
# 创建8bit线性层
|
# Create 8bit linear layer
|
||||||
# print size
|
# print size
|
||||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||||
new_module = bnb.nn.Linear8bitLt(
|
new_module = bnb.nn.Linear8bitLt(
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
has_fp16_weights=False
|
has_fp16_weights=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 复制权重和偏置
|
# Copy weights and bias
|
||||||
new_module.weight.data = module.weight.data
|
new_module.weight.data = module.weight.data
|
||||||
if bias:
|
if bias:
|
||||||
new_module.bias.data = module.bias.data
|
new_module.bias.data = module.bias.data
|
||||||
|
|
||||||
# 替换模块
|
# Replace module
|
||||||
setattr(model, name, new_module)
|
setattr(model, name, new_module)
|
||||||
else:
|
else:
|
||||||
# 递归处理子模块
|
# Process child modules recursively
|
||||||
replace_linear_with_linear8bitlt(module)
|
replace_linear_with_linear8bitlt(module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# 替换所有线性层
|
# Replace all linear layers
|
||||||
model = replace_linear_with_linear8bitlt(model)
|
model = replace_linear_with_linear8bitlt(model)
|
||||||
# add torch compile
|
# add torch compile
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
# 将模型移到GPU(量化发生在这里)
|
# Move model to GPU (quantization happens here)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print("- All linear layers replaced with Linear8bitLt")
|
print("- All linear layers replaced with Linear8bitLt")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 使用原来的Int4量化方法
|
# Use original Int4 quantization method
|
||||||
print("- Using bitsandbytes for Int4 quantization")
|
print("- Using bitsandbytes for Int4 quantization")
|
||||||
|
|
||||||
# Create quantization config
|
# Create quantization config
|
||||||
@@ -257,7 +312,7 @@ class Benchmark:
|
|||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_compute_dtype=compute_dtype,
|
bnb_4bit_compute_dtype=compute_dtype,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4"
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("- Quantization config:", quantization_config)
|
print("- Quantization config:", quantization_config)
|
||||||
@@ -267,7 +322,7 @@ class Benchmark:
|
|||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
device_map="auto" # Let HF decide on device mapping
|
device_map="auto", # Let HF decide on device mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if model loaded successfully
|
# Check if model loaded successfully
|
||||||
@@ -279,7 +334,7 @@ class Benchmark:
|
|||||||
# Apply optimizations directly here
|
# Apply optimizations directly here
|
||||||
print("\nApplying model optimizations:")
|
print("\nApplying model optimizations:")
|
||||||
|
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||||
else:
|
else:
|
||||||
# Skip moving to GPU since device_map="auto" already did that
|
# Skip moving to GPU since device_map="auto" already did that
|
||||||
@@ -289,79 +344,55 @@ class Benchmark:
|
|||||||
print(f"- Using {compute_dtype} for compute dtype")
|
print(f"- Using {compute_dtype} for compute dtype")
|
||||||
|
|
||||||
# Check CUDA and SDPA
|
# Check CUDA and SDPA
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
|
|
||||||
# Try xformers if available
|
# Try xformers if available (only on CUDA)
|
||||||
try:
|
if torch.cuda.is_available():
|
||||||
from xformers.ops import memory_efficient_attention
|
try:
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
model.enable_xformers_memory_efficient_attention()
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Enabled xformers memory efficient attention")
|
print("- Enabled xformers memory efficient attention")
|
||||||
else:
|
else:
|
||||||
print("- Model doesn't support xformers")
|
print("- Model doesn't support xformers")
|
||||||
except (ImportError, AttributeError):
|
except (ImportError, AttributeError):
|
||||||
print("- Xformers not available")
|
print("- Xformers not available")
|
||||||
|
|
||||||
# Set to eval mode
|
# Set to eval mode
|
||||||
model.eval()
|
model.eval()
|
||||||
print("- Model set to eval mode")
|
print("- Model set to eval mode")
|
||||||
# Int8 quantization using HuggingFace integration
|
# Int8 quantization using HuggingFace integration
|
||||||
# Int8 quantization using TorchAO
|
|
||||||
elif self.config.use_int8:
|
elif self.config.use_int8:
|
||||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
print("- Using INT8 quantization")
|
||||||
|
# For now, just use standard loading with INT8 config
|
||||||
# Import the quantize_ function and the quantization config
|
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
quantization_config = BitsAndBytesConfig(
|
||||||
print("- Successfully imported TorchAO")
|
load_in_8bit=True,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
# Load model normally first
|
llm_int8_has_fp16_weight=False,
|
||||||
# set default to half
|
|
||||||
import torch
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
self.config.model_path,
|
|
||||||
device_map="auto"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print("- Model loaded in full precision")
|
model = AutoModel.from_pretrained(
|
||||||
|
self.config.model_path,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=compute_dtype,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("Model loading returned None")
|
||||||
|
|
||||||
print(f"- Model type: {type(model)}")
|
print(f"- Model type: {type(model)}")
|
||||||
|
|
||||||
# Apply quantization - call the function to get the config, then apply it
|
|
||||||
# quantize_(model, int8_dynamic_activation_int8_weight())
|
|
||||||
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
|
|
||||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
|
||||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
|
||||||
print("- Model successfully quantized with int8 weights and int8 activations")
|
|
||||||
# add torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
# For older PyTorch versions that have issues with tensor subclasses
|
|
||||||
from torchao.utils import unwrap_tensor_subclass
|
|
||||||
import torch
|
|
||||||
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
|
|
||||||
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
|
|
||||||
unwrap_tensor_subclass(model)
|
|
||||||
|
|
||||||
# Apply optimizations
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Set to eval mode
|
|
||||||
model.eval()
|
model.eval()
|
||||||
print("- Model set to eval mode")
|
print("- Model set to eval mode")
|
||||||
|
|
||||||
# For better performance with int8 dynamic quantization
|
|
||||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
|
||||||
print("- Enabled fusion of int matmul with mul operations")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Standard loading for FP16/FP32
|
# Standard loading for FP16/FP32
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
model = AutoModel.from_pretrained(self.config.model_path)
|
||||||
@@ -371,6 +402,7 @@ class Benchmark:
|
|||||||
# Apply standard optimizations
|
# Apply standard optimizations
|
||||||
# set default to half
|
# set default to half
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model = ModelOptimizer.optimize(model, self.config)
|
model = ModelOptimizer.optimize(model, self.config)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@@ -385,49 +417,60 @@ class Benchmark:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR loading model: {str(e)}")
|
print(f"ERROR loading model: {e!s}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000,
|
0,
|
||||||
|
1000,
|
||||||
(batch_size, self.config.seq_length),
|
(batch_size, self.config.seq_length),
|
||||||
device="cuda",
|
device=device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference(
|
def _run_inference(
|
||||||
self,
|
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
|
||||||
input_ids: torch.Tensor,
|
) -> tuple[float, torch.Tensor]:
|
||||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
with torch.no_grad(), self.timer.timing():
|
||||||
if cuda_graph_wrapper is not None:
|
if graph_wrapper is not None:
|
||||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
output = graph_wrapper(input_ids, attention_mask)
|
||||||
else:
|
else:
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
return self.timer.elapsed_time(), output
|
return self.timer.elapsed_time(), output
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
# Reset peak memory stats
|
# Reset peak memory stats
|
||||||
torch.cuda.reset_peak_memory_stats()
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
# MPS doesn't have reset_peak_memory_stats, skip it
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("- No GPU memory stats available")
|
||||||
|
|
||||||
for batch_size in self.config.batch_sizes:
|
for batch_size in self.config.batch_sizes:
|
||||||
print(f"\nTesting batch size: {batch_size}")
|
print(f"\nTesting batch size: {batch_size}")
|
||||||
times = []
|
times = []
|
||||||
|
|
||||||
# Get or create CUDA graph for this batch size
|
# Get or create graph for this batch size
|
||||||
cuda_graph_wrapper = (
|
graph_wrapper = (
|
||||||
self.cuda_graphs.get_or_create(batch_size)
|
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
|
||||||
if self.cuda_graphs is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
# Pre-allocate input tensor
|
||||||
@@ -437,7 +480,7 @@ class Benchmark:
|
|||||||
# Run benchmark
|
# Run benchmark
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||||
try:
|
try:
|
||||||
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
|
elapsed_time, output = self._run_inference(input_ids, graph_wrapper)
|
||||||
if i == 0: # Only print on first run
|
if i == 0: # Only print on first run
|
||||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
print(f"Output shape: {output.last_hidden_state.shape}")
|
||||||
times.append(elapsed_time)
|
times.append(elapsed_time)
|
||||||
@@ -464,8 +507,19 @@ class Benchmark:
|
|||||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
# Log memory usage
|
# Log memory usage
|
||||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
if torch.cuda.is_available():
|
||||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
# MPS doesn't have max_memory_allocated, use 0
|
||||||
|
peak_memory_gb = 0.0
|
||||||
|
else:
|
||||||
|
peak_memory_gb = 0.0
|
||||||
|
print("- No GPU memory usage available")
|
||||||
|
|
||||||
|
if peak_memory_gb > 0:
|
||||||
|
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
||||||
|
else:
|
||||||
|
print("\n- GPU memory usage not available")
|
||||||
|
|
||||||
# Add memory info to results
|
# Add memory info to results
|
||||||
for batch_size in results:
|
for batch_size in results:
|
||||||
@@ -485,7 +539,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_sizes",
|
"--batch_sizes",
|
||||||
type=str,
|
type=str,
|
||||||
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
|
default="1,2,4,8,16,32",
|
||||||
help="Comma-separated list of batch sizes",
|
help="Comma-separated list of batch sizes",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -518,12 +572,12 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_cuda_graphs",
|
"--use_cuda_graphs",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable CUDA Graphs optimization",
|
help="Enable CUDA Graphs optimization (only on NVIDIA GPUs)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_flash_attention",
|
"--use_flash_attention",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable Flash Attention 2 if available",
|
help="Enable Flash Attention 2 if available (only on NVIDIA GPUs)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_linear8bitlt",
|
"--use_linear8bitlt",
|
||||||
@@ -568,7 +622,15 @@ def main():
|
|||||||
os.makedirs("results", exist_ok=True)
|
os.makedirs("results", exist_ok=True)
|
||||||
|
|
||||||
# Generate filename based on configuration
|
# Generate filename based on configuration
|
||||||
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
|
precision_type = (
|
||||||
|
"int4"
|
||||||
|
if config.use_int4
|
||||||
|
else "int8"
|
||||||
|
if config.use_int8
|
||||||
|
else "fp16"
|
||||||
|
if config.use_fp16
|
||||||
|
else "fp32"
|
||||||
|
)
|
||||||
model_name = os.path.basename(config.model_path)
|
model_name = os.path.basename(config.model_path)
|
||||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||||
|
|
||||||
@@ -576,17 +638,20 @@ def main():
|
|||||||
with open(output_file, "w") as f:
|
with open(output_file, "w") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
"config": {
|
||||||
"results": {str(k): v for k, v in results.items()}
|
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
|
||||||
|
},
|
||||||
|
"results": {str(k): v for k, v in results.items()},
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
indent=2
|
indent=2,
|
||||||
)
|
)
|
||||||
print(f"Results saved to {output_file}")
|
print(f"Results saved to {output_file}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Benchmark failed: {e}")
|
print(f"Benchmark failed: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
359
benchmarks/run_evaluation.py
Normal file
359
benchmarks/run_evaluation.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
This script runs a recall evaluation on a given LEANN index.
|
||||||
|
It correctly compares results by fetching the text content for both the new search
|
||||||
|
results and the golden standard results, making the comparison robust to ID changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
|
if not data_root.exists():
|
||||||
|
print(f"Data directory '{data_root}' not found.")
|
||||||
|
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
if download_embeddings:
|
||||||
|
# Download everything including embeddings (large files)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
print("Data download complete (including embeddings)!")
|
||||||
|
else:
|
||||||
|
# Download only specific folders, excluding embeddings
|
||||||
|
allow_patterns = [
|
||||||
|
"ground_truth/**",
|
||||||
|
"indices/**",
|
||||||
|
"queries/**",
|
||||||
|
"*.md",
|
||||||
|
"*.txt",
|
||||||
|
]
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
print("Data download complete (excluding embeddings)!")
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
|
)
|
||||||
|
print("uv pip install -e '.[dev]'")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred during data download: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||||
|
"""Download embeddings files specifically."""
|
||||||
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
# Check if specific dataset embeddings exist
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
print(f"Embeddings for {dataset_type} already exist")
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
print("Downloading embeddings from HuggingFace Hub...")
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# Download only embeddings folder
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=["embeddings/**/*.pkl"],
|
||||||
|
)
|
||||||
|
print("Embeddings download complete!")
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
return str(embeddings_dir)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading embeddings: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helper Function to get Golden Passages ---
|
||||||
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||||
|
"""
|
||||||
|
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||||
|
passage manager.
|
||||||
|
"""
|
||||||
|
golden_texts = set()
|
||||||
|
for gid in golden_ids:
|
||||||
|
try:
|
||||||
|
# PassageManager uses string IDs
|
||||||
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
|
golden_texts.add(passage_data["text"])
|
||||||
|
except KeyError:
|
||||||
|
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||||
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(file_path: Path) -> list[str]:
|
||||||
|
queries = []
|
||||||
|
with open(file_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["query"])
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||||
|
"""
|
||||||
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
||||||
|
output_path: Path where to save the index
|
||||||
|
backend: Backend to use ("hnsw" or "diskann")
|
||||||
|
"""
|
||||||
|
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
||||||
|
|
||||||
|
# Create builder with appropriate parameters
|
||||||
|
if backend == "hnsw":
|
||||||
|
builder_kwargs = {
|
||||||
|
"M": 32, # Graph degree
|
||||||
|
"efConstruction": 256, # Construction complexity
|
||||||
|
"is_compact": True, # Use compact storage
|
||||||
|
"is_recompute": True, # Enable pruning for better recall
|
||||||
|
}
|
||||||
|
elif backend == "diskann":
|
||||||
|
builder_kwargs = {
|
||||||
|
"complexity": 64,
|
||||||
|
"graph_degree": 32,
|
||||||
|
"search_memory_maximum": 8.0, # GB
|
||||||
|
"build_memory_maximum": 16.0, # GB
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
builder_kwargs = {}
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
||||||
|
dimensions=768, # Will be auto-detected from embeddings
|
||||||
|
**builder_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build index from precomputed embeddings
|
||||||
|
builder.build_index_from_embeddings(output_path, embeddings_file)
|
||||||
|
print(f"Index saved to: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||||
|
parser.add_argument(
|
||||||
|
"index_path",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="Path to the LEANN index to evaluate or build (optional).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
choices=["evaluate", "build"],
|
||||||
|
default="evaluate",
|
||||||
|
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embeddings-file",
|
||||||
|
type=str,
|
||||||
|
help="Path to embeddings pickle file (optional for build mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
default="hnsw",
|
||||||
|
help="Backend to use for building index (default: hnsw)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
|
)
|
||||||
|
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# --- Path Configuration ---
|
||||||
|
# Assumes a project structure where the script is in 'benchmarks/'
|
||||||
|
# and evaluation data is in 'benchmarks/data/'.
|
||||||
|
script_dir = Path(__file__).resolve().parent
|
||||||
|
data_root = script_dir / "data"
|
||||||
|
|
||||||
|
# Download data based on mode
|
||||||
|
if args.mode == "build":
|
||||||
|
# For building mode, we need embeddings
|
||||||
|
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||||
|
|
||||||
|
# Auto-detect dataset type and download embeddings
|
||||||
|
if args.embeddings_file:
|
||||||
|
embeddings_file = args.embeddings_file
|
||||||
|
# Try to detect dataset type from embeddings file path
|
||||||
|
if "rpj_wiki" in str(embeddings_file):
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in str(embeddings_file):
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default
|
||||||
|
else:
|
||||||
|
# Auto-detect from index path if provided, otherwise default to DPR
|
||||||
|
if args.index_path:
|
||||||
|
index_path_str = str(args.index_path)
|
||||||
|
if "rpj_wiki" in index_path_str:
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in index_path_str:
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
|
||||||
|
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
||||||
|
|
||||||
|
# Auto-generate index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
indices_dir = data_root / "indices" / dataset_type
|
||||||
|
indices_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
||||||
|
print(f"Auto-generated index path: {args.index_path}")
|
||||||
|
|
||||||
|
print(f"Building index from embeddings: {embeddings_file}")
|
||||||
|
built_index_path = build_index_from_embeddings(
|
||||||
|
embeddings_file, args.index_path, args.backend
|
||||||
|
)
|
||||||
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
|
# Ask if user wants to run evaluation
|
||||||
|
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
|
if eval_response != "y":
|
||||||
|
print("Index building complete. Exiting.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# For evaluation mode, don't need embeddings
|
||||||
|
download_data_if_needed(data_root, download_embeddings=False)
|
||||||
|
|
||||||
|
# Auto-detect index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
# Default to using downloaded indices
|
||||||
|
indices_dir = data_root / "indices"
|
||||||
|
|
||||||
|
# Try common datasets in order of preference
|
||||||
|
for dataset in ["dpr", "rpj_wiki"]:
|
||||||
|
dataset_dir = indices_dir / dataset
|
||||||
|
if dataset_dir.exists():
|
||||||
|
# Look for index files
|
||||||
|
index_files = list(dataset_dir.glob("*.index")) + list(
|
||||||
|
dataset_dir.glob("*_disk.index")
|
||||||
|
)
|
||||||
|
if index_files:
|
||||||
|
args.index_path = str(
|
||||||
|
index_files[0].with_suffix("")
|
||||||
|
) # Remove .index extension
|
||||||
|
print(f"Using index: {args.index_path}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not args.index_path:
|
||||||
|
print("No indices found. The data download should have included pre-built indices.")
|
||||||
|
print(
|
||||||
|
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Detect dataset type from index path to select the correct ground truth
|
||||||
|
index_path_str = str(args.index_path)
|
||||||
|
if "rpj_wiki" in index_path_str:
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in index_path_str:
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
# Fallback: try to infer from the index directory name
|
||||||
|
dataset_type = Path(args.index_path).name
|
||||||
|
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||||
|
|
||||||
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
|
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
|
|
||||||
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
|
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(args.index_path)
|
||||||
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
|
with open(golden_results_file) as f:
|
||||||
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
|
queries = queries[:num_eval_queries]
|
||||||
|
|
||||||
|
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
||||||
|
recall_scores = []
|
||||||
|
search_times = []
|
||||||
|
|
||||||
|
for i in range(num_eval_queries):
|
||||||
|
start_time = time.time()
|
||||||
|
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
||||||
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
|
# Get golden texts directly from the searcher's passage manager
|
||||||
|
golden_ids = golden_results_data["indices"][i][: args.top_k]
|
||||||
|
golden_texts = get_golden_texts(searcher, golden_ids)
|
||||||
|
|
||||||
|
overlap = len(new_texts & golden_texts)
|
||||||
|
recall = overlap / len(golden_texts) if golden_texts else 0
|
||||||
|
recall_scores.append(recall)
|
||||||
|
|
||||||
|
print("\n--- EVALUATION RESULTS ---")
|
||||||
|
print(f"Query: {queries[i]}")
|
||||||
|
print(f"New Results: {new_texts}")
|
||||||
|
print(f"Golden Results: {golden_texts}")
|
||||||
|
print(f"Overlap: {overlap}")
|
||||||
|
print(f"Recall: {recall}")
|
||||||
|
print(f"Search Time: {search_times[-1]:.4f}s")
|
||||||
|
print("--------------------------------")
|
||||||
|
|
||||||
|
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||||
|
avg_time = np.mean(search_times) if search_times else 0
|
||||||
|
|
||||||
|
print("\n🎉 --- Evaluation Complete ---")
|
||||||
|
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||||
|
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
311
benchmarks/simple_mac_tpt_test.py
Normal file
311
benchmarks/simple_mac_tpt_test.py
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
# Add MLX imports
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.utils import load
|
||||||
|
|
||||||
|
MLX_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||||
|
MLX_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkConfig:
|
||||||
|
model_path: str = "facebook/contriever"
|
||||||
|
batch_sizes: list[int] = None
|
||||||
|
seq_length: int = 256
|
||||||
|
num_runs: int = 5
|
||||||
|
use_fp16: bool = True
|
||||||
|
use_int4: bool = False
|
||||||
|
use_int8: bool = False
|
||||||
|
use_cuda_graphs: bool = False
|
||||||
|
use_flash_attention: bool = False
|
||||||
|
use_linear8bitlt: bool = False
|
||||||
|
use_mlx: bool = False # New flag for MLX testing
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.batch_sizes is None:
|
||||||
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||||
|
|
||||||
|
|
||||||
|
class MLXBenchmark:
|
||||||
|
"""MLX-specific benchmark for embedding models"""
|
||||||
|
|
||||||
|
def __init__(self, config: BenchmarkConfig):
|
||||||
|
self.config = config
|
||||||
|
self.model, self.tokenizer = self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""Load MLX model and tokenizer following the API pattern"""
|
||||||
|
print(f"Loading MLX model from {self.config.model_path}...")
|
||||||
|
try:
|
||||||
|
model, tokenizer = load(self.config.model_path)
|
||||||
|
print("MLX model loaded successfully")
|
||||||
|
return model, tokenizer
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading MLX model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_random_batch(self, batch_size: int):
|
||||||
|
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||||
|
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
||||||
|
|
||||||
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
|
"""Run MLX inference with same input as PyTorch"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
# Convert PyTorch tensor to MLX array
|
||||||
|
input_ids_mlx = mx.array(input_ids.numpy())
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
embeddings = self.model(input_ids_mlx)
|
||||||
|
|
||||||
|
# Mean pooling (following the API pattern)
|
||||||
|
pooled = embeddings.mean(axis=1)
|
||||||
|
|
||||||
|
# Convert to numpy (following the API pattern)
|
||||||
|
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
|
||||||
|
|
||||||
|
# Force computation
|
||||||
|
_ = pooled_numpy.shape
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"MLX inference error: {e}")
|
||||||
|
return float("inf")
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
return end_time - start_time
|
||||||
|
|
||||||
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
|
"""Run the MLX benchmark across all batch sizes"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
print(f"Starting MLX benchmark with model: {self.config.model_path}")
|
||||||
|
print(f"Testing batch sizes: {self.config.batch_sizes}")
|
||||||
|
|
||||||
|
for batch_size in self.config.batch_sizes:
|
||||||
|
print(f"\n=== Testing MLX batch size: {batch_size} ===")
|
||||||
|
times = []
|
||||||
|
|
||||||
|
# Create input batch (same as PyTorch)
|
||||||
|
input_ids = self._create_random_batch(batch_size)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
print("Warming up...")
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
self._run_inference(input_ids[:2]) # Warm up with smaller batch
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warmup error: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||||
|
try:
|
||||||
|
elapsed_time = self._run_inference(input_ids)
|
||||||
|
if elapsed_time != float("inf"):
|
||||||
|
times.append(elapsed_time)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during MLX inference: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not times:
|
||||||
|
print(f"Skipping batch size {batch_size} due to errors")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
avg_time = np.mean(times)
|
||||||
|
std_time = np.std(times)
|
||||||
|
throughput = batch_size / avg_time
|
||||||
|
|
||||||
|
results[batch_size] = {
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"std_time": std_time,
|
||||||
|
"throughput": throughput,
|
||||||
|
"min_time": np.min(times),
|
||||||
|
"max_time": np.max(times),
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"MLX Results for batch size {batch_size}:")
|
||||||
|
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||||
|
print(f" Min Time: {np.min(times):.4f}s")
|
||||||
|
print(f" Max Time: {np.max(times):.4f}s")
|
||||||
|
print(f" Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class Benchmark:
|
||||||
|
def __init__(self, config: BenchmarkConfig):
|
||||||
|
self.config = config
|
||||||
|
self.device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
self.model = self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self) -> nn.Module:
|
||||||
|
print(f"Loading model from {self.config.model_path}...")
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(self.config.model_path)
|
||||||
|
if self.config.use_fp16:
|
||||||
|
model = model.half()
|
||||||
|
model = torch.compile(model)
|
||||||
|
model = model.to(self.device)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
|
return torch.randint(
|
||||||
|
0,
|
||||||
|
1000,
|
||||||
|
(batch_size, self.config.seq_length),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
return end_time - start_time
|
||||||
|
|
||||||
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
for batch_size in self.config.batch_sizes:
|
||||||
|
print(f"\nTesting batch size: {batch_size}")
|
||||||
|
times = []
|
||||||
|
|
||||||
|
input_ids = self._create_random_batch(batch_size)
|
||||||
|
|
||||||
|
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||||
|
try:
|
||||||
|
elapsed_time = self._run_inference(input_ids)
|
||||||
|
times.append(elapsed_time)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during inference: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not times:
|
||||||
|
continue
|
||||||
|
|
||||||
|
avg_time = np.mean(times)
|
||||||
|
std_time = np.std(times)
|
||||||
|
throughput = batch_size / avg_time
|
||||||
|
|
||||||
|
results[batch_size] = {
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"std_time": std_time,
|
||||||
|
"throughput": throughput,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||||
|
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
|
else:
|
||||||
|
peak_memory_gb = 0.0
|
||||||
|
|
||||||
|
for batch_size in results:
|
||||||
|
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark():
|
||||||
|
"""Main function to run the benchmark with optimized parameters."""
|
||||||
|
config = BenchmarkConfig()
|
||||||
|
|
||||||
|
try:
|
||||||
|
benchmark = Benchmark(config)
|
||||||
|
results = benchmark.run()
|
||||||
|
|
||||||
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
|
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"max_throughput": max_throughput,
|
||||||
|
"avg_throughput": avg_throughput,
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Benchmark failed: {e}")
|
||||||
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def run_mlx_benchmark():
|
||||||
|
"""Run MLX-specific benchmark"""
|
||||||
|
if not MLX_AVAILABLE:
|
||||||
|
print("MLX not available, skipping MLX benchmark")
|
||||||
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": "MLX not available",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
benchmark = MLXBenchmark(config)
|
||||||
|
results = benchmark.run()
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": "No valid results",
|
||||||
|
}
|
||||||
|
|
||||||
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
|
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"max_throughput": max_throughput,
|
||||||
|
"avg_throughput": avg_throughput,
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"MLX benchmark failed: {e}")
|
||||||
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=== PyTorch Benchmark ===")
|
||||||
|
pytorch_result = run_benchmark()
|
||||||
|
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
|
||||||
|
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
|
||||||
|
|
||||||
|
print("\n=== MLX Benchmark ===")
|
||||||
|
mlx_result = run_mlx_benchmark()
|
||||||
|
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
|
||||||
|
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
||||||
|
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
||||||
|
print("\n=== Comparison ===")
|
||||||
|
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||||
82
data/.gitattributes
vendored
Normal file
82
data/.gitattributes
vendored
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - uncompressed
|
||||||
|
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - compressed
|
||||||
|
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - uncompressed
|
||||||
|
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - compressed
|
||||||
|
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Video files - compressed
|
||||||
|
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
BIN
data/2501.14312v1 (1).pdf
Normal file
BIN
data/2501.14312v1 (1).pdf
Normal file
Binary file not shown.
14905
data/PrideandPrejudice.txt
Normal file
14905
data/PrideandPrejudice.txt
Normal file
File diff suppressed because it is too large
Load Diff
82
data/README.md
Normal file
82
data/README.md
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
|
||||||
|
|
||||||
|
各位好,
|
||||||
|
|
||||||
|
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
|
||||||
|
|
||||||
|
首先为自证身份,列举一些细节:
|
||||||
|
|
||||||
|
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
|
||||||
|
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
|
||||||
|
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
|
||||||
|
4. 诺亚曾经传说是研究型的,但是来了之后因为在四野做大模型项目,项目成员完全变成了交付型的,且充满了例会,评审,汇报。很多时候做实验都要申请。团队需要对接终端小艺,华为云,ICT等诸多业务线,交付压力不小。
|
||||||
|
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”,一开始只有内部需要申请试用的网页版,到后续迫于压力在welink上接入和公测开放。
|
||||||
|
|
||||||
|
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
|
||||||
|
|
||||||
|
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
|
||||||
|
|
||||||
|
华为确实主要在昇腾卡上训练大模型(小模型实验室有不少英伟达的卡,他们之前也会用来训练,后面转移到昇腾)。曾经我被华为“打造世界第二选择”的决心而折服,我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打,从充满bug到现在能训出模型,付出了巨大的心血和代价。
|
||||||
|
|
||||||
|
最初我们的算力非常有限,在910A上训练模型。那会只支持fp16,训练的稳定性远不如bf16。盘古的moe开始很早,23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型,后面主力模型也逐渐在910B上训练。
|
||||||
|
|
||||||
|
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低,每个单个的符号,数字,空格,乃至汉字都会占用一个token。可想而知这会非常浪费算力,且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好(虽然事后来看,他的怀疑是无疑正确的),于是就决定,让71B和135B换tokenizer,因为小模型实验室曾经尝试过。团队缝合了两个tokenizer,开始了tokenizer的更换。71B模型的更换失败了,而135B因为采用了更精细的embedding初始化策略,续训了至少1T的数据后词表总算更换成功,但可想而知,效果并不会变好。
|
||||||
|
|
||||||
|
于此同期,阿里和智谱等国内其他公司在GPU上训练,且已经摸索出了正确的方法,盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败,导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时,团队的士气低迷到了极点。团队在算力极其有限的时候,做出了很多努力和挣扎。比如,团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数,还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B,架构相对落后,团队进行了一系列的操作,比如切换绝对位置编码到rope,去掉bias,切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验,这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训,变成了第二代38B dense模型(在几个月内这个模型都是主要的盘古中档位模型),曾经具有一定的竞争力。但是,由于更大的135B模型架构落后,且更换词表模型损伤巨大(后续分析发现当时更换的缝合词表有更严重的bug),续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
|
||||||
|
|
||||||
|
在这种情况下,王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来,通过训练短短的几百B数据,各项指标平均提升了十个点左右。实际上,这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行,使得领导完全对于这种扯淡的事情没有概念,他们只会觉得肯定是有什么算法创新。经过内部的分析,他们实际上是使用Qwen 1.5 110B续训而来,通过加层,扩增ffn维度,添加盘古pi论文的一些机制得来,凑够了大概135B的参数。实际上,旧的135B有107层,而这个模型只有82层,各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen,甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游,甚至包括外部客户。
|
||||||
|
|
||||||
|
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击,内部很多人其实都知道这件事,甚至包括终端和华为云。我们都戏称以后别叫盘古模型了,叫千古吧。当时团队成员就想向bcg举报了,毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来,因为更高级别的领导(比如姚老师,以及可能熊总和查老)其实后面也知道了,但是并不管,因为通过套壳拿出好的结果,对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷,离职跑路也逐渐成为挂在嘴边的事。
|
||||||
|
|
||||||
|
此时,盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来,当时诺亚完全没有掌握从头训练的技术,何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下,盘古开始了第三代模型的训练,付出了巨大的努力后,在数据架构和训练算法方面都与业界逐渐接轨,而这其中的艰辛和小模型实验室的人一点关系都没有。
|
||||||
|
|
||||||
|
一开始团队成员毫无信心,只从一个13B的模型开始训练,但是后面发现效果还不错,于是这个模型后续再次进行了一次参数扩增,变成了第三代的38B,代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的(也是业界常见的做法)。而当时王云鹤的实验室做出来了另一个词表(也就是后续pangu系列的词表)。当时两个词表还被迫进行了一次赛马,最终没有明显的好坏结论。于是,领导当即决定,应该统一词表,使用王云鹤他们的。于是,在后续从头训练的135B V3(也就是对外的Pangu Ultra),便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑,为什么当时同为V3代的两个不同档位的模型,会使用不同的tokenizer。
|
||||||
|
|
||||||
|
|
||||||
|
我们打心眼里觉得,135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的,华为全栈自研,正经从头训练的千亿级别的模型,且效果与24年同期竞品可比的。写到这里我已经热泪盈眶,太不容易了。当时为了稳定训练,团队做了大量实验对比,并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难,我们做到了,我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨,我们为了它的训练而不眠。在被内部心声骂的一文不值的时候,我们有多么不甘,有多少的委屈,我们挺住了。
|
||||||
|
|
||||||
|
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
|
||||||
|
|
||||||
|
然而,我们的所有辛苦的成果,经常被小模型实验室轻飘飘的拿走了。数据,直接要走。代码,直接要走,还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦,他们取得荣耀。果然应了那句话,你在负重前行是因为有人替你岁月静好。在这种情况下,越来越多的战友再也坚持不下去了,选择了离开。看到身边那些优秀的同事一个个离职,我的内心又感叹又难过。在这种作战一样的环境下,我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方,堪称良师。看到他们去了诸如字节Seed,Deepseek,月之暗面,腾讯和快手等等很多出色的团队,我打心眼里为他们高兴和祝福,脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新,ta说:“来这里是我技术生涯中的耻辱,在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足,以及没法适应互联网公司高淘汰的环境,让我多次想离职的心始终没有迈出这一步。
|
||||||
|
|
||||||
|
盘古除了dense模型,后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的,小模型实验室也开启了第二次主要的套壳行动(次要的插曲可能还包括一些别的模型,比如math模型),即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的(就算如此,这也与技术报告不符,何况是套壳qwen 2.5的14b续训)。还记得他们训了没几天,内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型,都知道他们的套壳行动,只是迫于各种原因,无法伸张正义。实际上,对于后续训了很久很久的这个模型,Honestagi能够分析出这个量级的相似性我已经很诧异了,因为这个模型为了续训洗参数,所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印,采取了不少办法,甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
|
||||||
|
|
||||||
|
24年底和25年初,在Deepseek v3和r1发布之后,由于其惊艳的技术水平,团队受到了巨大的冲击,也受到了更大的质疑。于是为了紧跟潮流,盘古模仿Deepseek的模型尺寸,开启了718B moe的训练。这个时候,小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数,进行训练。连任务加载ckpt的目录都是deepseekv3,改都不改,何其嚣张?与之相反,一些有真正技术信仰的同事,在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然,这个模型怎么可能比直接套壳的好呢?如果不是团队leader坚持,早就被叫停了。
|
||||||
|
|
||||||
|
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
|
||||||
|
|
||||||
|
HonestAGI的事情出来后,内部让大家不停的研讨分析,如何公关和“回应”。诚然,这个原文的分析也许不够有力,给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此,这两天我内心感到作呕,时时怀疑自己的人生意义以及苍天无眼。我不奉陪了,我要离职了,同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到,他们竟然猖狂到敢开源。我没想到,他们敢如此愚弄世人,大肆宣发。当时,我也许是存了侥幸心理,没有拒绝署名。我相信很多扎实做事的战友,也只是被迫上了贼船,或者不知情。但这件事已经无法挽回,我希望我的余生能够坚持扎实做真正有意义的事,为我当时的软弱和不坚定赎罪。
|
||||||
|
|
||||||
|
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
|
||||||
|
|
||||||
|
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
|
||||||
|
|
||||||
|
现在,我累了,我想投降。
|
||||||
|
|
||||||
|
其实时至今日,我还是真心希望华为能认真吸取教训,能做好盘古,把盘古做到世界一流,把昇腾变成英伟达的水平。内部的劣币驱逐良币,使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着,施展着他们的抱负才华,为中美在AI的激烈竞赛中奉献力量。我时常感叹,华为不是没有人才,而是根本不知道怎么留住人才。如果给这些人合适的环境,合适的资源,更少的枷锁,更少的政治斗争,盘古何愁不成?
|
||||||
|
|
||||||
|
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
|
||||||
|
|
||||||
|
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
|
||||||
|
|
||||||
|
如果我消失了,就当是我为了真理和理想,为了华为乃至中国能够更好地发展算力和AI而牺牲了吧,我愿埋葬于那片曾经奋斗过的地方。
|
||||||
|
|
||||||
|
诺亚,再见
|
||||||
|
|
||||||
|
2025年7月6日凌晨 写于深圳
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
各位好,
|
||||||
|
|
||||||
|
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
|
||||||
|
|
||||||
|
我补充一些细节,以免某些人继续颠倒黑白。
|
||||||
|
|
||||||
|
关于135B V2,小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后(比如任务令表彰和及时激励),因为不想继续支撑下游应用和模型迭代,又把这个烫手山芋甩给了四纵。确实技高一筹,直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型,最终拿回了一个当时一个魔改的先进的千问。做大模型的人,自己做的模型就像自己孩子一样熟悉,不要把别人都当傻子。就像自家儿子出门一趟,回来个别人家孩子。
|
||||||
|
|
||||||
|
盘古report的署名是不符合学术规范的。例如,135B V3有不少有技术贡献的人,因为作者名额数量限制,劳动成果没有得到应有的回报,团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶,甚至是团队当时的精神支柱,支撑着不少兄弟们继续留在诺亚。所谓的名额限制,以及挂名了一些毫无技术贡献的人(如一些小模型实验室的人),让兄弟们何其心寒。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317
|
||||||
458
demo.ipynb
458
demo.ipynb
@@ -1,362 +1,116 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "markdown",
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Initializing leann-backend-diskann...\n",
|
|
||||||
"INFO: Registering backend 'diskann'\n",
|
|
||||||
"INFO: DiskANN backend loaded successfully\n",
|
|
||||||
"INFO: LeannBuilder initialized with 'diskann' backend.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/home/ubuntu/LEANN_clean/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO: Computing embeddings for 6 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 2.91it/s]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO: Building DiskANN index for 6 vectors with metric Metric.INNER_PRODUCT...\n",
|
|
||||||
"Using Inner Product search, so need to pre-process base data into temp file. Please ensure there is additional (n*(d+1)*4) bytes for storing pre-processed base vectors, apart from the interim indices created by DiskANN and the final index.\n",
|
|
||||||
"Pre-processing base file by adding extra coordinate\n",
|
|
||||||
"✅ DiskANN index built successfully at 'knowledge'\n",
|
|
||||||
"Writing bin: knowledge_disk.index_max_base_norm.bin\n",
|
|
||||||
"bin: #pts = 1, #dims = 1, size = 12B\n",
|
|
||||||
"Finished writing bin.\n",
|
|
||||||
"Time for preprocessing data for inner product: 0.000172 seconds\n",
|
|
||||||
"Reading max_norm_of_base from knowledge_disk.index_max_base_norm.bin\n",
|
|
||||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
|
||||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
|
||||||
"Metadata: #pts = 1, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"max_norm_of_base: 1\n",
|
|
||||||
"! Using prepped_base file at knowledge_prepped_base.bin\n",
|
|
||||||
"Starting index build: R=32 L=64 Query RAM budget: 4.02653e+09 Indexing ram budget: 8 T: 8\n",
|
|
||||||
"getting bin metadata\n",
|
|
||||||
"Time for getting bin metadata: 0.000019 seconds\n",
|
|
||||||
"Compressing 769-dimensional data into 512 bytes per vector.\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Training data with 6 samples loaded.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 256, #dims = 769...\n",
|
|
||||||
"done.\n",
|
|
||||||
"PQ pivot file exists. Not generating again\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 4, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 256, #dims = 769...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 769, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 513, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Loaded PQ pivot information\n",
|
|
||||||
"Processing points [0, 6)...done.\n",
|
|
||||||
"Time for generating quantized data: 0.055587 seconds\n",
|
|
||||||
"Full index fits in RAM budget, should consume at most 2.03973e-05GiBs, so building in one shot\n",
|
|
||||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
|
||||||
"Passed, empty search_params while creating index config\n",
|
|
||||||
"Using only first 6 from file.. \n",
|
|
||||||
"Starting index build with 6 points... \n",
|
|
||||||
"0% of index build completed.Starting final cleanup..done. Link time: 0.00011s\n",
|
|
||||||
"Index built with degree: max:5 avg:5 min:5 count(deg<2):0\n",
|
|
||||||
"Not saving tags as they are not enabled.\n",
|
|
||||||
"Time taken for save: 0.000148s.\n",
|
|
||||||
"Time for building merged vamana index: 0.000836 seconds\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Vamana index file size=168\n",
|
|
||||||
"Opened: knowledge_disk.index, cache_size: 67108864\n",
|
|
||||||
"medoid: 0B\n",
|
|
||||||
"max_node_len: 3100B\n",
|
|
||||||
"nnodes_per_sector: 1B\n",
|
|
||||||
"# sectors: 6\n",
|
|
||||||
"Sector #0written\n",
|
|
||||||
"Finished writing 28672B\n",
|
|
||||||
"Writing bin: knowledge_disk.index\n",
|
|
||||||
"bin: #pts = 9, #dims = 1, size = 80B\n",
|
|
||||||
"Finished writing bin.\n",
|
|
||||||
"Output disk index file written to knowledge_disk.index\n",
|
|
||||||
"Finished writing 28672B\n",
|
|
||||||
"Time for generating disk layout: 0.040268 seconds\n",
|
|
||||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
|
||||||
"Loading base knowledge_prepped_base.bin. #points: 6. #dim: 769.\n",
|
|
||||||
"Wrote 1 points to sample file: knowledge_sample_data.bin\n",
|
|
||||||
"Indexing time: 0.0970594\n",
|
|
||||||
"INFO: Leann metadata saved to knowledge.leann.meta.json\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Opened file : knowledge_disk.index\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"✅ DiskANN index loaded successfully.\n",
|
|
||||||
"INFO: LeannSearcher initialized with 'diskann' backend using index 'knowledge.leann'.\n",
|
|
||||||
"Since data is floating point, we assume that it has been appropriately pre-processed (normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we shall invoke an l2 distance function.\n",
|
|
||||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
|
||||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
|
||||||
"Before index load\n",
|
|
||||||
"Reading bin file knowledge_pq_compressed.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_compressed.bin... \n",
|
|
||||||
"Metadata: #pts = 6, #dims = 512...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 4, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Offsets: 4096 791560 794644 796704\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 256, #dims = 769...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 769, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
|
||||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
|
||||||
"Metadata: #pts = 513, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Loaded PQ Pivots: #ctrs: 256, #dims: 769, #chunks: 512\n",
|
|
||||||
"Loaded PQ centroids and in-memory compressed vectors. #points: 6 #dim: 769 #aligned_dim: 776 #chunks: 512\n",
|
|
||||||
"Loading index metadata from knowledge_disk.index\n",
|
|
||||||
"Disk-Index File Meta-data: # nodes per sector: 1, max node len (bytes): 3100, max node degree: 5\n",
|
|
||||||
"Disk-Index Meta: nodes per sector: 1, max node len: 3100, max node degree: 5\n",
|
|
||||||
"Setting up thread-specific contexts for nthreads: 8\n",
|
|
||||||
"allocating ctx: 0x7a33f7204000 to thread-id:134367072315200\n",
|
|
||||||
"allocating ctx: 0x7a33f6805000 to thread-id:134355206802368\n",
|
|
||||||
"allocating ctx: 0x7a33f5e72000 to thread-id:134355217288000\n",
|
|
||||||
"allocating ctx: 0x7a33f5e61000 to thread-id:134355227773632\n",
|
|
||||||
"allocating ctx: 0x7a33f5e50000 to thread-id:134355196316736\n",
|
|
||||||
"allocating ctx: 0x7a33f5e3f000 to thread-id:134355164859840\n",
|
|
||||||
"allocating ctx: 0x7a33f5e2e000 to thread-id:134355175345472\n",
|
|
||||||
"allocating ctx: 0x7a33f5e1d000 to thread-id:134355185831104\n",
|
|
||||||
"Loading centroid data from medoids vector data of 1 medoid(s)\n",
|
|
||||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
|
||||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
|
||||||
"Metadata: #pts = 1, #dims = 1...\n",
|
|
||||||
"done.\n",
|
|
||||||
"Setting re-scaling factor of base vectors to 1\n",
|
|
||||||
"load_from_separate_paths done.\n",
|
|
||||||
"Reading (with alignment) bin file knowledge_sample_data.bin ...Metadata: #pts = 1, #dims = 769, aligned_dim = 776... allocating aligned memory of 3104 bytes... done. Copying data to mem_aligned buffer... done.\n",
|
|
||||||
"reserve ratio: 1\n",
|
|
||||||
"Graph traversal completed, hops: 3\n",
|
|
||||||
"Loading the cache list into memory....done.\n",
|
|
||||||
"After index load\n",
|
|
||||||
"INFO: Computing embeddings for 1 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 60.54it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running\n",
|
|
||||||
"INFO: Starting session-level embedding server as a background process...\n",
|
|
||||||
"INFO: Running command from project root: /home/ubuntu/LEANN_clean/leann\n",
|
|
||||||
"INFO: Server process started with PID: 424761\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"✅ Embedding server is up and ready for this session.\n",
|
|
||||||
"[EmbeddingServer LOG]: Initializing leann-backend-diskann...\n",
|
|
||||||
"[EmbeddingServer LOG]: WARNING: Could not import DiskANN backend: cannot import name '_diskannpy' from partially initialized module 'packages.leann-backend-diskann.leann_backend_diskann' (most likely due to a circular import) (/home/ubuntu/LEANN_clean/leann/packages/leann-backend-diskann/leann_backend_diskann/__init__.py)\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Initializing embedding server thread on port 5555\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Using CUDA device\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Loading model sentence-transformers/all-mpnet-base-v2\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Using FP16 precision with model: sentence-transformers/all-mpnet-base-v2\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Loaded 6 demo documents\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ ROUTER server listening on port 5555\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Embedding server ready to serve requests\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 3 bytes\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Request for 1 node embeddings: [0]\n",
|
|
||||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 0\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000028 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Total batch size: 1, max_batch_size: 128\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 1\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.019294 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Batch size: 1, Sequence length: 256\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000210 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 3.065444 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.041810 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000194 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 3.128073 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [1, 2, 3, 4, 5]\n",
|
|
||||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 1 to 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000042 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001791 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000112 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 3.674183 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000372 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000177 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 3.677425 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [3, 4, 2, 1, 0]\n",
|
|
||||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 4\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000030 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001550 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000097 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.009335 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000154 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000073 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.011773 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [0, 1, 2, 4, 5]\n",
|
|
||||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001041 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000125 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008972 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000151 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000048 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010853 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [3, 1, 0, 2, 5]\n",
|
|
||||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001350 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000088 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008869 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000146 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000063 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.011054 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [0, 2, 3, 4, 5]\n",
|
|
||||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000022 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001195 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000087 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008903 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000145 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000060 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010921 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [1, 0, 3, 4, 5]\n",
|
|
||||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001188 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000087 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008858 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000153 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000052 seconds\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010886 seconds\n",
|
|
||||||
"reserve ratio: Score: -0.481 - C++ is a powerful programming language1\n",
|
|
||||||
"Graph traversal completed, hops: 3\n",
|
|
||||||
"\n",
|
|
||||||
"Score: -1.049 - Java is a powerful programming language\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
|
||||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannBuilder, LeannSearcher\n",
|
"# Quick Start \n",
|
||||||
"import leann_backend_diskann\n",
|
"\n",
|
||||||
"# 1. Build index (no embeddings stored!)\n",
|
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||||
"builder = LeannBuilder(backend_name=\"diskann\")\n",
|
"\n",
|
||||||
"builder.add_text(\"Python is a powerful programming language\")\n",
|
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
||||||
"builder.add_text(\"Machine learning transforms industries\") \n",
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# install this if you are using colab\n",
|
||||||
|
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
||||||
|
"! uv pip install leann --no-deps\n",
|
||||||
|
"# For Colab environment, we need to set some environment variables\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"INDEX_DIR = Path(\"./\").resolve()\n",
|
||||||
|
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Build the index"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannBuilder\n",
|
||||||
|
"\n",
|
||||||
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
|
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||||
|
"builder.add_text(\n",
|
||||||
|
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
||||||
|
")\n",
|
||||||
|
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
"builder.add_text(\"Java is a powerful programming language\")\n",
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||||
"builder.add_text(\"C++ is a powerful programming language\")\n",
|
"builder.build_index(INDEX_PATH)"
|
||||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
]
|
||||||
"builder.build_index(\"knowledge.leann\")\n",
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Search with real-time embeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannSearcher\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 2. Search with real-time embeddings\n",
|
"searcher = LeannSearcher(INDEX_PATH)\n",
|
||||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||||
"results = searcher.search(\"C++ programming languages\", top_k=2,recompute_beighbor_embeddings=True)\n",
|
"results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chat with LEANN using retrieved results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannChat\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for result in results:\n",
|
"llm_config = {\n",
|
||||||
" print(f\"Score: {result['score']:.3f} - {result['text']}\")"
|
" \"type\": \"hf\",\n",
|
||||||
|
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
||||||
|
"response = chat.ask(\n",
|
||||||
|
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||||
|
" top_k=2,\n",
|
||||||
|
" llm_kwargs={\"max_tokens\": 128},\n",
|
||||||
|
")\n",
|
||||||
|
"response"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -376,7 +130,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.11"
|
"version": "3.11.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
220
docs/CONTRIBUTING.md
Normal file
220
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
# 🤝 Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
|
## Ways to Contribute
|
||||||
|
|
||||||
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
|
- 📖 **Documentation**: Help make Leann more accessible
|
||||||
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
|
|
||||||
|
## 🚀 Development Setup
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
1. **Install uv** (fast Python package installer):
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Clone the repository**:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||||
|
cd LEANN-RAG
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Install system dependencies**:
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ubuntu/Debian:**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||||
|
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Build from source**:
|
||||||
|
```bash
|
||||||
|
# macOS
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
|
|
||||||
|
# Ubuntu/Debian
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔨 Pre-commit Hooks
|
||||||
|
|
||||||
|
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||||
|
|
||||||
|
### Setup Pre-commit
|
||||||
|
|
||||||
|
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||||
|
```bash
|
||||||
|
uv pip install pre-commit
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install the git hooks**:
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Run pre-commit manually** (optional):
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pre-commit Checks
|
||||||
|
|
||||||
|
Our pre-commit configuration includes:
|
||||||
|
- **Trailing whitespace removal**
|
||||||
|
- **End-of-file fixing**
|
||||||
|
- **YAML validation**
|
||||||
|
- **Large file prevention**
|
||||||
|
- **Merge conflict detection**
|
||||||
|
- **Debug statement detection**
|
||||||
|
- **Code formatting with ruff**
|
||||||
|
- **Code linting with ruff**
|
||||||
|
|
||||||
|
## 🧪 Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
uv run pytest
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
uv run pytest test/test_filename.py
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
uv run pytest --cov=leann
|
||||||
|
```
|
||||||
|
|
||||||
|
### Writing Tests
|
||||||
|
|
||||||
|
- Place tests in the `test/` directory
|
||||||
|
- Follow the naming convention `test_*.py`
|
||||||
|
- Use descriptive test names that explain what's being tested
|
||||||
|
- Include both positive and negative test cases
|
||||||
|
|
||||||
|
## 📝 Code Style
|
||||||
|
|
||||||
|
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||||
|
|
||||||
|
### Format Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Format all files
|
||||||
|
ruff format
|
||||||
|
|
||||||
|
# Check formatting without changing files
|
||||||
|
ruff format --check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lint Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run linter with auto-fix
|
||||||
|
ruff check --fix
|
||||||
|
|
||||||
|
# Just check without fixing
|
||||||
|
ruff check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Style Guidelines
|
||||||
|
|
||||||
|
- Follow PEP 8 conventions
|
||||||
|
- Use descriptive variable names
|
||||||
|
- Add type hints where appropriate
|
||||||
|
- Write docstrings for all public functions and classes
|
||||||
|
- Keep functions focused and single-purpose
|
||||||
|
|
||||||
|
## 🚦 CI/CD
|
||||||
|
|
||||||
|
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||||
|
|
||||||
|
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||||
|
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||||
|
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||||
|
4. **Wheel building**: Ensures packages can be built and distributed
|
||||||
|
|
||||||
|
### CI Commands
|
||||||
|
|
||||||
|
The CI uses the same commands as pre-commit to ensure consistency:
|
||||||
|
```bash
|
||||||
|
# Linting
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
# Format checking
|
||||||
|
ruff format --check .
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure your code passes these checks locally before pushing!
|
||||||
|
|
||||||
|
## 🔄 Pull Request Process
|
||||||
|
|
||||||
|
1. **Fork the repository** and create your branch from `main`:
|
||||||
|
```bash
|
||||||
|
git checkout -b feature/your-feature-name
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Make your changes**:
|
||||||
|
- Write clean, documented code
|
||||||
|
- Add tests for new functionality
|
||||||
|
- Update documentation as needed
|
||||||
|
|
||||||
|
3. **Run pre-commit checks**:
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Test your changes**:
|
||||||
|
```bash
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Commit with descriptive messages**:
|
||||||
|
```bash
|
||||||
|
git commit -m "feat: add new search algorithm"
|
||||||
|
```
|
||||||
|
|
||||||
|
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||||
|
- `feat:` for new features
|
||||||
|
- `fix:` for bug fixes
|
||||||
|
- `docs:` for documentation changes
|
||||||
|
- `test:` for test additions/changes
|
||||||
|
- `refactor:` for code refactoring
|
||||||
|
- `perf:` for performance improvements
|
||||||
|
|
||||||
|
6. **Push and create a pull request**:
|
||||||
|
- Provide a clear description of your changes
|
||||||
|
- Reference any related issues
|
||||||
|
- Include examples or screenshots if applicable
|
||||||
|
|
||||||
|
## 📚 Documentation
|
||||||
|
|
||||||
|
When adding new features or making significant changes:
|
||||||
|
|
||||||
|
1. Update relevant documentation in `/docs`
|
||||||
|
2. Add docstrings to new functions/classes
|
||||||
|
3. Update README.md if needed
|
||||||
|
4. Include usage examples
|
||||||
|
|
||||||
|
## 🤔 Getting Help
|
||||||
|
|
||||||
|
- **Discord**: Join our community for discussions
|
||||||
|
- **Issues**: Check existing issues or create a new one
|
||||||
|
- **Discussions**: For general questions and ideas
|
||||||
|
|
||||||
|
## 📄 License
|
||||||
|
|
||||||
|
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||||
22
docs/RELEASE.md
Normal file
22
docs/RELEASE.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Release Guide
|
||||||
|
|
||||||
|
## Setup (One-time)
|
||||||
|
|
||||||
|
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||||
|
1. Get token: https://pypi.org/manage/account/token/
|
||||||
|
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||||
|
|
||||||
|
## Release (One-click)
|
||||||
|
|
||||||
|
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||||
|
2. Click "Run workflow"
|
||||||
|
3. Enter version: `0.1.2`
|
||||||
|
4. Click green "Run workflow" button
|
||||||
|
|
||||||
|
That's it! The workflow will automatically:
|
||||||
|
- ✅ Update version in all packages
|
||||||
|
- ✅ Build all packages
|
||||||
|
- ✅ Publish to PyPI
|
||||||
|
- ✅ Create GitHub tag and release
|
||||||
|
|
||||||
|
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Comparison between Sentence Transformers and OpenAI embeddings
|
||||||
|
|
||||||
|
This example shows how different embedding models handle complex queries
|
||||||
|
and demonstrates the differences between local and API-based embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
|
# OpenAI API key should be set as environment variable
|
||||||
|
# export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||||
|
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||||
|
|
||||||
|
# Two queries with same intent but different wording
|
||||||
|
query1 = "Tell me my browser history about some conference i often visit"
|
||||||
|
query2 = "browser history about conference I often visit"
|
||||||
|
|
||||||
|
texts = [query1, query2, conference_text, browser_text]
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b):
|
||||||
|
return np.dot(a, b) # Already normalized
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_embeddings(embeddings, model_name):
|
||||||
|
print(f"\n=== {model_name} Results ===")
|
||||||
|
|
||||||
|
# Results for Query 1
|
||||||
|
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||||
|
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||||
|
|
||||||
|
print(f"Query 1: '{query1}'")
|
||||||
|
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Results for Query 2
|
||||||
|
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||||
|
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||||
|
|
||||||
|
print(f"\nQuery 2: '{query2}'")
|
||||||
|
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Show the impact
|
||||||
|
print("\n=== Impact Analysis ===")
|
||||||
|
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||||
|
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||||
|
|
||||||
|
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||||
|
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||||
|
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||||
|
print("✅ STABLE: Conference remains winner in both queries")
|
||||||
|
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||||
|
print("✅ STABLE: Browser remains winner in both queries")
|
||||||
|
else:
|
||||||
|
print("🔄 MIXED: Results vary between queries")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query1_conf": sim1_conf,
|
||||||
|
"query1_browser": sim1_browser,
|
||||||
|
"query2_conf": sim2_conf,
|
||||||
|
"query2_browser": sim2_browser,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Test Sentence Transformers
|
||||||
|
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||||
|
try:
|
||||||
|
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||||
|
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Sentence Transformers failed: {e}")
|
||||||
|
st_results = None
|
||||||
|
|
||||||
|
# Test OpenAI
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing OpenAI (text-embedding-3-small)...")
|
||||||
|
try:
|
||||||
|
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||||
|
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ OpenAI failed: {e}")
|
||||||
|
openai_results = None
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
if st_results and openai_results:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("=== COMPARISON SUMMARY ===")
|
||||||
10
docs/faq.md
Normal file
10
docs/faq.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# FAQ
|
||||||
|
|
||||||
|
## 1. My building time seems long
|
||||||
|
|
||||||
|
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
22
docs/features.md
Normal file
22
docs/features.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# ✨ Detailed Features
|
||||||
|
|
||||||
|
## 🔥 Core Features
|
||||||
|
|
||||||
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
|
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||||
|
|
||||||
|
## 🛠️ Technical Highlights
|
||||||
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
|
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||||
|
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||||
|
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||||
|
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||||
|
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
||||||
|
|
||||||
|
## 🎨 Developer Experience
|
||||||
|
|
||||||
|
- **Simple Python API** - Get started in minutes
|
||||||
|
- **Extensible backend system** - Easy to add new algorithms
|
||||||
|
- **Comprehensive examples** - From basic usage to production deployment
|
||||||
75
docs/normalized_embeddings.md
Normal file
75
docs/normalized_embeddings.md
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Normalized Embeddings Support in LEANN
|
||||||
|
|
||||||
|
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
||||||
|
|
||||||
|
## What are Normalized Embeddings?
|
||||||
|
|
||||||
|
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
||||||
|
|
||||||
|
## Automatic Detection
|
||||||
|
|
||||||
|
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
||||||
|
|
||||||
|
1. **Automatically set `distance_metric="cosine"`** if not specified
|
||||||
|
2. **Show a warning** if you manually specify a different distance metric
|
||||||
|
3. **Provide optimal search performance** with the correct metric
|
||||||
|
|
||||||
|
## Supported Normalized Embedding Models
|
||||||
|
|
||||||
|
### OpenAI
|
||||||
|
All OpenAI text embedding models are normalized:
|
||||||
|
- `text-embedding-ada-002`
|
||||||
|
- `text-embedding-3-small`
|
||||||
|
- `text-embedding-3-large`
|
||||||
|
|
||||||
|
### Voyage AI
|
||||||
|
All Voyage AI embedding models are normalized:
|
||||||
|
- `voyage-2`
|
||||||
|
- `voyage-3`
|
||||||
|
- `voyage-large-2`
|
||||||
|
- `voyage-multilingual-2`
|
||||||
|
- `voyage-code-2`
|
||||||
|
|
||||||
|
### Cohere
|
||||||
|
All Cohere embedding models are normalized:
|
||||||
|
- `embed-english-v3.0`
|
||||||
|
- `embed-multilingual-v3.0`
|
||||||
|
- `embed-english-light-v3.0`
|
||||||
|
- `embed-multilingual-light-v3.0`
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
# Automatic detection - will use cosine distance
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai"
|
||||||
|
)
|
||||||
|
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
||||||
|
# Automatically setting distance_metric='cosine'
|
||||||
|
|
||||||
|
# Manual override (not recommended)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
distance_metric="mips" # Will show warning
|
||||||
|
)
|
||||||
|
# Warning: Using 'mips' distance metric with normalized embeddings...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Non-Normalized Embeddings
|
||||||
|
|
||||||
|
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
||||||
|
|
||||||
|
## Why This Matters
|
||||||
|
|
||||||
|
Using the wrong distance metric with normalized embeddings can lead to:
|
||||||
|
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
||||||
|
- **Incorrect ranking** of search results
|
||||||
|
- **Suboptimal performance** compared to using the correct metric
|
||||||
|
|
||||||
|
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
||||||
21
docs/roadmap.md
Normal file
21
docs/roadmap.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# 📈 Roadmap
|
||||||
|
|
||||||
|
## 🎯 Q2 2025
|
||||||
|
|
||||||
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
|
- [X] HNSW backend integration
|
||||||
|
- [X] Real-time embedding pipeline
|
||||||
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
## 🚀 Q3 2025
|
||||||
|
|
||||||
|
- [ ] Advanced caching strategies
|
||||||
|
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||||
|
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||||
|
- [ ] Add OpenAI recompute API
|
||||||
|
|
||||||
|
## 🌟 Q4 2025
|
||||||
|
|
||||||
|
- [ ] Integration with LangChain/LlamaIndex
|
||||||
|
- [ ] Visual similarity search
|
||||||
|
- [ ] Query rewrtiting, rerank and expansion
|
||||||
@@ -1,16 +1,23 @@
|
|||||||
"""
|
"""
|
||||||
Simple demo showing basic leann usage
|
Simple demo showing basic leann usage
|
||||||
Run: uv run python examples/simple_demo.py
|
Run: uv run python examples/basic_demo.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
|
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
description="Simple demo of Leann with selectable embedding models."
|
||||||
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding_model",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||||
@@ -74,7 +81,7 @@ def main():
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
print("Demo completed! Try running:")
|
print("Demo completed! Try running:")
|
||||||
print(" uv run python examples/document_search.py")
|
print(" uv run python apps/document_rag.py")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Binary file not shown.
@@ -1,146 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Document search demo with recompute mode
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
|
|
||||||
# Import backend packages to trigger plugin registration
|
|
||||||
try:
|
|
||||||
import leann_backend_diskann
|
|
||||||
import leann_backend_hnsw
|
|
||||||
print("INFO: Backend packages imported successfully.")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
|
||||||
|
|
||||||
# Import upper-level API from leann-core
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
|
|
||||||
|
|
||||||
def load_sample_documents():
|
|
||||||
"""Create sample documents for demonstration"""
|
|
||||||
docs = [
|
|
||||||
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
|
|
||||||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
|
||||||
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
|
|
||||||
]
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def main():
|
|
||||||
print("==========================================================")
|
|
||||||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
|
||||||
print("==========================================================")
|
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_indices")
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
|
||||||
BACKEND_TO_TEST = "diskann"
|
|
||||||
|
|
||||||
if INDEX_DIR.exists():
|
|
||||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
|
||||||
shutil.rmtree(INDEX_DIR)
|
|
||||||
|
|
||||||
# --- 1. Build index ---
|
|
||||||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=BACKEND_TO_TEST,
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = load_sample_documents()
|
|
||||||
print(f"Loaded {len(documents)} sample documents.")
|
|
||||||
for doc in documents:
|
|
||||||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
print(f"\nIndex built!")
|
|
||||||
|
|
||||||
# --- 2. Basic search demo ---
|
|
||||||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
|
||||||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
|
||||||
|
|
||||||
query = "What is machine learning?"
|
|
||||||
print(f"\nQuery: '{query}'")
|
|
||||||
|
|
||||||
print("\n--- Basic search mode (PQ computation) ---")
|
|
||||||
start_time = time.time()
|
|
||||||
results = searcher.search(query, top_k=2)
|
|
||||||
basic_time = time.time() - start_time
|
|
||||||
|
|
||||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
|
||||||
print(">>> Basic search results <<<")
|
|
||||||
for i, res in enumerate(results, 1):
|
|
||||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
|
||||||
|
|
||||||
# --- 3. Recompute search demo ---
|
|
||||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
|
||||||
|
|
||||||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
|
||||||
|
|
||||||
# Configure recompute parameters
|
|
||||||
recompute_params = {
|
|
||||||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
|
||||||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
|
||||||
"skip_search_reorder": True, # Skip search reordering
|
|
||||||
"dedup_node_dis": True, # Enable node distance deduplication
|
|
||||||
"prune_ratio": 0.1, # Pruning ratio 10%
|
|
||||||
"batch_recompute": False, # Don't use batch recomputation
|
|
||||||
"global_pruning": False, # Don't use global pruning
|
|
||||||
"zmq_port": 5555, # ZMQ port
|
|
||||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
|
|
||||||
}
|
|
||||||
|
|
||||||
print("Recompute parameter configuration:")
|
|
||||||
for key, value in recompute_params.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
print(f"\n🔄 Executing Recompute search...")
|
|
||||||
try:
|
|
||||||
start_time = time.time()
|
|
||||||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
|
||||||
recompute_time = time.time() - start_time
|
|
||||||
|
|
||||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
|
||||||
print(">>> Recompute search results <<<")
|
|
||||||
for i, res in enumerate(recompute_results, 1):
|
|
||||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
print(f"\n--- Result comparison ---")
|
|
||||||
print(f"Basic search time: {basic_time:.3f} seconds")
|
|
||||||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
|
||||||
|
|
||||||
print("\nBasic search vs Recompute results:")
|
|
||||||
for i in range(min(len(results), len(recompute_results))):
|
|
||||||
basic_score = results[i]['score']
|
|
||||||
recompute_score = recompute_results[i]['score']
|
|
||||||
score_diff = abs(basic_score - recompute_score)
|
|
||||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
|
||||||
|
|
||||||
if recompute_time > basic_time:
|
|
||||||
print(f"✅ Recompute mode working correctly (more accurate but slower)")
|
|
||||||
else:
|
|
||||||
print(f"ℹ️ Recompute time is unusually fast, network recomputation may not be enabled")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Recompute search failed: {e}")
|
|
||||||
print("This usually indicates an embedding server connection issue")
|
|
||||||
|
|
||||||
# --- 4. Chat demo ---
|
|
||||||
print(f"\n[PHASE 4] Starting chat session...")
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH)
|
|
||||||
chat_response = chat.ask(query)
|
|
||||||
print(f"You: {query}")
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
print("\n==========================================================")
|
|
||||||
print("✅ Demo finished successfully!")
|
|
||||||
print("==========================================================")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
import faulthandler
|
|
||||||
faulthandler.enable()
|
|
||||||
|
|
||||||
from llama_index.core import SimpleDirectoryReader, Settings
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
from llama_index.node_parser.docling import DoclingNodeParser
|
|
||||||
from llama_index.readers.docling import DoclingReader
|
|
||||||
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import dotenv
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
import leann_backend_hnsw # Import to ensure backend registration
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
reader = DoclingReader(export_type=DoclingReader.ExportType.JSON)
|
|
||||||
file_extractor: dict[str, BaseReader] = {
|
|
||||||
".docx": reader,
|
|
||||||
".pptx": reader,
|
|
||||||
".pdf": reader,
|
|
||||||
".xlsx": reader,
|
|
||||||
}
|
|
||||||
node_parser = DoclingNodeParser(
|
|
||||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=64)
|
|
||||||
)
|
|
||||||
print("Loading documents...")
|
|
||||||
documents = SimpleDirectoryReader(
|
|
||||||
"examples/data",
|
|
||||||
recursive=True,
|
|
||||||
file_extractor=file_extractor,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
|
|
||||||
).load_data(show_progress=True)
|
|
||||||
print("Documents loaded.")
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.text)
|
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_pdf_index")
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# CSR compact mode with recompute
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH)
|
|
||||||
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
|
||||||
print(f"You: {query}")
|
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever")
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
43
examples/mlx_demo.py
Normal file
43
examples/mlx_demo.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
|
||||||
|
# Define the path for our new MLX-based index
|
||||||
|
INDEX_PATH = "./mlx_diskann_index/leann"
|
||||||
|
|
||||||
|
if os.path.exists(INDEX_PATH + ".meta.json"):
|
||||||
|
print(f"Index already exists at {INDEX_PATH}. Skipping build.")
|
||||||
|
else:
|
||||||
|
print("Initializing LeannBuilder with MLX support...")
|
||||||
|
# 1. Configure LeannBuilder to use MLX
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
|
||||||
|
embedding_mode="mlx",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Add documents
|
||||||
|
print("Adding documents...")
|
||||||
|
docs = [
|
||||||
|
"MLX is an array framework for machine learning on Apple silicon.",
|
||||||
|
"It was designed by Apple's machine learning research team.",
|
||||||
|
"The mlx-community organization provides pre-trained models in MLX format.",
|
||||||
|
"It supports operations on multi-dimensional arrays.",
|
||||||
|
"Leann can now use MLX for its embedding models.",
|
||||||
|
]
|
||||||
|
for doc in docs:
|
||||||
|
builder.add_text(doc)
|
||||||
|
|
||||||
|
# 3. Build the index
|
||||||
|
print(f"Building the MLX-based index at: {INDEX_PATH}")
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print("\nSuccessfully built the index with MLX embeddings!")
|
||||||
|
print(f"Check the metadata file: {INDEX_PATH}.meta.json")
|
||||||
|
|
||||||
|
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH)
|
||||||
|
# add query
|
||||||
|
query = "MLX is an array framework for machine learning on Apple silicon."
|
||||||
|
print(f"Query: {query}")
|
||||||
|
response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1)
|
||||||
|
print(f"Response: {response}")
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "0.1.0",
|
|
||||||
"backend_name": "diskann",
|
|
||||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
|
||||||
"num_chunks": 6,
|
|
||||||
"chunks": [
|
|
||||||
{
|
|
||||||
"text": "Python is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Machine learning transforms industries",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Neural networks process complex data",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Java is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "C++ is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "C# is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
0
packages/__init__.py
Normal file
0
packages/__init__.py
Normal file
@@ -1,8 +1,8 @@
|
|||||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||||
|
|
||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
project(leann_backend_diskann_wrapper)
|
project(leann_backend_diskann_wrapper)
|
||||||
|
|
||||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
# DiskANN will handle everything itself, including compiling Python bindings
|
||||||
add_subdirectory(src/third_party/DiskANN)
|
add_subdirectory(src/third_party/DiskANN)
|
||||||
|
|||||||
1
packages/leann-backend-diskann/__init__.py
Normal file
1
packages/leann-backend-diskann/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# This file makes the directory a Python package
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
from . import diskann_backend as diskann_backend
|
||||||
|
|||||||
@@ -1,30 +1,70 @@
|
|||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import struct
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import threading
|
import logging
|
||||||
import time
|
import os
|
||||||
import atexit
|
import struct
|
||||||
import socket
|
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from leann.registry import register_backend
|
import numpy as np
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
LeannBackendFactoryInterface,
|
|
||||||
LeannBackendBuilderInterface,
|
LeannBackendBuilderInterface,
|
||||||
LeannBackendSearcherInterface
|
LeannBackendFactoryInterface,
|
||||||
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
from . import _diskannpy as diskannpy
|
from leann.registry import register_backend
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def suppress_cpp_output_if_needed():
|
||||||
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||||
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
|
||||||
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||||
|
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
if not should_suppress:
|
||||||
|
# Don't suppress, just yield
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save original file descriptors
|
||||||
|
stdout_fd = sys.stdout.fileno()
|
||||||
|
stderr_fd = sys.stderr.fileno()
|
||||||
|
|
||||||
|
# Save original stdout/stderr
|
||||||
|
stdout_dup = os.dup(stdout_fd)
|
||||||
|
stderr_dup = os.dup(stderr_fd)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Redirect to /dev/null
|
||||||
|
devnull = os.open(os.devnull, os.O_WRONLY)
|
||||||
|
os.dup2(devnull, stdout_fd)
|
||||||
|
os.dup2(devnull, stderr_fd)
|
||||||
|
os.close(devnull)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original file descriptors
|
||||||
|
os.dup2(stdout_dup, stdout_fd)
|
||||||
|
os.dup2(stderr_dup, stderr_fd)
|
||||||
|
os.close(stdout_dup)
|
||||||
|
os.close(stderr_dup)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_diskann_metrics():
|
||||||
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mips": diskannpy.Metric.INNER_PRODUCT,
|
||||||
|
"l2": diskannpy.Metric.L2,
|
||||||
|
"cosine": diskannpy.Metric.COSINE,
|
||||||
|
}
|
||||||
|
|
||||||
METRIC_MAP = {
|
|
||||||
"mips": diskannpy.Metric.INNER_PRODUCT,
|
|
||||||
"l2": diskannpy.Metric.L2,
|
|
||||||
"cosine": diskannpy.Metric.COSINE,
|
|
||||||
}
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def chdir(path):
|
def chdir(path):
|
||||||
@@ -35,102 +75,14 @@ def chdir(path):
|
|||||||
finally:
|
finally:
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
|
|
||||||
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
|
|
||||||
|
def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||||
num_vectors, dim = data.shape
|
num_vectors, dim = data.shape
|
||||||
with open(file_path, 'wb') as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(struct.pack('I', num_vectors))
|
f.write(struct.pack("I", num_vectors))
|
||||||
f.write(struct.pack('I', dim))
|
f.write(struct.pack("I", dim))
|
||||||
f.write(data.tobytes())
|
f.write(data.tobytes())
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.server_process = None
|
|
||||||
self.server_port = None
|
|
||||||
atexit.register(self.stop_server)
|
|
||||||
|
|
||||||
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"):
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查端口是否已被其他无关进程占用
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
print(f"INFO: Starting session-level embedding server as a background process...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
command = [
|
|
||||||
sys.executable,
|
|
||||||
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
|
||||||
"--zmq-port", str(port),
|
|
||||||
"--model-name", model_name
|
|
||||||
]
|
|
||||||
project_root = Path(__file__).parent.parent.parent.parent
|
|
||||||
print(f"INFO: Running command from project root: {project_root}")
|
|
||||||
self.server_process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
cwd=project_root,
|
|
||||||
# stdout=subprocess.PIPE,
|
|
||||||
# stderr=subprocess.PIPE,
|
|
||||||
text=True,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
self.server_port = port
|
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
|
||||||
|
|
||||||
max_wait, wait_interval = 30, 0.5
|
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"✅ Embedding server is up and ready for this session.")
|
|
||||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
|
||||||
log_thread.start()
|
|
||||||
return True
|
|
||||||
if self.server_process.poll() is not None:
|
|
||||||
print("❌ ERROR: Server process terminated unexpectedly during startup.")
|
|
||||||
self._log_monitor()
|
|
||||||
return False
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
|
|
||||||
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
|
|
||||||
self.stop_server()
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _log_monitor(self):
|
|
||||||
if not self.server_process:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
if self.server_process.stdout:
|
|
||||||
for line in iter(self.server_process.stdout.readline, ''):
|
|
||||||
print(f"[EmbeddingServer LOG]: {line.strip()}")
|
|
||||||
self.server_process.stdout.close()
|
|
||||||
if self.server_process.stderr:
|
|
||||||
for line in iter(self.server_process.stderr.readline, ''):
|
|
||||||
print(f"[EmbeddingServer ERROR]: {line.strip()}")
|
|
||||||
self.server_process.stderr.close()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Log monitor error: {e}")
|
|
||||||
|
|
||||||
def stop_server(self):
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
|
|
||||||
self.server_process.terminate()
|
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=5)
|
|
||||||
print("INFO: Server process terminated.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
print("WARNING: Server process did not terminate gracefully, killing it.")
|
|
||||||
self.server_process.kill()
|
|
||||||
self.server_process = None
|
|
||||||
|
|
||||||
@register_backend("diskann")
|
@register_backend("diskann")
|
||||||
class DiskannBackend(LeannBackendFactoryInterface):
|
class DiskannBackend(LeannBackendFactoryInterface):
|
||||||
@@ -140,134 +92,179 @@ class DiskannBackend(LeannBackendFactoryInterface):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
||||||
path = Path(index_path)
|
|
||||||
meta_path = path.parent / f"{path.name}.meta.json"
|
|
||||||
if not meta_path.exists():
|
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
|
||||||
|
|
||||||
with open(meta_path, 'r') as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
|
|
||||||
dimensions = meta.get("dimensions")
|
|
||||||
if not dimensions:
|
|
||||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
|
||||||
|
|
||||||
kwargs['dimensions'] = dimensions
|
|
||||||
return DiskannSearcher(index_path, **kwargs)
|
return DiskannSearcher(index_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class DiskannBuilder(LeannBackendBuilderInterface):
|
class DiskannBuilder(LeannBackendBuilderInterface):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
index_prefix = path.stem
|
index_prefix = path.stem
|
||||||
|
|
||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
|
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
if not data.flags['C_CONTIGUOUS']:
|
|
||||||
data = np.ascontiguousarray(data)
|
|
||||||
|
|
||||||
data_filename = f"{index_prefix}_data.bin"
|
data_filename = f"{index_prefix}_data.bin"
|
||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
metric_str = build_kwargs.get("distance_metric", "mips").lower()
|
metric_enum = _get_diskann_metrics().get(
|
||||||
metric_enum = METRIC_MAP.get(metric_str)
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
|
)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
raise ValueError(
|
||||||
|
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||||
complexity = build_kwargs.get("complexity", 64)
|
)
|
||||||
graph_degree = build_kwargs.get("graph_degree", 32)
|
|
||||||
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
|
|
||||||
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
|
|
||||||
num_threads = build_kwargs.get("num_threads", 8)
|
|
||||||
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
|
|
||||||
codebook_prefix = ""
|
|
||||||
|
|
||||||
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
|
|
||||||
with chdir(index_dir):
|
with chdir(index_dir):
|
||||||
diskannpy.build_disk_float_index(
|
diskannpy.build_disk_float_index(
|
||||||
metric_enum,
|
metric_enum,
|
||||||
data_filename,
|
data_filename,
|
||||||
index_prefix,
|
index_prefix,
|
||||||
complexity,
|
build_kwargs.get("complexity", 64),
|
||||||
graph_degree,
|
build_kwargs.get("graph_degree", 32),
|
||||||
final_index_ram_limit,
|
build_kwargs.get("search_memory_maximum", 4.0),
|
||||||
indexing_ram_budget,
|
build_kwargs.get("build_memory_maximum", 8.0),
|
||||||
num_threads,
|
build_kwargs.get("num_threads", 8),
|
||||||
pq_disk_bytes,
|
build_kwargs.get("pq_disk_bytes", 0),
|
||||||
codebook_prefix
|
"",
|
||||||
)
|
)
|
||||||
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
finally:
|
||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
os.remove(temp_data_file)
|
os.remove(temp_data_file)
|
||||||
|
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
|
||||||
|
|
||||||
class DiskannSearcher(LeannBackendSearcherInterface):
|
|
||||||
|
class DiskannSearcher(BaseSearcher):
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
super().__init__(
|
||||||
index_dir = path.parent
|
index_path,
|
||||||
index_prefix = path.stem
|
backend_module_name="leann_backend_diskann.diskann_embedding_server",
|
||||||
metric_str = kwargs.get("distance_metric", "mips").lower()
|
**kwargs,
|
||||||
metric_enum = METRIC_MAP.get(metric_str)
|
)
|
||||||
if metric_enum is None:
|
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
|
||||||
|
|
||||||
num_threads = kwargs.get("num_threads", 8)
|
# Initialize DiskANN index with suppressed C++ output based on log level
|
||||||
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
with suppress_cpp_output_if_needed():
|
||||||
dimensions = kwargs.get("dimensions")
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
if not dimensions:
|
|
||||||
raise ValueError("Vector dimension not provided to DiskannSearcher.")
|
|
||||||
|
|
||||||
try:
|
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||||
full_index_prefix = str(index_dir / index_prefix)
|
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||||
self._index = diskannpy.StaticDiskFloatIndex(
|
if metric_enum is None:
|
||||||
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
|
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||||
|
|
||||||
|
self.num_threads = kwargs.get("num_threads", 8)
|
||||||
|
|
||||||
|
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||||
|
# Store the initialization parameters for later use
|
||||||
|
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||||
|
self._init_params = {
|
||||||
|
"metric_enum": metric_enum,
|
||||||
|
"full_index_prefix": full_index_prefix,
|
||||||
|
"num_threads": self.num_threads,
|
||||||
|
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||||
|
"cache_mechanism": 1,
|
||||||
|
"pq_prefix": "",
|
||||||
|
"partition_prefix": "",
|
||||||
|
}
|
||||||
|
self._diskannpy = diskannpy
|
||||||
|
self._current_zmq_port = None
|
||||||
|
self._index = None
|
||||||
|
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
||||||
|
|
||||||
|
def _ensure_index_loaded(self, zmq_port: int):
|
||||||
|
"""Ensure the index is loaded with the correct zmq_port."""
|
||||||
|
if self._index is None or self._current_zmq_port != zmq_port:
|
||||||
|
# Need to (re)load the index with the correct zmq_port
|
||||||
|
with suppress_cpp_output_if_needed():
|
||||||
|
if self._index is not None:
|
||||||
|
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
||||||
|
|
||||||
|
self._index = self._diskannpy.StaticDiskFloatIndex(
|
||||||
|
self._init_params["metric_enum"],
|
||||||
|
self._init_params["full_index_prefix"],
|
||||||
|
self._init_params["num_threads"],
|
||||||
|
self._init_params["num_nodes_to_cache"],
|
||||||
|
self._init_params["cache_mechanism"],
|
||||||
|
zmq_port,
|
||||||
|
self._init_params["pq_prefix"],
|
||||||
|
self._init_params["partition_prefix"],
|
||||||
|
)
|
||||||
|
self._current_zmq_port = zmq_port
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: np.ndarray,
|
||||||
|
top_k: int,
|
||||||
|
complexity: int = 64,
|
||||||
|
beam_width: int = 1,
|
||||||
|
prune_ratio: float = 0.0,
|
||||||
|
recompute_embeddings: bool = False,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
zmq_port: int | None = None,
|
||||||
|
batch_recompute: bool = False,
|
||||||
|
dedup_node_dis: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Search for nearest neighbors using DiskANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query vectors (B, D) where B is batch size, D is dimension
|
||||||
|
top_k: Number of nearest neighbors to return
|
||||||
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel IO requests per iteration
|
||||||
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server
|
||||||
|
pruning_strategy: PQ candidate selection strategy:
|
||||||
|
- "global": Use global pruning strategy (default)
|
||||||
|
- "local": Use local pruning strategy
|
||||||
|
- "proportional": Not supported in DiskANN, falls back to global
|
||||||
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
|
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||||
|
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||||
|
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||||
|
"""
|
||||||
|
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
||||||
|
if recompute_embeddings:
|
||||||
|
if zmq_port is None:
|
||||||
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
|
self._ensure_index_loaded(zmq_port)
|
||||||
|
else:
|
||||||
|
# If not recomputing, we still need an index, use a default port
|
||||||
|
if self._index is None:
|
||||||
|
self._ensure_index_loaded(6666) # Default port when not recomputing
|
||||||
|
|
||||||
|
# DiskANN doesn't support "proportional" strategy
|
||||||
|
if pruning_strategy == "proportional":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
||||||
)
|
)
|
||||||
self.num_threads = num_threads
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager()
|
|
||||||
print("✅ DiskANN index loaded successfully.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
|
||||||
complexity = kwargs.get("complexity", 256)
|
|
||||||
beam_width = kwargs.get("beam_width", 4)
|
|
||||||
|
|
||||||
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
|
|
||||||
skip_search_reorder = kwargs.get("skip_search_reorder", False)
|
|
||||||
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False)
|
|
||||||
dedup_node_dis = kwargs.get("dedup_node_dis", False)
|
|
||||||
prune_ratio = kwargs.get("prune_ratio", 0.0)
|
|
||||||
batch_recompute = kwargs.get("batch_recompute", False)
|
|
||||||
global_pruning = kwargs.get("global_pruning", False)
|
|
||||||
|
|
||||||
if recompute_beighbor_embeddings:
|
|
||||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
|
||||||
zmq_port = kwargs.get("zmq_port", 6666)
|
|
||||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
|
||||||
|
|
||||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
|
|
||||||
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
|
|
||||||
kwargs['recompute_beighbor_embeddings'] = False
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
if query.ndim == 1:
|
|
||||||
query = np.expand_dims(query, axis=0)
|
|
||||||
|
|
||||||
try:
|
# Map pruning_strategy to DiskANN's global_pruning parameter
|
||||||
|
if pruning_strategy == "local":
|
||||||
|
use_global_pruning = False
|
||||||
|
else: # "global"
|
||||||
|
use_global_pruning = True
|
||||||
|
|
||||||
|
# Perform search with suppressed C++ output based on log level
|
||||||
|
with suppress_cpp_output_if_needed():
|
||||||
labels, distances = self._index.batch_search(
|
labels, distances = self._index.batch_search(
|
||||||
query,
|
query,
|
||||||
query.shape[0],
|
query.shape[0],
|
||||||
@@ -275,21 +272,15 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
complexity,
|
complexity,
|
||||||
beam_width,
|
beam_width,
|
||||||
self.num_threads,
|
self.num_threads,
|
||||||
USE_DEFERRED_FETCH,
|
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||||
skip_search_reorder,
|
kwargs.get("skip_search_reorder", False),
|
||||||
recompute_beighbor_embeddings,
|
recompute_embeddings,
|
||||||
dedup_node_dis,
|
dedup_node_dis,
|
||||||
prune_ratio,
|
prune_ratio,
|
||||||
batch_recompute,
|
batch_recompute,
|
||||||
global_pruning
|
use_global_pruning,
|
||||||
)
|
)
|
||||||
return {"labels": labels, "distances": distances}
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
|
|
||||||
batch_size = query.shape[0]
|
|
||||||
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
|
|
||||||
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
|
|
||||||
|
|
||||||
def __del__(self):
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
if hasattr(self, 'embedding_server_manager'):
|
|
||||||
self.embedding_server_manager.stop_server()
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -0,0 +1,284 @@
|
|||||||
|
"""
|
||||||
|
DiskANN-specific embedding server
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
# Ensure we have a handler if none exists
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def create_diskann_embedding_server(
|
||||||
|
passages_file: str | None = None,
|
||||||
|
zmq_port: int = 5555,
|
||||||
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
distance_metric: str = "l2",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||||
|
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
||||||
|
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||||
|
|
||||||
|
# Add leann-core to path for unified embedding computation
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||||
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.api import PassageManager
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import embedding computation module: {e}")
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
sys.path.pop(0)
|
||||||
|
|
||||||
|
# Check port availability
|
||||||
|
import socket
|
||||||
|
|
||||||
|
def check_port(port):
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
if check_port(zmq_port):
|
||||||
|
logger.error(f"Port {zmq_port} is already in use")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only support metadata file, fail fast for everything else
|
||||||
|
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||||
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||||
|
|
||||||
|
# Load metadata to get passage sources
|
||||||
|
with open(passages_file) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passages = PassageManager(meta["passage_sources"])
|
||||||
|
logger.info(
|
||||||
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import protobuf after ensuring the path is correct
|
||||||
|
try:
|
||||||
|
from . import embedding_pb2
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import protobuf module: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def zmq_server_thread():
|
||||||
|
"""ZMQ server thread using REP socket for universal compatibility"""
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(
|
||||||
|
zmq.REP
|
||||||
|
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||||
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# REP socket receives single-part messages
|
||||||
|
message = socket.recv()
|
||||||
|
|
||||||
|
# Check for empty messages - REP socket requires response to every request
|
||||||
|
if len(message) == 0:
|
||||||
|
logger.debug("Received empty message, sending empty response")
|
||||||
|
socket.send(b"") # REP socket must respond to every request
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
||||||
|
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
||||||
|
|
||||||
|
e2e_start = time.time()
|
||||||
|
|
||||||
|
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
||||||
|
texts = []
|
||||||
|
node_ids = []
|
||||||
|
is_text_request = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
req_proto.ParseFromString(message)
|
||||||
|
node_ids = list(req_proto.node_ids)
|
||||||
|
|
||||||
|
if not node_ids:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
||||||
|
)
|
||||||
|
except Exception as protobuf_error:
|
||||||
|
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
||||||
|
# Fallback to msgpack (for BaseSearcher direct text requests)
|
||||||
|
try:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
request = msgpack.unpackb(message)
|
||||||
|
# For BaseSearcher compatibility, request is a list of texts directly
|
||||||
|
if isinstance(request, list) and all(
|
||||||
|
isinstance(item, str) for item in request
|
||||||
|
):
|
||||||
|
texts = request
|
||||||
|
is_text_request = True
|
||||||
|
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
|
||||||
|
else:
|
||||||
|
raise ValueError("Not a valid msgpack text request")
|
||||||
|
except Exception as msgpack_error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look up texts by node IDs (only if not direct text request)
|
||||||
|
if not is_text_request:
|
||||||
|
for nid in node_ids:
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"Passage ID {nid} not found: {e}")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Debug logging
|
||||||
|
logger.debug(f"Processing {len(texts)} texts")
|
||||||
|
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||||
|
|
||||||
|
# Process embeddings using unified computation
|
||||||
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare response based on request type
|
||||||
|
if is_text_request:
|
||||||
|
# For BaseSearcher compatibility: return msgpack format
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb(embeddings.tolist())
|
||||||
|
else:
|
||||||
|
# For DiskANN C++ compatibility: return protobuf format
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Serialize embeddings data
|
||||||
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||||
|
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
|
||||||
|
# Send response back to the client
|
||||||
|
socket.send(response_data)
|
||||||
|
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
|
zmq_thread.start()
|
||||||
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
|
# Keep the main thread alive
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("DiskANN Server shutting down...")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||||
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
|
parser.add_argument(
|
||||||
|
"--passages-file",
|
||||||
|
type=str,
|
||||||
|
help="Metadata JSON file containing passage sources",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
type=str,
|
||||||
|
default="l2",
|
||||||
|
choices=["l2", "mips", "cosine"],
|
||||||
|
help="Distance metric for similarity computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create and start the DiskANN embedding server
|
||||||
|
create_diskann_embedding_server(
|
||||||
|
passages_file=args.passages_file,
|
||||||
|
zmq_port=args.zmq_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
@@ -1,27 +1,28 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# source: embedding.proto
|
# source: embedding.proto
|
||||||
|
# ruff: noqa
|
||||||
"""Generated protocol buffer code."""
|
"""Generated protocol buffer code."""
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
from google.protobuf import descriptor as _descriptor
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
|
||||||
# @@protoc_insertion_point(imports)
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||||
|
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
|
)
|
||||||
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
|
DESCRIPTOR._options = None
|
||||||
DESCRIPTOR._options = None
|
_NODEEMBEDDINGREQUEST._serialized_start = 35
|
||||||
_NODEEMBEDDINGREQUEST._serialized_start=35
|
_NODEEMBEDDINGREQUEST._serialized_end = 75
|
||||||
_NODEEMBEDDINGREQUEST._serialized_end=75
|
_NODEEMBEDDINGRESPONSE._serialized_start = 77
|
||||||
_NODEEMBEDDINGRESPONSE._serialized_start=77
|
_NODEEMBEDDINGRESPONSE._serialized_end = 166
|
||||||
_NODEEMBEDDINGRESPONSE._serialized_end=166
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|||||||
@@ -1,397 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
import argparse
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import zmq
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
RED = "\033[91m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
# 简化的文档存储 - 替代 LazyPassages
|
|
||||||
class SimpleDocumentStore:
|
|
||||||
"""简化的文档存储,支持任意ID"""
|
|
||||||
def __init__(self, documents: dict = None):
|
|
||||||
self.documents = documents or {}
|
|
||||||
# 默认演示文档
|
|
||||||
self.default_docs = {
|
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
|
||||||
1: "Machine learning builds systems that learn from data.",
|
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __getitem__(self, doc_id):
|
|
||||||
doc_id = int(doc_id)
|
|
||||||
|
|
||||||
# 优先使用指定的文档
|
|
||||||
if doc_id in self.documents:
|
|
||||||
return {"text": self.documents[doc_id]}
|
|
||||||
|
|
||||||
# 其次使用默认演示文档
|
|
||||||
if doc_id in self.default_docs:
|
|
||||||
return {"text": self.default_docs[doc_id]}
|
|
||||||
|
|
||||||
# 对于任意其他ID,返回通用文档
|
|
||||||
fallback_docs = [
|
|
||||||
"This is a general document about technology and programming concepts.",
|
|
||||||
"This document discusses machine learning and artificial intelligence topics.",
|
|
||||||
"This content covers data structures, algorithms, and computer science fundamentals.",
|
|
||||||
"This is a document about software engineering and development practices.",
|
|
||||||
"This content focuses on databases, data management, and information systems."
|
|
||||||
]
|
|
||||||
|
|
||||||
# 根据ID选择一个fallback文档
|
|
||||||
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
|
|
||||||
return {"text": f"[ID:{doc_id}] {fallback_text}"}
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.documents) + len(self.default_docs)
|
|
||||||
|
|
||||||
def create_embedding_server_thread(
|
|
||||||
zmq_port=5555,
|
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
max_batch_size=128,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
在当前线程中创建并运行 embedding server
|
|
||||||
这个函数设计为在单独的线程中调用
|
|
||||||
"""
|
|
||||||
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 检查端口是否已被占用
|
|
||||||
import socket
|
|
||||||
def check_port(port):
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
if check_port(zmq_port):
|
|
||||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# 选择设备
|
|
||||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
|
||||||
cuda_available = torch.cuda.is_available()
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
device = torch.device("cuda")
|
|
||||||
print("INFO: Using CUDA device")
|
|
||||||
elif mps_available:
|
|
||||||
device = torch.device("mps")
|
|
||||||
print("INFO: Using MPS device (Apple Silicon)")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
print("INFO: Using CPU device")
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
print(f"INFO: Loading model {model_name}")
|
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
|
||||||
|
|
||||||
# 优化模型
|
|
||||||
if cuda_available or mps_available:
|
|
||||||
try:
|
|
||||||
model = model.half()
|
|
||||||
model = torch.compile(model)
|
|
||||||
print(f"INFO: Using FP16 precision with model: {model_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
|
||||||
|
|
||||||
# 默认演示文档
|
|
||||||
demo_documents = {
|
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
|
||||||
1: "Machine learning builds systems that learn from data.",
|
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
|
||||||
}
|
|
||||||
|
|
||||||
passages = SimpleDocumentStore(demo_documents)
|
|
||||||
print(f"INFO: Loaded {len(passages)} demo documents")
|
|
||||||
|
|
||||||
class DeviceTimer:
|
|
||||||
"""设备计时器"""
|
|
||||||
def __init__(self, name="", device=device):
|
|
||||||
self.name = name
|
|
||||||
self.device = device
|
|
||||||
self.start_time = 0
|
|
||||||
self.end_time = 0
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
else:
|
|
||||||
self.start_event = None
|
|
||||||
self.end_event = None
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start()
|
|
||||||
yield
|
|
||||||
self.end()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
if cuda_available:
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.start_event.record()
|
|
||||||
else:
|
|
||||||
if self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
if cuda_available:
|
|
||||||
self.end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
else:
|
|
||||||
if self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.end_time = time.time()
|
|
||||||
|
|
||||||
def elapsed_time(self):
|
|
||||||
if cuda_available:
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
|
||||||
else:
|
|
||||||
return self.end_time - self.start_time
|
|
||||||
|
|
||||||
def print_elapsed(self):
|
|
||||||
print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds")
|
|
||||||
|
|
||||||
def process_batch(texts_batch, ids_batch, missing_ids):
|
|
||||||
"""处理文本批次"""
|
|
||||||
batch_size = len(texts_batch)
|
|
||||||
print(f"INFO: Processing batch of size {batch_size}")
|
|
||||||
|
|
||||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
|
||||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
|
||||||
embed_timer = DeviceTimer("embedding (batch)", device)
|
|
||||||
pool_timer = DeviceTimer("mean pooling (batch)", device)
|
|
||||||
|
|
||||||
with tokenize_timer.timing():
|
|
||||||
encoded_batch = tokenizer.batch_encode_plus(
|
|
||||||
texts_batch,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
max_length=256,
|
|
||||||
return_tensors="pt",
|
|
||||||
return_token_type_ids=False,
|
|
||||||
)
|
|
||||||
tokenize_timer.print_elapsed()
|
|
||||||
|
|
||||||
seq_length = encoded_batch["input_ids"].size(1)
|
|
||||||
print(f"Batch size: {batch_size}, Sequence length: {seq_length}")
|
|
||||||
|
|
||||||
with to_device_timer.timing():
|
|
||||||
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
|
||||||
to_device_timer.print_elapsed()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
with embed_timer.timing():
|
|
||||||
out = model(enc["input_ids"], enc["attention_mask"])
|
|
||||||
embed_timer.print_elapsed()
|
|
||||||
|
|
||||||
with pool_timer.timing():
|
|
||||||
hidden_states = out.last_hidden_state if hasattr(out, "last_hidden_state") else out
|
|
||||||
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
|
|
||||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
|
||||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
|
||||||
batch_embeddings = sum_embeddings / sum_mask
|
|
||||||
pool_timer.print_elapsed()
|
|
||||||
|
|
||||||
return batch_embeddings.cpu().numpy()
|
|
||||||
|
|
||||||
# ZMQ server 主循环 - 修改为REP套接字
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.ROUTER) # 改为REP套接字
|
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
|
||||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
|
||||||
|
|
||||||
# 设置超时
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5秒接收超时
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300秒发送超时
|
|
||||||
|
|
||||||
from . import embedding_pb2
|
|
||||||
|
|
||||||
print(f"INFO: Embedding server ready to serve requests")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
parts = socket.recv_multipart()
|
|
||||||
|
|
||||||
# --- 恢复稳健的消息格式判断 ---
|
|
||||||
# 必须检查 parts 的长度,避免 IndexError
|
|
||||||
if len(parts) >= 3:
|
|
||||||
identity = parts[0]
|
|
||||||
# empty = parts[1] # 中间的空帧我们通常不关心
|
|
||||||
message = parts[2]
|
|
||||||
elif len(parts) == 2:
|
|
||||||
# 也能处理没有空帧的情况
|
|
||||||
identity = parts[0]
|
|
||||||
message = parts[1]
|
|
||||||
else:
|
|
||||||
# 如果收到格式错误的消息,打印警告并忽略它,而不是崩溃
|
|
||||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
|
||||||
continue
|
|
||||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
|
||||||
|
|
||||||
e2e_start = time.time()
|
|
||||||
lookup_timer = DeviceTimer("text lookup", device)
|
|
||||||
|
|
||||||
# 解析请求
|
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
||||||
req_proto.ParseFromString(message)
|
|
||||||
node_ids = req_proto.node_ids
|
|
||||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
|
||||||
|
|
||||||
# 添加调试信息
|
|
||||||
if len(node_ids) > 0:
|
|
||||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
|
||||||
|
|
||||||
# 查找文本
|
|
||||||
texts = []
|
|
||||||
missing_ids = []
|
|
||||||
with lookup_timer.timing():
|
|
||||||
for nid in node_ids:
|
|
||||||
txtinfo = passages[nid]
|
|
||||||
txt = txtinfo["text"]
|
|
||||||
texts.append(txt)
|
|
||||||
lookup_timer.print_elapsed()
|
|
||||||
|
|
||||||
if missing_ids:
|
|
||||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
|
||||||
|
|
||||||
# 处理批次
|
|
||||||
total_size = len(texts)
|
|
||||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
|
||||||
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
if total_size > max_batch_size:
|
|
||||||
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
|
||||||
for i in range(0, total_size, max_batch_size):
|
|
||||||
end_idx = min(i + max_batch_size, total_size)
|
|
||||||
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
|
||||||
|
|
||||||
chunk_texts = texts[i:end_idx]
|
|
||||||
chunk_ids = node_ids[i:end_idx]
|
|
||||||
|
|
||||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
|
||||||
all_embeddings.append(embeddings_chunk)
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
|
|
||||||
hidden = np.vstack(all_embeddings)
|
|
||||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
|
||||||
else:
|
|
||||||
hidden = process_batch(texts, node_ids, missing_ids)
|
|
||||||
|
|
||||||
# 序列化响应
|
|
||||||
ser_start = time.time()
|
|
||||||
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
|
|
||||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
|
||||||
resp_proto.missing_ids.extend(missing_ids)
|
|
||||||
|
|
||||||
response_data = resp_proto.SerializeToString()
|
|
||||||
|
|
||||||
# REP 套接字发送单个响应
|
|
||||||
socket.send_multipart([identity, b'', response_data])
|
|
||||||
|
|
||||||
ser_end = time.time()
|
|
||||||
|
|
||||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
|
||||||
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
e2e_end = time.time()
|
|
||||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
|
||||||
# REP套接字不需要重新创建,只需要继续监听
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Error in ZMQ server: {e}")
|
|
||||||
try:
|
|
||||||
# 发送空响应以维持REQ-REP状态
|
|
||||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
socket.send(empty_resp.SerializeToString())
|
|
||||||
except:
|
|
||||||
# 如果发送失败,重新创建socket
|
|
||||||
socket.close()
|
|
||||||
socket = context.socket(zmq.REP)
|
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000)
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
||||||
print("INFO: ZMQ socket recreated after error")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to start embedding server: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# 保持原有的 create_embedding_server 函数不变,只添加线程化版本
|
|
||||||
def create_embedding_server(
|
|
||||||
domain="demo",
|
|
||||||
load_passages=True,
|
|
||||||
load_embeddings=False,
|
|
||||||
use_fp16=True,
|
|
||||||
use_int8=False,
|
|
||||||
use_cuda_graphs=False,
|
|
||||||
zmq_port=5555,
|
|
||||||
max_batch_size=128,
|
|
||||||
lazy_load_passages=False,
|
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
原有的 create_embedding_server 函数保持不变
|
|
||||||
这个是阻塞版本,用于直接运行
|
|
||||||
"""
|
|
||||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Embedding service")
|
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
|
||||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
|
||||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
|
||||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
|
||||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
|
||||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
|
||||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
help="Embedding model name")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
create_embedding_server(
|
|
||||||
domain=args.domain,
|
|
||||||
load_passages=args.load_passages,
|
|
||||||
load_embeddings=args.load_embeddings,
|
|
||||||
use_fp16=args.use_fp16,
|
|
||||||
use_int8=args.use_int8,
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
zmq_port=args.zmq_port,
|
|
||||||
max_batch_size=args.max_batch_size,
|
|
||||||
lazy_load_passages=args.lazy_load_passages,
|
|
||||||
model_name=args.model_name,
|
|
||||||
)
|
|
||||||
@@ -4,13 +4,16 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.1.0"
|
version = "0.1.16"
|
||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = ["leann-core==0.1.16", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# 关键:简化的 CMake 路径
|
# Key: simplified CMake path
|
||||||
cmake.source-dir = "third_party/DiskANN"
|
cmake.source-dir = "third_party/DiskANN"
|
||||||
# 关键:Python 包在根目录,路径完全匹配
|
# Key: Python package in root directory, paths match exactly
|
||||||
wheel.packages = ["leann_backend_diskann"]
|
wheel.packages = ["leann_backend_diskann"]
|
||||||
# 使用默认的 redirect 模式
|
# Use default redirect mode
|
||||||
editable.mode = "redirect"
|
editable.mode = "redirect"
|
||||||
|
cmake.build-type = "Release"
|
||||||
|
build.verbose = true
|
||||||
|
build.tool-args = ["-j8"]
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: 015c201141...af2a26481e
@@ -1,7 +1,38 @@
|
|||||||
# 最终简化版
|
|
||||||
cmake_minimum_required(VERSION 3.24)
|
cmake_minimum_required(VERSION 3.24)
|
||||||
project(leann_backend_hnsw_wrapper)
|
project(leann_backend_hnsw_wrapper)
|
||||||
|
set(CMAKE_C_COMPILER_WORKS 1)
|
||||||
|
set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||||
|
|
||||||
|
# Set OpenMP path for macOS
|
||||||
|
if(APPLE)
|
||||||
|
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||||
|
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||||
|
set(OpenMP_C_LIB_NAMES "omp")
|
||||||
|
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||||
|
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||||
|
|
||||||
|
# Force use of system libc++ to avoid version mismatch
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||||
|
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
|
||||||
|
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
|
||||||
|
|
||||||
|
# Set minimum macOS version for better compatibility
|
||||||
|
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Use system ZeroMQ instead of building from source
|
||||||
|
find_package(PkgConfig REQUIRED)
|
||||||
|
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||||
|
|
||||||
|
# Add cppzmq headers
|
||||||
|
include_directories(third_party/cppzmq)
|
||||||
|
|
||||||
|
# Configure msgpack-c - disable boost dependency
|
||||||
|
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||||
|
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||||
|
include_directories(third_party/msgpack-c/include)
|
||||||
|
|
||||||
|
# Faiss configuration - streamlined build
|
||||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||||
@@ -9,4 +40,24 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
|
# Disable additional SIMD versions to speed up compilation
|
||||||
|
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# Additional optimization options from INSTALL.md
|
||||||
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||||
|
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
|
||||||
|
|
||||||
|
# Avoid building demos and benchmarks
|
||||||
|
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
|
||||||
|
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# NEW: Tell Faiss to only build the generic version
|
||||||
|
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||||
|
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
add_subdirectory(third_party/faiss)
|
add_subdirectory(third_party/faiss)
|
||||||
@@ -1 +1 @@
|
|||||||
from . import hnsw_backend
|
from . import hnsw_backend as hnsw_backend
|
||||||
|
|||||||
@@ -1,87 +1,115 @@
|
|||||||
|
import argparse
|
||||||
|
import gc # Import garbage collector interface
|
||||||
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import gc # Import garbage collector interface
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# --- FourCCs (add more if needed) ---
|
# --- FourCCs (add more if needed) ---
|
||||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little')
|
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||||
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
||||||
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
||||||
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
|
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
|
||||||
|
|
||||||
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
||||||
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
|
NULL_INDEX_FOURCC = int.from_bytes(b"null", "little")
|
||||||
|
|
||||||
# --- Helper functions for reading/writing binary data ---
|
# --- Helper functions for reading/writing binary data ---
|
||||||
|
|
||||||
|
|
||||||
def read_struct(f, fmt):
|
def read_struct(f, fmt):
|
||||||
"""Reads data according to the struct format."""
|
"""Reads data according to the struct format."""
|
||||||
size = struct.calcsize(fmt)
|
size = struct.calcsize(fmt)
|
||||||
data = f.read(size)
|
data = f.read(size)
|
||||||
if len(data) != size:
|
if len(data) != size:
|
||||||
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.")
|
raise EOFError(
|
||||||
|
f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}."
|
||||||
|
)
|
||||||
return struct.unpack(fmt, data)[0]
|
return struct.unpack(fmt, data)[0]
|
||||||
|
|
||||||
|
|
||||||
def read_vector_raw(f, element_fmt_char):
|
def read_vector_raw(f, element_fmt_char):
|
||||||
"""Reads a vector (size followed by data), returns count and raw bytes."""
|
"""Reads a vector (size followed by data), returns count and raw bytes."""
|
||||||
count = -1 # Initialize count
|
count = -1 # Initialize count
|
||||||
total_bytes = -1 # Initialize total_bytes
|
total_bytes = -1 # Initialize total_bytes
|
||||||
try:
|
try:
|
||||||
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned
|
count = read_struct(f, "<Q") # size_t usually 64-bit unsigned
|
||||||
element_size = struct.calcsize(element_fmt_char)
|
element_size = struct.calcsize(element_fmt_char)
|
||||||
# --- FIX for MemoryError: Check for unreasonably large count ---
|
# --- FIX for MemoryError: Check for unreasonably large count ---
|
||||||
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
||||||
if count > max_reasonable_count or count < 0:
|
if count > max_reasonable_count or count < 0:
|
||||||
raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.")
|
raise MemoryError(
|
||||||
|
f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read."
|
||||||
|
)
|
||||||
|
|
||||||
total_bytes = count * element_size
|
total_bytes = count * element_size
|
||||||
# --- FIX for MemoryError: Check for huge byte size before allocation ---
|
# --- FIX for MemoryError: Check for huge byte size before allocation ---
|
||||||
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
||||||
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
||||||
raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.")
|
raise MemoryError(
|
||||||
|
f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch."
|
||||||
|
)
|
||||||
|
|
||||||
data_bytes = f.read(total_bytes)
|
data_bytes = f.read(total_bytes)
|
||||||
|
|
||||||
if len(data_bytes) != total_bytes:
|
if len(data_bytes) != total_bytes:
|
||||||
raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.")
|
raise EOFError(
|
||||||
|
f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}."
|
||||||
|
)
|
||||||
return count, data_bytes
|
return count, data_bytes
|
||||||
except (MemoryError, OverflowError) as e:
|
except (MemoryError, OverflowError) as e:
|
||||||
# Add context to the error message
|
# Add context to the error message
|
||||||
print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr)
|
print(
|
||||||
raise e # Re-raise the original error type
|
f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
raise e # Re-raise the original error type
|
||||||
|
|
||||||
|
|
||||||
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
||||||
"""Reads a vector into a NumPy array."""
|
"""Reads a vector into a NumPy array."""
|
||||||
count = -1 # Initialize count for robust error handling
|
count = -1 # Initialize count for robust error handling
|
||||||
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True)
|
print(
|
||||||
|
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
|
||||||
|
end="",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
||||||
print(f"Count={count}, Bytes={len(data_bytes)}")
|
print(f"Count={count}, Bytes={len(data_bytes)}")
|
||||||
if count > 0 and len(data_bytes) > 0:
|
if count > 0 and len(data_bytes) > 0:
|
||||||
arr = np.frombuffer(data_bytes, dtype=np_dtype)
|
arr = np.frombuffer(data_bytes, dtype=np_dtype)
|
||||||
if arr.size != count:
|
if arr.size != count:
|
||||||
raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}")
|
raise ValueError(
|
||||||
|
f"Inconsistent array size after reading. Expected {count}, got {arr.size}"
|
||||||
|
)
|
||||||
return arr
|
return arr
|
||||||
elif count == 0:
|
elif count == 0:
|
||||||
return np.array([], dtype=np_dtype)
|
return np.array([], dtype=np_dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Read zero bytes but count > 0.")
|
raise ValueError("Read zero bytes but count > 0.")
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
# Now count should be defined (or -1 if error was in read_struct)
|
# Now count should be defined (or -1 if error was in read_struct)
|
||||||
print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr)
|
print(
|
||||||
|
f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
except Exception as e: # Catch other potential errors like ValueError
|
except Exception as e: # Catch other potential errors like ValueError
|
||||||
print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr)
|
print(
|
||||||
|
f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def write_numpy_vector(f, arr, struct_fmt_char):
|
def write_numpy_vector(f, arr, struct_fmt_char):
|
||||||
"""Writes a NumPy array as a vector (size followed by data)."""
|
"""Writes a NumPy array as a vector (size followed by data)."""
|
||||||
count = arr.size
|
count = arr.size
|
||||||
f.write(struct.pack('<Q', count))
|
f.write(struct.pack("<Q", count))
|
||||||
try:
|
try:
|
||||||
expected_dtype = np.dtype(struct_fmt_char)
|
expected_dtype = np.dtype(struct_fmt_char)
|
||||||
if arr.dtype != expected_dtype:
|
if arr.dtype != expected_dtype:
|
||||||
@@ -89,23 +117,30 @@ def write_numpy_vector(f, arr, struct_fmt_char):
|
|||||||
else:
|
else:
|
||||||
data_to_write = arr.tobytes()
|
data_to_write = arr.tobytes()
|
||||||
f.write(data_to_write)
|
f.write(data_to_write)
|
||||||
del data_to_write # Hint GC
|
del data_to_write # Hint GC
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr)
|
print(
|
||||||
raise e
|
f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def write_list_vector(f, lst, struct_fmt_char):
|
def write_list_vector(f, lst, struct_fmt_char):
|
||||||
"""Writes a Python list as a vector iteratively."""
|
"""Writes a Python list as a vector iteratively."""
|
||||||
count = len(lst)
|
count = len(lst)
|
||||||
f.write(struct.pack('<Q', count))
|
f.write(struct.pack("<Q", count))
|
||||||
fmt = '<' + struct_fmt_char
|
fmt = "<" + struct_fmt_char
|
||||||
chunk_size = 1024 * 1024
|
chunk_size = 1024 * 1024
|
||||||
element_size = struct.calcsize(fmt)
|
element_size = struct.calcsize(fmt)
|
||||||
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
|
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
|
||||||
try:
|
try:
|
||||||
buffer = bytearray(chunk_size * element_size)
|
buffer = bytearray(chunk_size * element_size)
|
||||||
except MemoryError:
|
except MemoryError:
|
||||||
print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr)
|
print(
|
||||||
|
f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
buffer_count = 0
|
buffer_count = 0
|
||||||
|
|
||||||
@@ -116,65 +151,79 @@ def write_list_vector(f, lst, struct_fmt_char):
|
|||||||
buffer_count += 1
|
buffer_count += 1
|
||||||
|
|
||||||
if buffer_count == chunk_size or i == count - 1:
|
if buffer_count == chunk_size or i == count - 1:
|
||||||
f.write(buffer[:buffer_count * element_size])
|
f.write(buffer[: buffer_count * element_size])
|
||||||
buffer_count = 0
|
buffer_count = 0
|
||||||
|
|
||||||
except struct.error as e:
|
except struct.error as e:
|
||||||
print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr)
|
print(
|
||||||
|
f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
|
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
|
||||||
"""Helper to get cumulative neighbors count, matching C++ logic."""
|
"""Helper to get cumulative neighbors count, matching C++ logic."""
|
||||||
if level < 0: return 0
|
if level < 0:
|
||||||
|
return 0
|
||||||
if level < len(cum_nneighbor_per_level_np):
|
if level < len(cum_nneighbor_per_level_np):
|
||||||
return cum_nneighbor_per_level_np[level]
|
return cum_nneighbor_per_level_np[level]
|
||||||
else:
|
else:
|
||||||
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
|
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
|
||||||
|
|
||||||
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
|
||||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
def write_compact_format(
|
||||||
compact_neighbors_data, storage_fourcc, storage_data):
|
f_out,
|
||||||
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
compact_level_ptr,
|
||||||
|
compact_node_offsets_np,
|
||||||
|
compact_neighbors_data,
|
||||||
|
storage_fourcc,
|
||||||
|
storage_data,
|
||||||
|
):
|
||||||
"""Write HNSW data in compact format following C++ read order exactly."""
|
"""Write HNSW data in compact format following C++ read order exactly."""
|
||||||
# Write IndexHNSW Header
|
# Write IndexHNSW Header
|
||||||
f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc']))
|
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['d']))
|
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
||||||
f_out.write(struct.pack('<q', original_hnsw_data['ntotal']))
|
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
||||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy1']))
|
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
||||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy2']))
|
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
||||||
f_out.write(struct.pack('<?', original_hnsw_data['is_trained']))
|
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['metric_type']))
|
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
||||||
if original_hnsw_data['metric_type'] > 1:
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
f_out.write(struct.pack('<f', original_hnsw_data['metric_arg']))
|
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
||||||
|
|
||||||
# Write HNSW struct parts (standard order)
|
# Write HNSW struct parts (standard order)
|
||||||
write_numpy_vector(f_out, assign_probas_np, 'd')
|
write_numpy_vector(f_out, assign_probas_np, "d")
|
||||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i')
|
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
||||||
write_numpy_vector(f_out, levels_np, 'i')
|
write_numpy_vector(f_out, levels_np, "i")
|
||||||
|
|
||||||
# Write compact format flag
|
# Write compact format flag
|
||||||
f_out.write(struct.pack('<?', True)) # storage_is_compact = True
|
f_out.write(struct.pack("<?", True)) # storage_is_compact = True
|
||||||
|
|
||||||
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
|
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
|
||||||
if isinstance(compact_level_ptr, np.ndarray):
|
if isinstance(compact_level_ptr, np.ndarray):
|
||||||
write_numpy_vector(f_out, compact_level_ptr, 'Q')
|
write_numpy_vector(f_out, compact_level_ptr, "Q")
|
||||||
else:
|
else:
|
||||||
write_list_vector(f_out, compact_level_ptr, 'Q')
|
write_list_vector(f_out, compact_level_ptr, "Q")
|
||||||
|
|
||||||
write_numpy_vector(f_out, compact_node_offsets_np, 'Q')
|
write_numpy_vector(f_out, compact_node_offsets_np, "Q")
|
||||||
|
|
||||||
# Write HNSW scalar parameters
|
# Write HNSW scalar parameters
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['entry_point']))
|
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['max_level']))
|
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['efConstruction']))
|
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['efSearch']))
|
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam']))
|
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
||||||
|
|
||||||
# Write storage fourcc (this determines how to read what follows)
|
# Write storage fourcc (this determines how to read what follows)
|
||||||
f_out.write(struct.pack('<I', storage_fourcc))
|
f_out.write(struct.pack("<I", storage_fourcc))
|
||||||
|
|
||||||
# Write compact neighbors data AFTER storage fourcc
|
# Write compact neighbors data AFTER storage fourcc
|
||||||
write_list_vector(f_out, compact_neighbors_data, 'i')
|
write_list_vector(f_out, compact_neighbors_data, "i")
|
||||||
|
|
||||||
# Write storage data if not NULL (only after neighbors)
|
# Write storage data if not NULL (only after neighbors)
|
||||||
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||||
@@ -183,6 +232,7 @@ def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneigh
|
|||||||
|
|
||||||
# --- Main Conversion Logic ---
|
# --- Main Conversion Logic ---
|
||||||
|
|
||||||
|
|
||||||
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
|
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
|
||||||
"""
|
"""
|
||||||
Converts an HNSW graph file to the CSR format.
|
Converts an HNSW graph file to the CSR format.
|
||||||
@@ -196,91 +246,115 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
original_hnsw_data = {}
|
original_hnsw_data = {}
|
||||||
neighbors_np = None # Initialize to allow check in finally block
|
neighbors_np = None # Initialize to allow check in finally block
|
||||||
try:
|
try:
|
||||||
with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out:
|
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
||||||
|
|
||||||
# --- Read IndexHNSW FourCC and Header ---
|
# --- Read IndexHNSW FourCC and Header ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
|
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
|
||||||
# ... (Keep the header reading logic as before) ...
|
# ... (Keep the header reading logic as before) ...
|
||||||
hnsw_index_fourcc = read_struct(f_in, '<I')
|
hnsw_index_fourcc = read_struct(f_in, "<I")
|
||||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||||
print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr)
|
print(
|
||||||
return False
|
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
||||||
original_hnsw_data['index_fourcc'] = hnsw_index_fourcc
|
file=sys.stderr,
|
||||||
original_hnsw_data['d'] = read_struct(f_in, '<i')
|
)
|
||||||
original_hnsw_data['ntotal'] = read_struct(f_in, '<q')
|
return False
|
||||||
original_hnsw_data['dummy1'] = read_struct(f_in, '<q')
|
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||||
original_hnsw_data['dummy2'] = read_struct(f_in, '<q')
|
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['is_trained'] = read_struct(f_in, '?')
|
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
||||||
original_hnsw_data['metric_type'] = read_struct(f_in, '<i')
|
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
||||||
original_hnsw_data['metric_arg'] = 0.0
|
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
||||||
if original_hnsw_data['metric_type'] > 1:
|
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
||||||
original_hnsw_data['metric_arg'] = read_struct(f_in, '<f')
|
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}")
|
original_hnsw_data["metric_arg"] = 0.0
|
||||||
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
|
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
||||||
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Read original HNSW struct data ---
|
# --- Read original HNSW struct data ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
|
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
|
||||||
assign_probas_np = read_numpy_vector(f_in, np.float64, 'd')
|
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})"
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i')
|
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})"
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
levels_np = read_numpy_vector(f_in, np.int32, 'i')
|
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
|
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
ntotal = len(levels_np)
|
ntotal = len(levels_np)
|
||||||
if ntotal != original_hnsw_data['ntotal']:
|
if ntotal != original_hnsw_data["ntotal"]:
|
||||||
print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr)
|
print(
|
||||||
original_hnsw_data['ntotal'] = ntotal
|
f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
original_hnsw_data["ntotal"] = ntotal
|
||||||
|
|
||||||
# --- Check for compact format flag ---
|
# --- Check for compact format flag ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
|
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
|
||||||
pos_before_compact = f_in.tell()
|
pos_before_compact = f_in.tell()
|
||||||
try:
|
try:
|
||||||
is_compact_flag = read_struct(f_in, '<?')
|
is_compact_flag = read_struct(f_in, "<?")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
|
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
|
||||||
|
|
||||||
if is_compact_flag:
|
if is_compact_flag:
|
||||||
# Input is already in compact format - read compact data
|
# Input is already in compact format - read compact data
|
||||||
print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data..."
|
||||||
|
)
|
||||||
|
|
||||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q')
|
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})"
|
||||||
|
)
|
||||||
|
|
||||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})"
|
||||||
|
)
|
||||||
|
|
||||||
# Read scalar parameters
|
# Read scalar parameters
|
||||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
||||||
|
)
|
||||||
|
|
||||||
# Read storage fourcc
|
# Read storage fourcc
|
||||||
storage_fourcc = read_struct(f_in, '<I')
|
storage_fourcc = read_struct(f_in, "<I")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}"
|
||||||
|
)
|
||||||
|
|
||||||
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
|
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
|
||||||
# Read compact neighbors data
|
# Read compact neighbors data
|
||||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})"
|
||||||
|
)
|
||||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
del compact_neighbors_data_np
|
del compact_neighbors_data_np
|
||||||
|
|
||||||
# Skip storage data and write with NULL marker
|
# Skip storage data and write with NULL marker
|
||||||
print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker."
|
||||||
|
)
|
||||||
storage_fourcc = NULL_INDEX_FOURCC
|
storage_fourcc = NULL_INDEX_FOURCC
|
||||||
elif not prune_embeddings:
|
elif not prune_embeddings:
|
||||||
# Read and preserve compact neighbors and storage
|
# Read and preserve compact neighbors and storage
|
||||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
del compact_neighbors_data_np
|
del compact_neighbors_data_np
|
||||||
|
|
||||||
@@ -288,16 +362,25 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
storage_data = f_in.read()
|
storage_data = f_in.read()
|
||||||
else:
|
else:
|
||||||
# Already pruned (NULL storage)
|
# Already pruned (NULL storage)
|
||||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
del compact_neighbors_data_np
|
del compact_neighbors_data_np
|
||||||
storage_data = b''
|
storage_data = b""
|
||||||
|
|
||||||
# Write the updated compact format
|
# Write the updated compact format
|
||||||
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
|
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
|
||||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
write_compact_format(
|
||||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
f_out,
|
||||||
compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'')
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
compact_level_ptr,
|
||||||
|
compact_node_offsets_np,
|
||||||
|
compact_neighbors_data,
|
||||||
|
storage_fourcc,
|
||||||
|
storage_data if not prune_embeddings else b"",
|
||||||
|
)
|
||||||
|
|
||||||
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
|
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
|
||||||
return True
|
return True
|
||||||
@@ -305,63 +388,86 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
else:
|
else:
|
||||||
# is_compact=False, rewind and read original format
|
# is_compact=False, rewind and read original format
|
||||||
f_in.seek(pos_before_compact)
|
f_in.seek(pos_before_compact)
|
||||||
print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format..."
|
||||||
|
)
|
||||||
|
|
||||||
except EOFError:
|
except EOFError:
|
||||||
# No compact flag found, assume original format
|
# No compact flag found, assume original format
|
||||||
f_in.seek(pos_before_compact)
|
f_in.seek(pos_before_compact)
|
||||||
print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format..."
|
||||||
|
)
|
||||||
|
|
||||||
# --- Handle potential extra byte in original format (like C++ code) ---
|
# --- Handle potential extra byte in original format (like C++ code) ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets..."
|
||||||
|
)
|
||||||
pos_before_probe = f_in.tell()
|
pos_before_probe = f_in.tell()
|
||||||
try:
|
try:
|
||||||
suspected_flag = read_struct(f_in, '<B') # Read 1 byte
|
suspected_flag = read_struct(f_in, "<B") # Read 1 byte
|
||||||
if suspected_flag == 0x00:
|
if suspected_flag == 0x00:
|
||||||
print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte."
|
||||||
|
)
|
||||||
elif suspected_flag == 0x01:
|
elif suspected_flag == 0x01:
|
||||||
print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False"
|
||||||
|
)
|
||||||
raise ValueError("Inconsistent compact flag state")
|
raise ValueError("Inconsistent compact flag state")
|
||||||
else:
|
else:
|
||||||
# Rewind - this byte is part of offsets data
|
# Rewind - this byte is part of offsets data
|
||||||
f_in.seek(pos_before_probe)
|
f_in.seek(pos_before_probe)
|
||||||
print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})"
|
||||||
|
)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
f_in.seek(pos_before_probe)
|
f_in.seek(pos_before_probe)
|
||||||
print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Read original format data ---
|
# --- Read original format data ---
|
||||||
offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
|
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
|
||||||
if len(offsets_np) != ntotal + 1:
|
if len(offsets_np) != ntotal + 1:
|
||||||
raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}")
|
raise ValueError(
|
||||||
|
f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}"
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
|
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
|
||||||
neighbors_np = read_numpy_vector(f_in, np.int32, 'i')
|
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
|
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
|
||||||
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
|
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
|
||||||
if neighbors_np.size != expected_neighbors_size:
|
if neighbors_np.size != expected_neighbors_size:
|
||||||
print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.")
|
print(
|
||||||
|
f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}."
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
|
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
|
||||||
storage_fourcc = None
|
storage_fourcc = None
|
||||||
try:
|
try:
|
||||||
storage_fourcc = read_struct(f_in, '<I')
|
storage_fourcc = read_struct(f_in, "<I")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}."
|
||||||
|
)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Perform Conversion ---
|
# --- Perform Conversion ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
|
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
|
||||||
@@ -373,17 +479,21 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
|
|
||||||
current_level_ptr_idx = 0
|
current_level_ptr_idx = 0
|
||||||
current_data_idx = 0
|
current_data_idx = 0
|
||||||
total_valid_neighbors_counted = 0 # For validation
|
total_valid_neighbors_counted = 0 # For validation
|
||||||
|
|
||||||
# Optimize calculation by getting slices once per node if possible
|
# Optimize calculation by getting slices once per node if possible
|
||||||
for i in range(ntotal):
|
for i in range(ntotal):
|
||||||
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
||||||
progress = (i / ntotal) * 100
|
progress = (i / ntotal) * 100
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="")
|
print(
|
||||||
|
f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
|
||||||
node_max_level = levels_np[i] - 1
|
node_max_level = levels_np[i] - 1
|
||||||
if node_max_level < -1: node_max_level = -1
|
if node_max_level < -1:
|
||||||
|
node_max_level = -1
|
||||||
|
|
||||||
node_ptr_start_index = current_level_ptr_idx
|
node_ptr_start_index = current_level_ptr_idx
|
||||||
compact_node_offsets_np[i] = node_ptr_start_index
|
compact_node_offsets_np[i] = node_ptr_start_index
|
||||||
@@ -394,13 +504,17 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
for level in range(node_max_level + 1):
|
for level in range(node_max_level + 1):
|
||||||
compact_level_ptr.append(current_data_idx)
|
compact_level_ptr.append(current_data_idx)
|
||||||
|
|
||||||
begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level)
|
begin_orig_np = original_offset_start + get_cum_neighbors(
|
||||||
end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1)
|
cum_nneighbor_per_level_np, level
|
||||||
|
)
|
||||||
|
end_orig_np = original_offset_start + get_cum_neighbors(
|
||||||
|
cum_nneighbor_per_level_np, level + 1
|
||||||
|
)
|
||||||
|
|
||||||
begin_orig = int(begin_orig_np)
|
begin_orig = int(begin_orig_np)
|
||||||
end_orig = int(end_orig_np)
|
end_orig = int(end_orig_np)
|
||||||
|
|
||||||
neighbors_len = len(neighbors_np) # Cache length
|
neighbors_len = len(neighbors_np) # Cache length
|
||||||
begin_orig = min(max(0, begin_orig), neighbors_len)
|
begin_orig = min(max(0, begin_orig), neighbors_len)
|
||||||
end_orig = min(max(begin_orig, end_orig), neighbors_len)
|
end_orig = min(max(begin_orig, end_orig), neighbors_len)
|
||||||
|
|
||||||
@@ -413,71 +527,116 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
|
|
||||||
if num_valid > 0:
|
if num_valid > 0:
|
||||||
# Append valid neighbors
|
# Append valid neighbors
|
||||||
compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask])
|
compact_neighbors_data.extend(
|
||||||
|
level_neighbors_slice[valid_neighbors_mask]
|
||||||
|
)
|
||||||
current_data_idx += num_valid
|
current_data_idx += num_valid
|
||||||
total_valid_neighbors_counted += num_valid
|
total_valid_neighbors_counted += num_valid
|
||||||
|
|
||||||
|
|
||||||
compact_level_ptr.append(current_data_idx)
|
compact_level_ptr.append(current_data_idx)
|
||||||
current_level_ptr_idx += num_pointers_expected
|
current_level_ptr_idx += num_pointers_expected
|
||||||
|
|
||||||
compact_node_offsets_np[ntotal] = current_level_ptr_idx
|
compact_node_offsets_np[ntotal] = current_level_ptr_idx
|
||||||
print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line
|
print(
|
||||||
|
f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. "
|
||||||
|
) # Clear progress line
|
||||||
|
|
||||||
# --- Validation Checks ---
|
# --- Validation Checks ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
|
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
|
||||||
valid_check_passed = True
|
valid_check_passed = True
|
||||||
# Check 1: Total valid neighbors count
|
# Check 1: Total valid neighbors count
|
||||||
print(f" Checking total valid neighbor count...")
|
print(" Checking total valid neighbor count...")
|
||||||
expected_valid_count = np.sum(neighbors_np >= 0)
|
expected_valid_count = np.sum(neighbors_np >= 0)
|
||||||
if total_valid_neighbors_counted != len(compact_neighbors_data):
|
if total_valid_neighbors_counted != len(compact_neighbors_data):
|
||||||
print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
print(
|
||||||
valid_check_passed = False
|
f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
valid_check_passed = False
|
||||||
if expected_valid_count != len(compact_neighbors_data):
|
if expected_valid_count != len(compact_neighbors_data):
|
||||||
print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
print(
|
||||||
valid_check_passed = False
|
f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
valid_check_passed = False
|
||||||
else:
|
else:
|
||||||
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
||||||
|
|
||||||
# Check 2: Final pointer indices consistency
|
# Check 2: Final pointer indices consistency
|
||||||
print(f" Checking final pointer indices...")
|
print(" Checking final pointer indices...")
|
||||||
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
|
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
|
||||||
print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr)
|
print(
|
||||||
valid_check_passed = False
|
f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!",
|
||||||
if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \
|
file=sys.stderr,
|
||||||
(len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
)
|
||||||
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
valid_check_passed = False
|
||||||
print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
if (
|
||||||
valid_check_passed = False
|
len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)
|
||||||
|
) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
||||||
|
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
||||||
|
print(
|
||||||
|
f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
valid_check_passed = False
|
||||||
else:
|
else:
|
||||||
print(f" OK: Final pointers match data size.")
|
print(" OK: Final pointers match data size.")
|
||||||
|
|
||||||
if not valid_check_passed:
|
if not valid_check_passed:
|
||||||
print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr)
|
print(
|
||||||
|
"Error: Validation checks failed. Output file might be incorrect.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
# Optional: Exit here if validation fails
|
# Optional: Exit here if validation fails
|
||||||
# return False
|
# return False
|
||||||
|
|
||||||
# --- Explicitly delete large intermediate arrays ---
|
# --- Explicitly delete large intermediate arrays ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..."
|
||||||
|
)
|
||||||
del neighbors_np
|
del neighbors_np
|
||||||
del offsets_np
|
del offsets_np
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}")
|
print(
|
||||||
|
f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Write CSR HNSW graph data using unified function ---
|
# --- Write CSR HNSW graph data using unified function ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..."
|
||||||
|
)
|
||||||
|
|
||||||
# Determine storage fourcc based on prune_embeddings
|
# Determine storage fourcc and data based on prune_embeddings
|
||||||
output_storage_fourcc = NULL_INDEX_FOURCC if prune_embeddings else (storage_fourcc if 'storage_fourcc' in locals() else NULL_INDEX_FOURCC)
|
|
||||||
if prune_embeddings:
|
if prune_embeddings:
|
||||||
print(f" Pruning embeddings: Writing NULL storage marker.")
|
print(" Pruning embeddings: Writing NULL storage marker.")
|
||||||
storage_data = b''
|
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
storage_data = b""
|
||||||
|
else:
|
||||||
|
# Keep embeddings - read and preserve original storage data
|
||||||
|
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
|
||||||
|
print(" Preserving embeddings: Reading original storage data...")
|
||||||
|
storage_data = f_in.read() # Read remaining storage data
|
||||||
|
output_storage_fourcc = storage_fourcc
|
||||||
|
print(f" Read {len(storage_data)} bytes of storage data")
|
||||||
|
else:
|
||||||
|
print(" No embeddings found in original file (NULL storage)")
|
||||||
|
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
storage_data = b""
|
||||||
|
|
||||||
# Use the unified write function
|
# Use the unified write function
|
||||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
write_compact_format(
|
||||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
f_out,
|
||||||
compact_neighbors_data, output_storage_fourcc, storage_data if not prune_embeddings else b'')
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
compact_level_ptr,
|
||||||
|
compact_node_offsets_np,
|
||||||
|
compact_neighbors_data,
|
||||||
|
output_storage_fourcc,
|
||||||
|
storage_data,
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up memory
|
# Clean up memory
|
||||||
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
||||||
@@ -492,40 +651,66 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
||||||
return False
|
return False
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
|
print(
|
||||||
# Clean up potentially partially written output file?
|
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
|
||||||
try: os.remove(output_filename)
|
file=sys.stderr,
|
||||||
except OSError: pass
|
)
|
||||||
return False
|
# Clean up potentially partially written output file?
|
||||||
|
try:
|
||||||
|
os.remove(output_filename)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
except EOFError as e:
|
except EOFError as e:
|
||||||
print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr)
|
print(
|
||||||
try: os.remove(output_filename)
|
f"Error: Reached end of file unexpectedly reading {input_filename}. {e}",
|
||||||
except OSError: pass
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.remove(output_filename)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
|
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
try:
|
try:
|
||||||
os.remove(output_filename)
|
os.remove(output_filename)
|
||||||
except OSError: pass
|
except OSError:
|
||||||
|
pass
|
||||||
return False
|
return False
|
||||||
# Ensure neighbors_np is deleted even if an error occurs after its allocation
|
# Ensure neighbors_np is deleted even if an error occurs after its allocation
|
||||||
finally:
|
finally:
|
||||||
if 'neighbors_np' in locals() and neighbors_np is not None:
|
try:
|
||||||
del neighbors_np
|
if "neighbors_np" in locals() and neighbors_np is not None:
|
||||||
gc.collect()
|
del neighbors_np
|
||||||
|
gc.collect()
|
||||||
|
except NameError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# --- Script Execution ---
|
# --- Script Execution ---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file."
|
||||||
|
)
|
||||||
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
|
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
|
||||||
parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file")
|
parser.add_argument(
|
||||||
parser.add_argument("--prune-embeddings", action="store_true", default=True,
|
"output_csr_graph_file", help="Path to write the output CSR HNSW graph file"
|
||||||
help="Prune embedding storage (write NULL storage marker)")
|
)
|
||||||
parser.add_argument("--keep-embeddings", action="store_true",
|
parser.add_argument(
|
||||||
help="Keep embedding storage (overrides --prune-embeddings)")
|
"--prune-embeddings",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Prune embedding storage (write NULL storage marker)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep-embeddings",
|
||||||
|
action="store_true",
|
||||||
|
help="Keep embedding storage (overrides --prune-embeddings)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -534,10 +719,12 @@ if __name__ == "__main__":
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
|
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
|
||||||
print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
print("Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
|
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
|
||||||
success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings)
|
success = convert_hnsw_graph_to_csr(
|
||||||
|
args.input_index_file, args.output_csr_graph_file, prune_embeddings
|
||||||
|
)
|
||||||
if not success:
|
if not success:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@@ -1,145 +1,38 @@
|
|||||||
import numpy as np
|
import logging
|
||||||
import os
|
import os
|
||||||
import json
|
import shutil
|
||||||
import struct
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Any, Literal
|
||||||
import contextlib
|
|
||||||
import threading
|
import numpy as np
|
||||||
import time
|
from leann.interface import (
|
||||||
import atexit
|
LeannBackendBuilderInterface,
|
||||||
import socket
|
LeannBackendFactoryInterface,
|
||||||
import subprocess
|
LeannBackendSearcherInterface,
|
||||||
import sys
|
)
|
||||||
|
from leann.registry import register_backend
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||||
|
|
||||||
from leann.registry import register_backend
|
logger = logging.getLogger(__name__)
|
||||||
from leann.interface import (
|
|
||||||
LeannBackendFactoryInterface,
|
|
||||||
LeannBackendBuilderInterface,
|
|
||||||
LeannBackendSearcherInterface
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_metric_map():
|
def get_metric_map():
|
||||||
from . import faiss
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mips": faiss.METRIC_INNER_PRODUCT,
|
"mips": faiss.METRIC_INNER_PRODUCT,
|
||||||
"l2": faiss.METRIC_L2,
|
"l2": faiss.METRIC_L2,
|
||||||
"cosine": faiss.METRIC_INNER_PRODUCT,
|
"cosine": faiss.METRIC_INNER_PRODUCT,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
|
||||||
"""Check if a port is in use"""
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
class HNSWEmbeddingServerManager:
|
def normalize_l2(data: np.ndarray) -> np.ndarray:
|
||||||
"""
|
norms = np.linalg.norm(data, axis=1, keepdims=True)
|
||||||
HNSW-specific embedding server manager that handles the lifecycle of the embedding server process.
|
norms[norms == 0] = 1 # Avoid division by zero
|
||||||
Mirrors the DiskANN EmbeddingServerManager architecture.
|
return data / norms
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
self.server_process = None
|
|
||||||
self.server_port = None
|
|
||||||
atexit.register(self.stop_server)
|
|
||||||
|
|
||||||
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None, distance_metric="mips"):
|
|
||||||
"""
|
|
||||||
Start the HNSW embedding server process.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
port: ZMQ port for the server
|
|
||||||
model_name: Name of the embedding model to use
|
|
||||||
passages_file: Optional path to passages JSON file
|
|
||||||
distance_metric: The distance metric to use
|
|
||||||
"""
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check if port is already in use
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"WARNING: Port {port} is already in use. Assuming an external HNSW server is running and connecting to it.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
print(f"INFO: Starting session-level HNSW embedding server as a background process...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
command = [
|
|
||||||
sys.executable,
|
|
||||||
"-m", "leann_backend_hnsw.hnsw_embedding_server",
|
|
||||||
"--zmq-port", str(port),
|
|
||||||
"--model-name", model_name,
|
|
||||||
"--distance-metric", distance_metric
|
|
||||||
]
|
|
||||||
|
|
||||||
if passages_file:
|
|
||||||
command.extend(["--passages-file", str(passages_file)])
|
|
||||||
|
|
||||||
project_root = Path(__file__).parent.parent.parent.parent
|
|
||||||
print(f"INFO: Running HNSW command from project root: {project_root}")
|
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
cwd=project_root,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
text=True,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
self.server_port = port
|
|
||||||
print(f"INFO: HNSW server process started with PID: {self.server_process.pid}")
|
|
||||||
|
|
||||||
max_wait, wait_interval = 30, 0.5
|
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"✅ HNSW embedding server is up and ready for this session.")
|
|
||||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
|
||||||
log_thread.start()
|
|
||||||
return True
|
|
||||||
if self.server_process.poll() is not None:
|
|
||||||
print("❌ ERROR: HNSW server process terminated unexpectedly during startup.")
|
|
||||||
self._log_monitor()
|
|
||||||
return False
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
|
|
||||||
print(f"❌ ERROR: HNSW server process failed to start listening within {max_wait} seconds.")
|
|
||||||
self.stop_server()
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ ERROR: Failed to start HNSW embedding server process: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _log_monitor(self):
|
|
||||||
"""Monitor server logs"""
|
|
||||||
if not self.server_process:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
if self.server_process.stdout:
|
|
||||||
for line in iter(self.server_process.stdout.readline, ''):
|
|
||||||
print(f"[HNSWEmbeddingServer LOG]: {line.strip()}")
|
|
||||||
self.server_process.stdout.close()
|
|
||||||
if self.server_process.stderr:
|
|
||||||
for line in iter(self.server_process.stderr.readline, ''):
|
|
||||||
print(f"[HNSWEmbeddingServer ERROR]: {line.strip()}")
|
|
||||||
self.server_process.stderr.close()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"HNSW Log monitor error: {e}")
|
|
||||||
|
|
||||||
def stop_server(self):
|
|
||||||
"""Stop the HNSW embedding server process"""
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Terminating HNSW session server process (PID: {self.server_process.pid})...")
|
|
||||||
self.server_process.terminate()
|
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=5)
|
|
||||||
print("INFO: HNSW server process terminated.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
print("WARNING: HNSW server process did not terminate gracefully, killing it.")
|
|
||||||
self.server_process.kill()
|
|
||||||
self.server_process = None
|
|
||||||
|
|
||||||
@register_backend("hnsw")
|
@register_backend("hnsw")
|
||||||
class HNSWBackend(LeannBackendFactoryInterface):
|
class HNSWBackend(LeannBackendFactoryInterface):
|
||||||
@@ -149,372 +42,206 @@ class HNSWBackend(LeannBackendFactoryInterface):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
||||||
path = Path(index_path)
|
|
||||||
meta_path = path.parent / f"{path.name}.meta.json"
|
|
||||||
if not meta_path.exists():
|
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
|
||||||
|
|
||||||
with open(meta_path, 'r') as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
|
|
||||||
dimensions = meta.get("dimensions")
|
|
||||||
if not dimensions:
|
|
||||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
|
||||||
|
|
||||||
kwargs['dimensions'] = dimensions
|
|
||||||
return HNSWSearcher(index_path, **kwargs)
|
return HNSWSearcher(index_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class HNSWBuilder(LeannBackendBuilderInterface):
|
class HNSWBuilder(LeannBackendBuilderInterface):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs.copy()
|
self.build_params = kwargs.copy()
|
||||||
|
|
||||||
# --- Configuration defaults with standardized names ---
|
|
||||||
self.is_compact = self.build_params.setdefault("is_compact", True)
|
self.is_compact = self.build_params.setdefault("is_compact", True)
|
||||||
self.is_recompute = self.build_params.setdefault("is_recompute", True)
|
self.is_recompute = self.build_params.setdefault("is_recompute", True)
|
||||||
|
|
||||||
# --- Additional Options ---
|
|
||||||
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
|
|
||||||
self.disk_cache_ratio = self.build_params.setdefault("disk_cache_ratio", 0.0)
|
|
||||||
self.external_storage_path = self.build_params.get("external_storage_path", None)
|
|
||||||
|
|
||||||
# --- Standard HNSW parameters ---
|
|
||||||
self.M = self.build_params.setdefault("M", 32)
|
self.M = self.build_params.setdefault("M", 32)
|
||||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||||
self.dimensions = self.build_params.get("dimensions")
|
self.dimensions = self.build_params.get("dimensions")
|
||||||
|
if not self.is_recompute:
|
||||||
|
if self.is_compact:
|
||||||
|
# TODO: support this case @andy
|
||||||
|
raise ValueError(
|
||||||
|
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
|
||||||
|
)
|
||||||
|
|
||||||
if self.is_skip_neighbors and not self.is_compact:
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
if self.is_recompute and not self.is_compact:
|
|
||||||
raise ValueError("is_recompute requires is_compact=True for efficiency")
|
|
||||||
|
|
||||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
|
||||||
"""Build HNSW index using FAISS"""
|
|
||||||
from . import faiss
|
|
||||||
|
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
index_prefix = path.stem
|
index_prefix = path.stem
|
||||||
|
|
||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
|
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
if not data.flags['C_CONTIGUOUS']:
|
|
||||||
data = np.ascontiguousarray(data)
|
|
||||||
|
|
||||||
metric_str = self.distance_metric.lower()
|
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||||
metric_enum = get_metric_map().get(metric_str)
|
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
|
|
||||||
M = self.M
|
dim = self.dimensions or data.shape[1]
|
||||||
efConstruction = self.efConstruction
|
index = faiss.IndexHNSWFlat(dim, self.M, metric_enum)
|
||||||
dim = self.dimensions
|
index.hnsw.efConstruction = self.efConstruction
|
||||||
if not dim:
|
|
||||||
dim = data.shape[1]
|
|
||||||
|
|
||||||
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
|
if self.distance_metric.lower() == "cosine":
|
||||||
|
data = normalize_l2(data)
|
||||||
|
|
||||||
try:
|
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||||
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
|
index_file = index_dir / f"{index_prefix}.index"
|
||||||
index.hnsw.efConstruction = efConstruction
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
if metric_str == "cosine":
|
if self.is_compact:
|
||||||
faiss.normalize_L2(data)
|
self._convert_to_csr(index_file)
|
||||||
|
|
||||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
|
||||||
|
|
||||||
index_file = index_dir / f"{index_prefix}.index"
|
|
||||||
faiss.write_index(index, str(index_file))
|
|
||||||
|
|
||||||
print(f"✅ HNSW index built successfully at '{index_file}'")
|
|
||||||
|
|
||||||
if self.is_compact:
|
|
||||||
self._convert_to_csr(index_file)
|
|
||||||
|
|
||||||
if self.is_recompute:
|
|
||||||
self._generate_passages_file(index_dir, index_prefix, **kwargs)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _convert_to_csr(self, index_file: Path):
|
def _convert_to_csr(self, index_file: Path):
|
||||||
"""Convert built index to CSR format"""
|
"""Convert built index to CSR format"""
|
||||||
try:
|
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
|
||||||
|
|
||||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||||
|
|
||||||
success = convert_hnsw_graph_to_csr(
|
success = convert_hnsw_graph_to_csr(
|
||||||
str(index_file),
|
str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
|
||||||
str(csr_temp_file),
|
)
|
||||||
prune_embeddings=self.is_recompute
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
print("✅ CSR conversion successful.")
|
logger.info("✅ CSR conversion successful.")
|
||||||
import shutil
|
# index_file_old = index_file.with_suffix(".old")
|
||||||
shutil.move(str(csr_temp_file), str(index_file))
|
# shutil.move(str(index_file), str(index_file_old))
|
||||||
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
shutil.move(str(csr_temp_file), str(index_file))
|
||||||
else:
|
logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
||||||
# Clean up and fail fast
|
else:
|
||||||
if csr_temp_file.exists():
|
# Clean up and fail fast
|
||||||
os.remove(csr_temp_file)
|
if csr_temp_file.exists():
|
||||||
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
os.remove(csr_temp_file)
|
||||||
|
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
|
|
||||||
"""Generate passages file for recompute mode"""
|
|
||||||
try:
|
|
||||||
chunks = kwargs.get('chunks', [])
|
|
||||||
if not chunks:
|
|
||||||
print("INFO: No chunks data provided, skipping passages file generation")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Generate node_id to text mapping
|
|
||||||
passages_data = {}
|
|
||||||
for node_id, chunk in enumerate(chunks):
|
|
||||||
passages_data[str(node_id)] = chunk["text"]
|
|
||||||
|
|
||||||
# Save passages file
|
|
||||||
passages_file = index_dir / f"{index_prefix}.passages.json"
|
|
||||||
with open(passages_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(passages_data, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
|
|
||||||
# Don't raise - this is not critical for index building
|
|
||||||
pass
|
|
||||||
|
|
||||||
class HNSWSearcher(LeannBackendSearcherInterface):
|
|
||||||
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
|
|
||||||
"""
|
|
||||||
Robustly determines the index's storage status by parsing the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (is_compact, is_pruned).
|
|
||||||
"""
|
|
||||||
if not index_file.exists():
|
|
||||||
return False, False
|
|
||||||
|
|
||||||
with open(index_file, 'rb') as f:
|
|
||||||
try:
|
|
||||||
def read_struct(fmt):
|
|
||||||
size = struct.calcsize(fmt)
|
|
||||||
data = f.read(size)
|
|
||||||
if len(data) != size:
|
|
||||||
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'.")
|
|
||||||
return struct.unpack(fmt, data)[0]
|
|
||||||
|
|
||||||
def skip_vector(element_size):
|
|
||||||
count = read_struct('<Q')
|
|
||||||
f.seek(count * element_size, 1)
|
|
||||||
|
|
||||||
# 1. Read up to the compact flag
|
|
||||||
read_struct('<I'); read_struct('<i'); read_struct('<q');
|
|
||||||
read_struct('<q'); read_struct('<q'); read_struct('<?')
|
|
||||||
metric_type = read_struct('<i')
|
|
||||||
if metric_type > 1: read_struct('<f')
|
|
||||||
skip_vector(8); skip_vector(4); skip_vector(4)
|
|
||||||
|
|
||||||
# 2. Check if there's a compact flag byte
|
|
||||||
# Try to read the compact flag, but handle both old and new formats
|
|
||||||
pos_before_compact = f.tell()
|
|
||||||
try:
|
|
||||||
is_compact = read_struct('<?')
|
|
||||||
print(f"INFO: Detected is_compact flag as: {is_compact}")
|
|
||||||
except (EOFError, struct.error):
|
|
||||||
# Old format without compact flag - assume non-compact
|
|
||||||
f.seek(pos_before_compact)
|
|
||||||
is_compact = False
|
|
||||||
print(f"INFO: No compact flag found, assuming is_compact=False")
|
|
||||||
|
|
||||||
# 3. Read storage FourCC to determine if pruned
|
|
||||||
is_pruned = False
|
|
||||||
try:
|
|
||||||
if is_compact:
|
|
||||||
# For compact, we need to skip pointers and scalars to get to the storage FourCC
|
|
||||||
skip_vector(8) # level_ptr
|
|
||||||
skip_vector(8) # node_offsets
|
|
||||||
read_struct('<i'); read_struct('<i'); read_struct('<i');
|
|
||||||
read_struct('<i'); read_struct('<i')
|
|
||||||
storage_fourcc = read_struct('<I')
|
|
||||||
else:
|
|
||||||
# For non-compact, we need to read the flag probe, then skip offsets and neighbors
|
|
||||||
pos_before_probe = f.tell()
|
|
||||||
flag_byte = f.read(1)
|
|
||||||
if not (flag_byte and flag_byte == b'\x00'):
|
|
||||||
f.seek(pos_before_probe)
|
|
||||||
skip_vector(8); skip_vector(4) # offsets, neighbors
|
|
||||||
read_struct('<i'); read_struct('<i'); read_struct('<i');
|
|
||||||
read_struct('<i'); read_struct('<i')
|
|
||||||
# Now we are at the storage. The entire rest is storage blob.
|
|
||||||
storage_fourcc = struct.unpack('<I', f.read(4))[0]
|
|
||||||
|
|
||||||
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
|
|
||||||
if storage_fourcc == NULL_INDEX_FOURCC:
|
|
||||||
is_pruned = True
|
|
||||||
except (EOFError, struct.error):
|
|
||||||
# Cannot determine pruning status, assume not pruned
|
|
||||||
pass
|
|
||||||
|
|
||||||
print(f"INFO: Detected is_pruned as: {is_pruned}")
|
|
||||||
return is_compact, is_pruned
|
|
||||||
|
|
||||||
except (EOFError, struct.error) as e:
|
|
||||||
print(f"WARNING: Could not parse index file to detect format: {e}. Assuming standard, not pruned.")
|
|
||||||
return False, False
|
|
||||||
|
|
||||||
|
class HNSWSearcher(BaseSearcher):
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
from . import faiss
|
super().__init__(
|
||||||
path = Path(index_path)
|
index_path,
|
||||||
index_dir = path.parent
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
|
||||||
index_prefix = path.stem
|
**kwargs,
|
||||||
|
)
|
||||||
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
# Store configuration and paths for later use
|
self.distance_metric = (
|
||||||
self.config = kwargs.copy()
|
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||||
self.config["index_path"] = index_path
|
)
|
||||||
self.index_dir = index_dir
|
metric_enum = get_metric_map().get(self.distance_metric)
|
||||||
self.index_prefix = index_prefix
|
|
||||||
|
|
||||||
metric_str = self.config.get("distance_metric", "mips").lower()
|
|
||||||
metric_enum = get_metric_map().get(metric_str)
|
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
|
|
||||||
dimensions = self.config.get("dimensions")
|
self.is_compact, self.is_pruned = (
|
||||||
if not dimensions:
|
self.meta.get("is_compact", True),
|
||||||
raise ValueError("Vector dimension not provided to HNSWSearcher.")
|
self.meta.get("is_pruned", True),
|
||||||
|
)
|
||||||
|
|
||||||
index_file = index_dir / f"{index_prefix}.index"
|
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||||
if not index_file.exists():
|
if not index_file.exists():
|
||||||
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
||||||
|
|
||||||
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
|
|
||||||
|
|
||||||
# Validate configuration constraints
|
|
||||||
if not self.is_compact and self.config.get("is_skip_neighbors", False):
|
|
||||||
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
|
|
||||||
|
|
||||||
if self.config.get("is_recompute", False) and self.config.get("external_storage_path"):
|
|
||||||
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
|
|
||||||
|
|
||||||
hnsw_config = faiss.HNSWIndexConfig()
|
hnsw_config = faiss.HNSWIndexConfig()
|
||||||
hnsw_config.is_compact = self.is_compact
|
hnsw_config.is_compact = self.is_compact
|
||||||
|
hnsw_config.is_recompute = (
|
||||||
# Apply additional configuration options with strict validation
|
self.is_pruned
|
||||||
hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False)
|
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
|
||||||
hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False)
|
|
||||||
hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0)
|
|
||||||
hnsw_config.external_storage_path = self.config.get("external_storage_path")
|
|
||||||
hnsw_config.zmq_port = self.config.get("zmq_port", 5557)
|
|
||||||
|
|
||||||
if self.is_pruned and not hnsw_config.is_recompute:
|
|
||||||
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
|
||||||
|
|
||||||
print(f"INFO: Loading index with is_compact={self.is_compact}, is_pruned={self.is_pruned}")
|
|
||||||
print(f"INFO: Config - skip_neighbors={hnsw_config.is_skip_neighbors}, recompute={hnsw_config.is_recompute}")
|
|
||||||
|
|
||||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||||
|
|
||||||
if self.is_compact:
|
def search(
|
||||||
print("✅ Compact CSR format HNSW index loaded successfully.")
|
self,
|
||||||
else:
|
query: np.ndarray,
|
||||||
print("✅ Standard HNSW index loaded successfully.")
|
top_k: int,
|
||||||
|
zmq_port: int | None = None,
|
||||||
|
complexity: int = 64,
|
||||||
|
beam_width: int = 1,
|
||||||
|
prune_ratio: float = 0.0,
|
||||||
|
recompute_embeddings: bool = True,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
batch_size: int = 0,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Search for nearest neighbors using HNSW index.
|
||||||
|
|
||||||
self.metric_str = metric_str
|
Args:
|
||||||
self.embedding_server_manager = HNSWEmbeddingServerManager()
|
query: Query vectors (B, D) where B is batch size, D is dimension
|
||||||
|
top_k: Number of nearest neighbors to return
|
||||||
|
complexity: Search complexity/efSearch, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel search paths/beam_size
|
||||||
|
prune_ratio: Ratio of neighbors to prune via PQ (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server
|
||||||
|
pruning_strategy: PQ candidate selection strategy:
|
||||||
|
- "global": Use global PQ queue size for selection (default)
|
||||||
|
- "local": Local pruning, sort and select best candidates
|
||||||
|
- "proportional": Base selection on new neighbor count ratio
|
||||||
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
|
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
||||||
|
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
||||||
|
|
||||||
def _get_index_file(self, index_dir: Path, index_prefix: str) -> Path:
|
Returns:
|
||||||
"""Get the appropriate index file path based on format"""
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||||
# We always use the same filename now, format is detected internally
|
"""
|
||||||
return index_dir / f"{index_prefix}.index"
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
if not recompute_embeddings:
|
||||||
"""Search using HNSW index with optional recompute functionality"""
|
if self.is_pruned:
|
||||||
from . import faiss
|
raise RuntimeError("Recompute is required for pruned index.")
|
||||||
# Merge config with search-time kwargs
|
if recompute_embeddings:
|
||||||
search_config = self.config.copy()
|
if zmq_port is None:
|
||||||
search_config.update(kwargs)
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
|
|
||||||
ef = search_config.get("ef", 200) # Size of the dynamic candidate list for search
|
|
||||||
|
|
||||||
# Recompute parameters
|
|
||||||
zmq_port = search_config.get("zmq_port", 5557)
|
|
||||||
embedding_model = search_config.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
|
||||||
passages_file = search_config.get("passages_file", None)
|
|
||||||
|
|
||||||
# For recompute mode, try to find the passages file automatically
|
|
||||||
if self.is_pruned and not passages_file:
|
|
||||||
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
|
||||||
print(f"DEBUG: Checking for passages file at: {potential_passages_file}")
|
|
||||||
if potential_passages_file.exists():
|
|
||||||
passages_file = str(potential_passages_file)
|
|
||||||
print(f"INFO: Found passages file for recompute mode: {passages_file}")
|
|
||||||
else:
|
|
||||||
print(f"WARNING: No passages file found for recompute mode at {potential_passages_file}")
|
|
||||||
|
|
||||||
# If index is pruned (embeddings removed), we MUST start embedding server for recompute
|
|
||||||
if self.is_pruned:
|
|
||||||
print(f"INFO: Index is pruned - starting embedding server for recompute")
|
|
||||||
|
|
||||||
# CRITICAL: Check passages file exists - fail fast if not
|
|
||||||
if not passages_file:
|
|
||||||
raise RuntimeError(f"FATAL: Index is pruned but no passages file found. Cannot proceed with recompute mode.")
|
|
||||||
|
|
||||||
# Check if server is already running first
|
|
||||||
if _check_port(zmq_port):
|
|
||||||
print(f"INFO: Embedding server already running on port {zmq_port}")
|
|
||||||
else:
|
|
||||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file, self.metric_str):
|
|
||||||
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
|
|
||||||
|
|
||||||
# Give server extra time to fully initialize
|
|
||||||
print(f"INFO: Waiting for embedding server to fully initialize...")
|
|
||||||
time.sleep(3)
|
|
||||||
|
|
||||||
# Final verification
|
|
||||||
if not _check_port(zmq_port):
|
|
||||||
raise RuntimeError(f"Embedding server failed to start listening on port {zmq_port}")
|
|
||||||
else:
|
|
||||||
print(f"INFO: Index has embeddings stored - no recompute needed")
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
if query.ndim == 1:
|
if self.distance_metric == "cosine":
|
||||||
query = np.expand_dims(query, axis=0)
|
query = normalize_l2(query)
|
||||||
|
|
||||||
# Normalize query if using cosine similarity
|
params = faiss.SearchParametersHNSW()
|
||||||
if self.metric_str == "cosine":
|
if zmq_port is not None:
|
||||||
faiss.normalize_L2(query)
|
params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||||
|
params.efSearch = complexity
|
||||||
|
params.beam_size = beam_width
|
||||||
|
|
||||||
try:
|
# For OpenAI embeddings with cosine distance, disable relative distance check
|
||||||
# Set search parameter
|
# This prevents early termination when all scores are in a narrow range
|
||||||
self._index.hnsw.efSearch = ef
|
embedding_model = self.meta.get("embedding_model", "").lower()
|
||||||
|
if self.distance_metric == "cosine" and any(
|
||||||
|
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
|
||||||
|
):
|
||||||
|
params.check_relative_distance = False
|
||||||
|
else:
|
||||||
|
params.check_relative_distance = True
|
||||||
|
|
||||||
# Prepare output arrays for the older FAISS SWIG API
|
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
||||||
batch_size = query.shape[0]
|
params.pq_pruning_ratio = prune_ratio
|
||||||
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
|
||||||
labels = np.empty((batch_size, top_k), dtype=np.int64)
|
|
||||||
|
|
||||||
# Use standard FAISS search - recompute is handled internally by FAISS
|
# Map pruning_strategy to HNSW parameters
|
||||||
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
|
if pruning_strategy == "local":
|
||||||
|
params.local_prune = True
|
||||||
|
params.send_neigh_times_ratio = 0.0
|
||||||
|
elif pruning_strategy == "proportional":
|
||||||
|
params.local_prune = False
|
||||||
|
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
|
||||||
|
else: # "global"
|
||||||
|
params.local_prune = False
|
||||||
|
params.send_neigh_times_ratio = 0.0
|
||||||
|
|
||||||
return {"labels": labels, "distances": distances}
|
# HNSW-specific batch processing parameter
|
||||||
|
params.batch_size = batch_size
|
||||||
|
|
||||||
except Exception as e:
|
batch_size_query = query.shape[0]
|
||||||
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
|
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
||||||
raise
|
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
||||||
|
|
||||||
def __del__(self):
|
self._index.search(
|
||||||
if hasattr(self, 'embedding_server_manager'):
|
query.shape[0],
|
||||||
self.embedding_server_manager.stop_server()
|
faiss.swig_ptr(query),
|
||||||
|
top_k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
|
|
||||||
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -1,351 +1,109 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
"""
|
||||||
HNSW-specific embedding server with removed config.py dependencies
|
HNSW-specific embedding server
|
||||||
Based on DiskANN embedding server architecture
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pickle
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import zmq
|
|
||||||
import numpy as np
|
|
||||||
import msgpack
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Optional, Union
|
|
||||||
|
|
||||||
RED = "\033[91m"
|
import msgpack
|
||||||
RESET = "\033[0m"
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
|
||||||
def is_similarity_metric():
|
# Set up logging based on environment variable
|
||||||
"""
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
Check if the metric type is similarity-based (like inner product).
|
logger = logging.getLogger(__name__)
|
||||||
0 = L2 (distance metric), 1 = Inner Product (similarity metric)
|
|
||||||
"""
|
|
||||||
return True # 1 is METRIC_INNER_PRODUCT in FAISS
|
|
||||||
|
|
||||||
# Function for E5-style average pooling
|
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||||
import torch
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
from torch import Tensor
|
logger.setLevel(log_level)
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
# Ensure we have a handler if none exists
|
||||||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
if not logger.handlers:
|
||||||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
class SimplePassageLoader:
|
|
||||||
"""
|
|
||||||
Simple passage loader that replaces config.py dependencies
|
|
||||||
"""
|
|
||||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
|
||||||
self.passages_data = passages_data or {}
|
|
||||||
|
|
||||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
|
||||||
"""Get passage by ID"""
|
|
||||||
str_id = str(passage_id)
|
|
||||||
if str_id in self.passages_data:
|
|
||||||
return {"text": self.passages_data[str_id]}
|
|
||||||
else:
|
|
||||||
# Return empty text for missing passages
|
|
||||||
return {"text": ""}
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.passages_data)
|
|
||||||
|
|
||||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
|
||||||
"""
|
|
||||||
Load passages from a JSON file
|
|
||||||
Expected format: {"passage_id": "passage_text", ...}
|
|
||||||
"""
|
|
||||||
if not os.path.exists(passages_file):
|
|
||||||
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
|
|
||||||
return SimplePassageLoader()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
|
||||||
passages_data = json.load(f)
|
|
||||||
print(f"Loaded {len(passages_data)} passages from {passages_file}")
|
|
||||||
return SimplePassageLoader(passages_data)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading passages from {passages_file}: {e}")
|
|
||||||
return SimplePassageLoader()
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: str | None = None,
|
||||||
passages_data: Optional[Dict[str, str]] = None,
|
|
||||||
embeddings_file: Optional[str] = None,
|
|
||||||
use_fp16: bool = True,
|
|
||||||
use_int8: bool = False,
|
|
||||||
use_cuda_graphs: bool = False,
|
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
max_batch_size: int = 128,
|
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
custom_max_length_param: Optional[int] = None,
|
|
||||||
distance_metric: str = "mips",
|
distance_metric: str = "mips",
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||||
|
Simplified version using unified embedding computation module.
|
||||||
Args:
|
|
||||||
passages_file: Path to JSON file containing passage ID -> text mapping
|
|
||||||
passages_data: Direct passage data dict (alternative to passages_file)
|
|
||||||
embeddings_file: Path to pre-computed embeddings file (optional)
|
|
||||||
use_fp16: Whether to use FP16 precision
|
|
||||||
use_int8: Whether to use INT8 quantization
|
|
||||||
use_cuda_graphs: Whether to use CUDA graphs
|
|
||||||
zmq_port: ZMQ port to bind to
|
|
||||||
max_batch_size: Maximum batch size for processing
|
|
||||||
model_name: Transformer model name
|
|
||||||
custom_max_length_param: Custom max sequence length
|
|
||||||
distance_metric: The distance metric to use
|
|
||||||
"""
|
"""
|
||||||
print(f"Loading tokenizer for {model_name}...")
|
logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||||
print(f"Tokenizer loaded successfully!")
|
|
||||||
|
|
||||||
# Device setup
|
# Add leann-core to path for unified embedding computation
|
||||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
current_dir = Path(__file__).parent
|
||||||
cuda_available = torch.cuda.is_available()
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||||
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
print(f"MPS available: {mps_available}")
|
try:
|
||||||
print(f"CUDA available: {cuda_available}")
|
from leann.api import PassageManager
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
if cuda_available:
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
device = torch.device("cuda")
|
except ImportError as e:
|
||||||
print("Using CUDA device")
|
logger.error(f"Failed to import embedding computation module: {e}")
|
||||||
elif mps_available:
|
return
|
||||||
device = torch.device("mps")
|
finally:
|
||||||
print("Using MPS device (Apple Silicon)")
|
sys.path.pop(0)
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
print("Using CPU device (no GPU acceleration available)")
|
|
||||||
|
|
||||||
# Load model to the appropriate device
|
|
||||||
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
|
||||||
print(f"Loading model {model_name}... (this may take a while if downloading)")
|
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
|
||||||
print(f"Model {model_name} loaded successfully!")
|
|
||||||
|
|
||||||
# Check port availability
|
# Check port availability
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
def check_port(port):
|
def check_port(port):
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
if check_port(zmq_port):
|
if check_port(zmq_port):
|
||||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
logger.error(f"Port {zmq_port} is already in use")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Apply model optimizations (similar to DiskANN version)
|
# Only support metadata file, fail fast for everything else
|
||||||
if use_fp16 and (cuda_available or mps_available):
|
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||||
model = model.half()
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||||
model = torch.compile(model)
|
|
||||||
print(f"Using FP16 precision with model: {model_name}")
|
|
||||||
elif use_int8:
|
|
||||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
|
||||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
|
||||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
|
||||||
model = torch.compile(model)
|
|
||||||
model.eval()
|
|
||||||
print("- Model successfully quantized and compiled")
|
|
||||||
|
|
||||||
# Load passages
|
# Load metadata to get passage sources
|
||||||
if passages_data:
|
with open(passages_file) as f:
|
||||||
passages = SimplePassageLoader(passages_data)
|
meta = json.load(f)
|
||||||
print(f"Using provided passages data: {len(passages)} passages")
|
|
||||||
elif passages_file:
|
|
||||||
passages = load_passages_from_file(passages_file)
|
|
||||||
else:
|
|
||||||
passages = SimplePassageLoader()
|
|
||||||
print("No passages provided, using empty loader")
|
|
||||||
|
|
||||||
# Load embeddings if provided
|
# Convert relative paths to absolute paths based on metadata file location
|
||||||
_embeddings = None
|
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
||||||
if embeddings_file and os.path.exists(embeddings_file):
|
passage_sources = []
|
||||||
try:
|
for source in meta["passage_sources"]:
|
||||||
with open(embeddings_file, "rb") as f:
|
source_copy = source.copy()
|
||||||
_embeddings = pickle.load(f)
|
# Convert relative paths to absolute paths
|
||||||
print(f"Loaded embeddings from {embeddings_file}")
|
if not Path(source_copy["path"]).is_absolute():
|
||||||
except Exception as e:
|
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
||||||
print(f"Error loading embeddings: {e}")
|
if not Path(source_copy["index_path"]).is_absolute():
|
||||||
|
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
||||||
|
passage_sources.append(source_copy)
|
||||||
|
|
||||||
class DeviceTimer:
|
passages = PassageManager(passage_sources)
|
||||||
"""Device event-based timer for accurate timing."""
|
logger.info(
|
||||||
def __init__(self, name="", device=device):
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
self.name = name
|
)
|
||||||
self.device = device
|
|
||||||
self.start_time = 0
|
|
||||||
self.end_time = 0
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
else:
|
|
||||||
self.start_event = None
|
|
||||||
self.end_event = None
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start()
|
|
||||||
yield
|
|
||||||
self.end()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
if cuda_available:
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.start_event.record()
|
|
||||||
else:
|
|
||||||
if self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
if cuda_available:
|
|
||||||
self.end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
else:
|
|
||||||
if self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.end_time = time.time()
|
|
||||||
|
|
||||||
def elapsed_time(self):
|
|
||||||
if cuda_available:
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
|
||||||
else:
|
|
||||||
return self.end_time - self.start_time
|
|
||||||
|
|
||||||
def print_elapsed(self):
|
|
||||||
return # Disabled for now
|
|
||||||
|
|
||||||
def process_batch(texts_batch, ids_batch, missing_ids):
|
|
||||||
"""Process a batch of texts and return embeddings"""
|
|
||||||
_is_e5_model = "e5" in model_name.lower()
|
|
||||||
_is_bge_model = "bge" in model_name.lower()
|
|
||||||
batch_size = len(texts_batch)
|
|
||||||
|
|
||||||
# E5 model preprocessing
|
|
||||||
if _is_e5_model:
|
|
||||||
processed_texts_batch = [f"passage: {text}" for text in texts_batch]
|
|
||||||
else:
|
|
||||||
processed_texts_batch = texts_batch
|
|
||||||
|
|
||||||
# Set max length
|
|
||||||
if _is_e5_model:
|
|
||||||
current_max_length = custom_max_length_param if custom_max_length_param is not None else 512
|
|
||||||
else:
|
|
||||||
current_max_length = custom_max_length_param if custom_max_length_param is not None else 256
|
|
||||||
|
|
||||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
|
||||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
|
||||||
embed_timer = DeviceTimer("embedding (batch)", device)
|
|
||||||
pool_timer = DeviceTimer("pooling (batch)", device)
|
|
||||||
norm_timer = DeviceTimer("normalization (batch)", device)
|
|
||||||
|
|
||||||
with tokenize_timer.timing():
|
|
||||||
encoded_batch = tokenizer(
|
|
||||||
processed_texts_batch,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
max_length=current_max_length,
|
|
||||||
return_tensors="pt",
|
|
||||||
return_token_type_ids=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
seq_length = encoded_batch["input_ids"].size(1)
|
|
||||||
|
|
||||||
with to_device_timer.timing():
|
|
||||||
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
with embed_timer.timing():
|
|
||||||
out = model(enc["input_ids"], enc["attention_mask"])
|
|
||||||
|
|
||||||
with pool_timer.timing():
|
|
||||||
if _is_bge_model:
|
|
||||||
pooled_embeddings = out.last_hidden_state[:, 0]
|
|
||||||
elif not hasattr(out, 'last_hidden_state'):
|
|
||||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
|
||||||
pooled_embeddings = out
|
|
||||||
else:
|
|
||||||
print(f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}")
|
|
||||||
hidden_dim = getattr(model.config, 'hidden_size', 384 if _is_e5_model else 768)
|
|
||||||
pooled_embeddings = torch.zeros((batch_size, hidden_dim), device=device, dtype=enc["input_ids"].dtype if hasattr(enc["input_ids"], "dtype") else torch.float32)
|
|
||||||
elif _is_e5_model:
|
|
||||||
pooled_embeddings = e5_average_pool(out.last_hidden_state, enc['attention_mask'])
|
|
||||||
else:
|
|
||||||
hidden_states = out.last_hidden_state
|
|
||||||
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
|
|
||||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
|
||||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
|
||||||
pooled_embeddings = sum_embeddings / sum_mask
|
|
||||||
|
|
||||||
final_embeddings = pooled_embeddings
|
|
||||||
if _is_e5_model or _is_bge_model:
|
|
||||||
with norm_timer.timing():
|
|
||||||
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
|
|
||||||
|
|
||||||
if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any():
|
|
||||||
print(f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
|
|
||||||
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}")
|
|
||||||
dim_size = final_embeddings.shape[-1]
|
|
||||||
error_output = torch.zeros((batch_size, dim_size), device='cpu', dtype=torch.float32).numpy()
|
|
||||||
print(f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}")
|
|
||||||
return error_output
|
|
||||||
|
|
||||||
return final_embeddings.cpu().numpy()
|
|
||||||
|
|
||||||
def client_warmup(zmq_port):
|
|
||||||
"""Perform client-side warmup"""
|
|
||||||
time.sleep(2)
|
|
||||||
print(f"Performing client-side warmup with model {model_name}...")
|
|
||||||
sample_ids = ["1", "2", "3", "4", "5"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.REQ)
|
|
||||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
|
||||||
|
|
||||||
try:
|
|
||||||
ids_to_send = [int(x) for x in sample_ids]
|
|
||||||
except ValueError:
|
|
||||||
ids_to_send = []
|
|
||||||
|
|
||||||
if not ids_to_send:
|
|
||||||
print("Skipping warmup send.")
|
|
||||||
return
|
|
||||||
|
|
||||||
request_payload = [ids_to_send]
|
|
||||||
request_bytes = msgpack.packb(request_payload)
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
print(f"Sending warmup request {i+1}/3 via ZMQ (MessagePack)...")
|
|
||||||
socket.send(request_bytes)
|
|
||||||
response_bytes = socket.recv()
|
|
||||||
|
|
||||||
response_payload = msgpack.unpackb(response_bytes)
|
|
||||||
dimensions = response_payload[0]
|
|
||||||
embeddings_count = dimensions[0] if dimensions and len(dimensions) > 0 else 0
|
|
||||||
print(f"Warmup request {i+1}/3 successful, received {embeddings_count} embeddings")
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
print("Client-side MessagePack ZMQ warmup complete")
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during MessagePack ZMQ warmup: {e}")
|
|
||||||
|
|
||||||
def zmq_server_thread():
|
def zmq_server_thread():
|
||||||
"""ZMQ server thread"""
|
"""ZMQ server thread"""
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REP)
|
socket = context.socket(zmq.REP)
|
||||||
socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
print(f"HNSW ZMQ server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
@@ -353,244 +111,201 @@ def create_hnsw_embedding_server(
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
message_bytes = socket.recv()
|
message_bytes = socket.recv()
|
||||||
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||||
|
|
||||||
e2e_start = time.time()
|
e2e_start = time.time()
|
||||||
lookup_timer = DeviceTimer("text lookup", device)
|
request_payload = msgpack.unpackb(message_bytes)
|
||||||
|
|
||||||
try:
|
# Handle direct text embedding request
|
||||||
request_payload = msgpack.unpackb(message_bytes)
|
if isinstance(request_payload, list) and len(request_payload) > 0:
|
||||||
|
# Check if this is a direct text request (list of strings)
|
||||||
|
if all(isinstance(item, str) for item in request_payload):
|
||||||
|
logger.info(
|
||||||
|
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
||||||
|
)
|
||||||
|
|
||||||
# Handle distance calculation requests
|
# Use unified embedding computation (now with model caching)
|
||||||
if isinstance(request_payload, list) and len(request_payload) == 2 and isinstance(request_payload[0], list) and isinstance(request_payload[1], list):
|
embeddings = compute_embeddings(
|
||||||
node_ids = request_payload[0]
|
request_payload, model_name, mode=embedding_mode
|
||||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
)
|
||||||
|
|
||||||
print(f"Request for distance calculation: {len(node_ids)} nodes, query vector dim: {len(query_vector)}")
|
response = embeddings.tolist()
|
||||||
|
socket.send(msgpack.packb(response))
|
||||||
# Get embeddings for node IDs
|
|
||||||
texts = []
|
|
||||||
missing_ids = []
|
|
||||||
with lookup_timer.timing():
|
|
||||||
for nid in node_ids:
|
|
||||||
try:
|
|
||||||
txtinfo = passages[nid]
|
|
||||||
if txtinfo is None or txtinfo["text"] == "":
|
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
|
||||||
else:
|
|
||||||
txt = txtinfo["text"]
|
|
||||||
except (KeyError, IndexError):
|
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
|
||||||
texts.append(txt)
|
|
||||||
lookup_timer.print_elapsed()
|
|
||||||
|
|
||||||
# Process embeddings in chunks if needed
|
|
||||||
all_node_embeddings = []
|
|
||||||
total_size = len(texts)
|
|
||||||
|
|
||||||
if total_size > max_batch_size:
|
|
||||||
for i in range(0, total_size, max_batch_size):
|
|
||||||
end_idx = min(i + max_batch_size, total_size)
|
|
||||||
chunk_texts = texts[i:end_idx]
|
|
||||||
chunk_ids = node_ids[i:end_idx]
|
|
||||||
|
|
||||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
|
||||||
all_node_embeddings.append(embeddings_chunk)
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
|
|
||||||
node_embeddings = np.vstack(all_node_embeddings)
|
|
||||||
else:
|
|
||||||
node_embeddings = process_batch(texts, node_ids, missing_ids)
|
|
||||||
|
|
||||||
# Calculate distances
|
|
||||||
query_tensor = torch.tensor(query_vector, device=device).float()
|
|
||||||
node_embeddings_tensor = torch.tensor(node_embeddings, device=device).float()
|
|
||||||
|
|
||||||
calc_timer = DeviceTimer("distance calculation", device)
|
|
||||||
with calc_timer.timing():
|
|
||||||
with torch.no_grad():
|
|
||||||
if distance_metric == "l2":
|
|
||||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
|
|
||||||
query_np = query_tensor.cpu().numpy().astype(np.float32)
|
|
||||||
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
|
|
||||||
else: # mips or cosine
|
|
||||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
|
|
||||||
query_np = query_tensor.cpu().numpy()
|
|
||||||
distances = -np.dot(node_embeddings_np, query_np)
|
|
||||||
calc_timer.print_elapsed()
|
|
||||||
|
|
||||||
try:
|
|
||||||
response_payload = distances.flatten().tolist()
|
|
||||||
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
|
||||||
print(f"Sending distance response with {len(distances)} distances")
|
|
||||||
except Exception as pack_error:
|
|
||||||
print(f"Error packing MessagePack distance response: {pack_error}")
|
|
||||||
response_bytes = msgpack.packb([[]])
|
|
||||||
|
|
||||||
socket.send(response_bytes)
|
|
||||||
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
print(f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds")
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
continue
|
|
||||||
|
|
||||||
# Standard embedding request
|
|
||||||
if not isinstance(request_payload, list) or len(request_payload) != 1 or not isinstance(request_payload[0], list):
|
|
||||||
print(f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}")
|
|
||||||
socket.send(msgpack.packb([[], []]))
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Handle distance calculation requests
|
||||||
|
if (
|
||||||
|
isinstance(request_payload, list)
|
||||||
|
and len(request_payload) == 2
|
||||||
|
and isinstance(request_payload[0], list)
|
||||||
|
and isinstance(request_payload[1], list)
|
||||||
|
):
|
||||||
node_ids = request_payload[0]
|
node_ids = request_payload[0]
|
||||||
print(f"Request for {len(node_ids)} node embeddings")
|
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||||
|
|
||||||
except Exception as unpack_error:
|
logger.debug("Distance calculation request received")
|
||||||
print(f"Error unpacking MessagePack request: {unpack_error}")
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
|
# Get embeddings for node IDs
|
||||||
|
texts = []
|
||||||
|
for nid in node_ids:
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Passage ID {nid} not found")
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Process embeddings
|
||||||
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate distances
|
||||||
|
if distance_metric == "l2":
|
||||||
|
distances = np.sum(
|
||||||
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
|
)
|
||||||
|
else: # mips or cosine
|
||||||
|
distances = -np.dot(embeddings, query_vector)
|
||||||
|
|
||||||
|
response_payload = distances.flatten().tolist()
|
||||||
|
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
||||||
|
logger.debug(f"Sending distance response with {len(distances)} distances")
|
||||||
|
|
||||||
|
socket.send(response_bytes)
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Standard embedding request (passage ID lookup)
|
||||||
|
if (
|
||||||
|
not isinstance(request_payload, list)
|
||||||
|
or len(request_payload) != 1
|
||||||
|
or not isinstance(request_payload[0], list)
|
||||||
|
):
|
||||||
|
logger.error(
|
||||||
|
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
||||||
|
)
|
||||||
socket.send(msgpack.packb([[], []]))
|
socket.send(msgpack.packb([[], []]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
node_ids = request_payload[0]
|
||||||
|
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
||||||
|
|
||||||
# Look up texts by node IDs
|
# Look up texts by node IDs
|
||||||
texts = []
|
texts = []
|
||||||
missing_ids = []
|
for nid in node_ids:
|
||||||
with lookup_timer.timing():
|
try:
|
||||||
for nid in node_ids:
|
passage_data = passages.get_passage(str(nid))
|
||||||
try:
|
txt = passage_data["text"]
|
||||||
txtinfo = passages[nid]
|
if not txt:
|
||||||
if txtinfo is None or txtinfo["text"] == "":
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
|
||||||
else:
|
|
||||||
txt = txtinfo["text"]
|
|
||||||
except (KeyError, IndexError):
|
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
lookup_timer.print_elapsed()
|
except KeyError:
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
if missing_ids:
|
# Process embeddings
|
||||||
print(f"Missing passages for IDs: {missing_ids}")
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
|
logger.info(
|
||||||
# Process in chunks
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
total_size = len(texts)
|
)
|
||||||
print(f"Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
|
||||||
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
if total_size > max_batch_size:
|
|
||||||
print(f"Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
|
||||||
for i in range(0, total_size, max_batch_size):
|
|
||||||
end_idx = min(i + max_batch_size, total_size)
|
|
||||||
print(f"Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
|
||||||
|
|
||||||
chunk_texts = texts[i:end_idx]
|
|
||||||
chunk_ids = node_ids[i:end_idx]
|
|
||||||
|
|
||||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
|
||||||
all_embeddings.append(embeddings_chunk)
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
|
|
||||||
hidden = np.vstack(all_embeddings)
|
|
||||||
print(f"Combined embeddings shape: {hidden.shape}")
|
|
||||||
else:
|
|
||||||
hidden = process_batch(texts, node_ids, missing_ids)
|
|
||||||
|
|
||||||
# Serialization and response
|
# Serialization and response
|
||||||
ser_start = time.time()
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
logger.error(
|
||||||
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
|
)
|
||||||
|
raise AssertionError()
|
||||||
|
|
||||||
print(f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}")
|
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
if np.isnan(hidden).any() or np.isinf(hidden).any():
|
response_payload = [
|
||||||
print(f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
|
list(hidden_contiguous_f32.shape),
|
||||||
f"Requested IDs (sample): {node_ids[:5]}...{RESET}")
|
hidden_contiguous_f32.flatten().tolist(),
|
||||||
assert False
|
]
|
||||||
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||||
try:
|
|
||||||
hidden_contiguous_f32 = np.ascontiguousarray(hidden, dtype=np.float32)
|
|
||||||
response_payload = [
|
|
||||||
list(hidden_contiguous_f32.shape),
|
|
||||||
hidden_contiguous_f32.flatten().tolist()
|
|
||||||
]
|
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
|
||||||
except Exception as pack_error:
|
|
||||||
print(f"Error packing MessagePack response: {pack_error}")
|
|
||||||
response_bytes = msgpack.packb([[], []])
|
|
||||||
|
|
||||||
socket.send(response_bytes)
|
socket.send(response_bytes)
|
||||||
ser_end = time.time()
|
|
||||||
|
|
||||||
print(f"Serialize time: {ser_end - ser_start:.6f} seconds")
|
|
||||||
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
print("ZMQ socket timeout, continuing to listen")
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in ZMQ server loop: {e}")
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
|
||||||
try:
|
|
||||||
socket.send(msgpack.packb([[], []]))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Start warmup and server threads
|
traceback.print_exc()
|
||||||
if len(passages) > 0:
|
socket.send(msgpack.packb([[], []]))
|
||||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
|
||||||
warmup_thread.daemon = True
|
|
||||||
warmup_thread.start()
|
|
||||||
|
|
||||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("HNSW Server shutting down...")
|
logger.info("HNSW Server shutting down...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
parser.add_argument(
|
||||||
parser.add_argument("--embeddings-file", type=str, help="Pickle file containing pre-computed embeddings")
|
"--passages-file",
|
||||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
type=str,
|
||||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
help="JSON file containing passage ID to text mapping",
|
||||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
)
|
||||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
parser.add_argument(
|
||||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
"--model-name",
|
||||||
help="Embedding model name")
|
type=str,
|
||||||
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length")
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use")
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric", type=str, default="mips", help="Distance metric to use"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create and start the HNSW embedding server
|
# Create and start the HNSW embedding server
|
||||||
create_hnsw_embedding_server(
|
create_hnsw_embedding_server(
|
||||||
passages_file=args.passages_file,
|
passages_file=args.passages_file,
|
||||||
embeddings_file=args.embeddings_file,
|
|
||||||
use_fp16=args.use_fp16,
|
|
||||||
use_int8=args.use_int8,
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
zmq_port=args.zmq_port,
|
zmq_port=args.zmq_port,
|
||||||
max_batch_size=args.max_batch_size,
|
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
custom_max_length_param=args.custom_max_length,
|
|
||||||
distance_metric=args.distance_metric,
|
distance_metric=args.distance_metric,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
)
|
)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
# 文件: packages/leann-backend-hnsw/pyproject.toml
|
# packages/leann-backend-hnsw/pyproject.toml
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["scikit-build-core>=0.10", "numpy", "swig"]
|
requires = ["scikit-build-core>=0.10", "numpy", "swig"]
|
||||||
@@ -6,13 +6,22 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.1.0"
|
version = "0.1.16"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = [
|
||||||
|
"leann-core==0.1.16",
|
||||||
|
"numpy",
|
||||||
|
"pyzmq>=23.0.0",
|
||||||
|
"msgpack>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
# 回归到最标准的 scikit-build-core 配置
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
wheel.packages = ["leann_backend_hnsw"]
|
wheel.packages = ["leann_backend_hnsw"]
|
||||||
editable.mode = "redirect"
|
editable.mode = "redirect"
|
||||||
cmake.build-type = "Debug"
|
cmake.build-type = "Release"
|
||||||
build.verbose = true
|
build.verbose = true
|
||||||
|
build.tool-args = ["-j8"]
|
||||||
|
|
||||||
|
# CMake definitions to optimize compilation
|
||||||
|
[tool.scikit-build.cmake.define]
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||||
|
|||||||
1
packages/leann-backend-hnsw/third_party/cppzmq
vendored
Submodule
1
packages/leann-backend-hnsw/third_party/cppzmq
vendored
Submodule
Submodule packages/leann-backend-hnsw/third_party/cppzmq added at 3bcbd9dad2
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 2365db59a7...ff22e2c86b
1
packages/leann-backend-hnsw/third_party/libzmq
vendored
Submodule
1
packages/leann-backend-hnsw/third_party/libzmq
vendored
Submodule
Submodule packages/leann-backend-hnsw/third_party/libzmq added at 3e5ce5c1cd
1
packages/leann-backend-hnsw/third_party/msgpack-c
vendored
Submodule
1
packages/leann-backend-hnsw/third_party/msgpack-c
vendored
Submodule
Submodule packages/leann-backend-hnsw/third_party/msgpack-c added at a0b2ec09da
@@ -4,15 +4,46 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.1.0"
|
version = "0.1.16"
|
||||||
description = "Core API and plugin system for Leann."
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
|
|
||||||
|
# All required dependencies included
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy>=1.20.0"
|
"numpy>=1.20.0",
|
||||||
|
"tqdm>=4.60.0",
|
||||||
|
"psutil>=5.8.0",
|
||||||
|
"pyzmq>=23.0.0",
|
||||||
|
"msgpack>=1.0.0",
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"sentence-transformers>=2.2.0",
|
||||||
|
"llama-index-core>=0.12.0",
|
||||||
|
"llama-index-readers-file>=0.4.0", # Essential for document reading
|
||||||
|
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
"openai>=1.0.0",
|
||||||
|
"huggingface-hub>=0.20.0",
|
||||||
|
"transformers>=4.30.0",
|
||||||
|
"requests>=2.25.0",
|
||||||
|
"accelerate>=0.20.0",
|
||||||
|
"PyPDF2>=3.0.0",
|
||||||
|
"pymupdf>=1.23.0",
|
||||||
|
"pdfplumber>=0.10.0",
|
||||||
|
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||||
|
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
colab = [
|
||||||
|
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
|
||||||
|
"transformers>=4.30.0,<5.0.0", # Limit transformers version
|
||||||
|
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
leann = "leann.cli:main"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
@@ -1,17 +1,21 @@
|
|||||||
# This file makes the 'leann' directory a Python package.
|
# packages/leann-core/src/leann/__init__.py
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
|
||||||
from .api import LeannBuilder, LeannSearcher, LeannChat, SearchResult
|
# Fix OpenMP threading issues on macOS ARM64
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
os.environ["OMP_NUM_THREADS"] = "1"
|
||||||
|
os.environ["MKL_NUM_THREADS"] = "1"
|
||||||
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
|
os.environ["KMP_BLOCKTIME"] = "0"
|
||||||
|
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
|
||||||
|
if os.environ.get("CI") == "true":
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
# Import backends to ensure they are registered
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
try:
|
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||||
import leann_backend_hnsw
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
autodiscover_backends()
|
||||||
import leann_backend_diskann
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
__all__ = ["BACKEND_REGISTRY", "LeannBuilder", "LeannChat", "LeannSearcher"]
|
||||||
__all__ = ['LeannBuilder', 'LeannSearcher', 'LeannChat', 'SearchResult']
|
|
||||||
|
|||||||
@@ -1,244 +1,650 @@
|
|||||||
from .registry import BACKEND_REGISTRY
|
"""
|
||||||
from .interface import LeannBackendFactoryInterface
|
This file contains the core API for the LEANN project, now definitively updated
|
||||||
from typing import List, Dict, Any, Optional
|
with the correct, original embedding logic from the user's reference code.
|
||||||
import numpy as np
|
"""
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
import logging
|
||||||
import openai
|
import pickle
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
# --- Helper Functions for Embeddings ---
|
import numpy as np
|
||||||
|
|
||||||
def _get_openai_client():
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
"""Initializes and returns an OpenAI client, ensuring the API key is set."""
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError("OPENAI_API_KEY environment variable not set, which is required for OpenAI models.")
|
|
||||||
return openai.OpenAI(api_key=api_key)
|
|
||||||
|
|
||||||
def _is_openai_model(model_name: str) -> bool:
|
from .chat import get_llm
|
||||||
"""Checks if the model is likely an OpenAI embedding model."""
|
from .interface import LeannBackendFactoryInterface
|
||||||
# This is a simple check, can be improved with a more robust list.
|
from .registry import BACKEND_REGISTRY
|
||||||
return "ada" in model_name or "babbage" in model_name or model_name.startswith("text-embedding-")
|
|
||||||
|
|
||||||
def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
|
logger = logging.getLogger(__name__)
|
||||||
"""Computes embeddings for a list of text chunks using either SentenceTransformers or OpenAI."""
|
|
||||||
if _is_openai_model(model_name):
|
|
||||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
|
def get_registered_backends() -> list[str]:
|
||||||
client = _get_openai_client()
|
"""Get list of registered backend names."""
|
||||||
response = client.embeddings.create(model=model_name, input=chunks)
|
return list(BACKEND_REGISTRY.keys())
|
||||||
embeddings = [item.embedding for item in response.data]
|
|
||||||
|
|
||||||
|
def compute_embeddings(
|
||||||
|
chunks: list[str],
|
||||||
|
model_name: str,
|
||||||
|
mode: str = "sentence-transformers",
|
||||||
|
use_server: bool = True,
|
||||||
|
port: int | None = None,
|
||||||
|
is_build=False,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Computes embeddings using different backends.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of text chunks to embed
|
||||||
|
model_name: Name of the embedding model
|
||||||
|
mode: Embedding backend mode. Options:
|
||||||
|
- "sentence-transformers": Use sentence-transformers library (default)
|
||||||
|
- "mlx": Use MLX backend for Apple Silicon
|
||||||
|
- "openai": Use OpenAI embedding API
|
||||||
|
use_server: Whether to use embedding server (True for search, False for build)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy array of embeddings
|
||||||
|
"""
|
||||||
|
if use_server:
|
||||||
|
# Use embedding server (for search/query)
|
||||||
|
if port is None:
|
||||||
|
raise ValueError("port is required when use_server is True")
|
||||||
|
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||||
else:
|
else:
|
||||||
from sentence_transformers import SentenceTransformer
|
# Use direct computation (for build_index)
|
||||||
model = SentenceTransformer(model_name)
|
from .embedding_compute import (
|
||||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
|
compute_embeddings as compute_embeddings_direct,
|
||||||
embeddings = model.encode(chunks, show_progress_bar=True)
|
)
|
||||||
|
|
||||||
return np.asarray(embeddings, dtype=np.float32)
|
return compute_embeddings_direct(
|
||||||
|
chunks,
|
||||||
|
model_name,
|
||||||
|
mode=mode,
|
||||||
|
is_build=is_build,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_embedding_dimensions(model_name: str) -> int:
|
|
||||||
"""Gets the embedding dimensions for a given model."""
|
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
|
||||||
print(f"INFO: Calculating dimensions for model '{model_name}'...")
|
"""Computes embeddings using sentence-transformers.
|
||||||
if _is_openai_model(model_name):
|
|
||||||
client = _get_openai_client()
|
Args:
|
||||||
response = client.embeddings.create(model=model_name, input=["dummy text"])
|
chunks: List of text chunks to embed
|
||||||
return len(response.data[0].embedding)
|
model_name: Name of the sentence transformer model
|
||||||
else:
|
"""
|
||||||
from sentence_transformers import SentenceTransformer
|
logger.info(
|
||||||
model = SentenceTransformer(model_name)
|
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||||
dimension = model.get_sentence_embedding_dimension()
|
)
|
||||||
if dimension is None:
|
import msgpack
|
||||||
raise ValueError(f"Model '{model_name}' does not have a valid embedding dimension.")
|
import numpy as np
|
||||||
return dimension
|
import zmq
|
||||||
|
|
||||||
|
# Connect to embedding server
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
|
|
||||||
|
# Send chunks to server for embedding computation
|
||||||
|
request = chunks
|
||||||
|
socket.send(msgpack.packb(request))
|
||||||
|
|
||||||
|
# Receive embeddings from server
|
||||||
|
response = socket.recv()
|
||||||
|
embeddings_list = msgpack.unpackb(response)
|
||||||
|
|
||||||
|
# Convert back to numpy array
|
||||||
|
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchResult:
|
class SearchResult:
|
||||||
"""Represents a single search result."""
|
id: str
|
||||||
id: int
|
|
||||||
score: float
|
score: float
|
||||||
text: str
|
text: str
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class PassageManager:
|
||||||
|
def __init__(self, passage_sources: list[dict[str, Any]]):
|
||||||
|
self.offset_maps = {}
|
||||||
|
self.passage_files = {}
|
||||||
|
self.global_offset_map = {} # Combined map for fast lookup
|
||||||
|
|
||||||
|
for source in passage_sources:
|
||||||
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
|
passage_file = source["path"]
|
||||||
|
index_file = source["index_path"] # .idx file
|
||||||
|
|
||||||
|
# Fix path resolution for Colab and other environments
|
||||||
|
if not Path(index_file).is_absolute():
|
||||||
|
# If relative path, try to resolve it properly
|
||||||
|
index_file = str(Path(index_file).resolve())
|
||||||
|
|
||||||
|
if not Path(index_file).exists():
|
||||||
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
|
|
||||||
|
with open(index_file, "rb") as f:
|
||||||
|
offset_map = pickle.load(f)
|
||||||
|
self.offset_maps[passage_file] = offset_map
|
||||||
|
self.passage_files[passage_file] = passage_file
|
||||||
|
|
||||||
|
# Build global map for O(1) lookup
|
||||||
|
for passage_id, offset in offset_map.items():
|
||||||
|
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||||
|
|
||||||
|
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
||||||
|
if passage_id in self.global_offset_map:
|
||||||
|
passage_file, offset = self.global_offset_map[passage_id]
|
||||||
|
# Lazy file opening - only open when needed
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
f.seek(offset)
|
||||||
|
return json.loads(f.readline())
|
||||||
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||||
|
|
||||||
# --- Core Classes ---
|
|
||||||
|
|
||||||
class LeannBuilder:
|
class LeannBuilder:
|
||||||
"""
|
def __init__(
|
||||||
The builder is responsible for building the index, it will compute the embeddings and then build the index.
|
self,
|
||||||
It will also save the metadata of the index.
|
backend_name: str,
|
||||||
"""
|
embedding_model: str = "facebook/contriever",
|
||||||
def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", dimensions: Optional[int] = None, **backend_kwargs):
|
dimensions: int | None = None,
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
**backend_kwargs,
|
||||||
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||||
self.backend_factory = backend_factory
|
self.backend_factory = backend_factory
|
||||||
|
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.backend_kwargs = backend_kwargs
|
self.embedding_mode = embedding_mode
|
||||||
self.chunks: List[Dict[str, Any]] = []
|
|
||||||
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
|
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
# Check if we need to use cosine distance for normalized embeddings
|
||||||
self.chunks.append({"text": text, "metadata": metadata or {}})
|
normalized_embeddings_models = {
|
||||||
|
# OpenAI models
|
||||||
|
("openai", "text-embedding-ada-002"),
|
||||||
|
("openai", "text-embedding-3-small"),
|
||||||
|
("openai", "text-embedding-3-large"),
|
||||||
|
# Voyage AI models
|
||||||
|
("voyage", "voyage-2"),
|
||||||
|
("voyage", "voyage-3"),
|
||||||
|
("voyage", "voyage-large-2"),
|
||||||
|
("voyage", "voyage-multilingual-2"),
|
||||||
|
("voyage", "voyage-code-2"),
|
||||||
|
# Cohere models
|
||||||
|
("cohere", "embed-english-v3.0"),
|
||||||
|
("cohere", "embed-multilingual-v3.0"),
|
||||||
|
("cohere", "embed-english-light-v3.0"),
|
||||||
|
("cohere", "embed-multilingual-light-v3.0"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Also check for patterns in model names
|
||||||
|
is_normalized = False
|
||||||
|
current_model_lower = embedding_model.lower()
|
||||||
|
current_mode_lower = embedding_mode.lower()
|
||||||
|
|
||||||
|
# Check exact matches
|
||||||
|
for mode, model in normalized_embeddings_models:
|
||||||
|
if (current_mode_lower == mode and current_model_lower == model) or (
|
||||||
|
mode in current_mode_lower and model in current_model_lower
|
||||||
|
):
|
||||||
|
is_normalized = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check patterns
|
||||||
|
if not is_normalized:
|
||||||
|
# OpenAI patterns
|
||||||
|
if "openai" in current_mode_lower or "openai" in current_model_lower:
|
||||||
|
if any(
|
||||||
|
pattern in current_model_lower
|
||||||
|
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
|
||||||
|
):
|
||||||
|
is_normalized = True
|
||||||
|
# Voyage patterns
|
||||||
|
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
|
||||||
|
is_normalized = True
|
||||||
|
# Cohere patterns
|
||||||
|
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
|
||||||
|
if "embed" in current_model_lower:
|
||||||
|
is_normalized = True
|
||||||
|
|
||||||
|
# Handle distance metric
|
||||||
|
if is_normalized and "distance_metric" not in backend_kwargs:
|
||||||
|
backend_kwargs["distance_metric"] = "cosine"
|
||||||
|
warnings.warn(
|
||||||
|
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
|
||||||
|
f"Automatically setting distance_metric='cosine' for optimal performance. "
|
||||||
|
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
|
||||||
|
current_metric = backend_kwargs.get("distance_metric", "mips")
|
||||||
|
warnings.warn(
|
||||||
|
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
|
||||||
|
f"'{embedding_model}' may lead to suboptimal search results. "
|
||||||
|
f"Consider using 'cosine' distance metric for better performance.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.backend_kwargs = backend_kwargs
|
||||||
|
self.chunks: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
|
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
||||||
|
self.chunks.append(chunk_data)
|
||||||
|
|
||||||
def build_index(self, index_path: str):
|
def build_index(self, index_path: str):
|
||||||
if not self.chunks:
|
if not self.chunks:
|
||||||
raise ValueError("No chunks added. Use add_text() first.")
|
raise ValueError("No chunks added.")
|
||||||
|
|
||||||
if self.dimensions is None:
|
if self.dimensions is None:
|
||||||
self.dimensions = _get_embedding_dimensions(self.embedding_model)
|
self.dimensions = len(
|
||||||
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
|
compute_embeddings(
|
||||||
|
["dummy"],
|
||||||
|
self.embedding_model,
|
||||||
|
self.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
path = Path(index_path)
|
||||||
|
index_dir = path.parent
|
||||||
|
index_name = path.name
|
||||||
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||||
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||||
|
offset_map = {}
|
||||||
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
||||||
|
except ImportError:
|
||||||
|
chunk_iterator = self.chunks
|
||||||
|
|
||||||
|
for chunk in chunk_iterator:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk["metadata"],
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
texts_to_embed = [c["text"] for c in self.chunks]
|
texts_to_embed = [c["text"] for c in self.chunks]
|
||||||
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
|
embeddings = compute_embeddings(
|
||||||
|
texts_to_embed,
|
||||||
current_backend_kwargs = self.backend_kwargs.copy()
|
self.embedding_model,
|
||||||
current_backend_kwargs['dimensions'] = self.dimensions
|
self.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
is_build=True,
|
||||||
|
)
|
||||||
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||||
|
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
||||||
build_kwargs = current_backend_kwargs.copy()
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||||
build_kwargs['chunks'] = self.chunks
|
|
||||||
builder_instance.build(embeddings, index_path, **build_kwargs)
|
|
||||||
|
|
||||||
index_dir = Path(index_path).parent
|
|
||||||
leann_meta_path = index_dir / f"{Path(index_path).name}.meta.json"
|
|
||||||
|
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"version": "0.1.0",
|
"version": "1.0",
|
||||||
"backend_name": self.backend_name,
|
"backend_name": self.backend_name,
|
||||||
"embedding_model": self.embedding_model,
|
"embedding_model": self.embedding_model,
|
||||||
"dimensions": self.dimensions,
|
"dimensions": self.dimensions,
|
||||||
"backend_kwargs": self.backend_kwargs,
|
"backend_kwargs": self.backend_kwargs,
|
||||||
"num_chunks": len(self.chunks),
|
"embedding_mode": self.embedding_mode,
|
||||||
"chunks": self.chunks,
|
"passage_sources": [
|
||||||
|
{
|
||||||
|
"type": "jsonl",
|
||||||
|
"path": str(passages_file),
|
||||||
|
"index_path": str(offset_file),
|
||||||
|
}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
with open(leann_meta_path, 'w', encoding='utf-8') as f:
|
|
||||||
|
# Add storage status flags for HNSW backend
|
||||||
|
if self.backend_name == "hnsw":
|
||||||
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||||
|
meta_data["is_compact"] = is_compact
|
||||||
|
meta_data["is_pruned"] = (
|
||||||
|
is_compact and is_recompute
|
||||||
|
) # Pruned only if compact and recompute
|
||||||
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
print(f"INFO: Leann metadata saved to {leann_meta_path}")
|
|
||||||
|
def build_index_from_embeddings(self, index_path: str, embeddings_file: str):
|
||||||
|
"""
|
||||||
|
Build an index from pre-computed embeddings stored in a pickle file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path where the index will be saved
|
||||||
|
embeddings_file: Path to pickle file containing (ids, embeddings) tuple
|
||||||
|
"""
|
||||||
|
# Load pre-computed embeddings
|
||||||
|
with open(embeddings_file, "rb") as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
|
||||||
|
if not isinstance(data, tuple) or len(data) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ids, embeddings = data
|
||||||
|
|
||||||
|
if not isinstance(embeddings, np.ndarray):
|
||||||
|
raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}")
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate/set dimensions
|
||||||
|
embedding_dim = embeddings.shape[1]
|
||||||
|
if self.dimensions is None:
|
||||||
|
self.dimensions = embedding_dim
|
||||||
|
elif self.dimensions != embedding_dim:
|
||||||
|
raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure we have text data for each embedding
|
||||||
|
if len(self.chunks) != len(ids):
|
||||||
|
# If no text chunks provided, create placeholder text entries
|
||||||
|
if not self.chunks:
|
||||||
|
logger.info("No text chunks provided, creating placeholder entries...")
|
||||||
|
for id_val in ids:
|
||||||
|
self.add_text(
|
||||||
|
f"Document {id_val}",
|
||||||
|
metadata={"id": str(id_val), "from_embeddings": True},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build file structure
|
||||||
|
path = Path(index_path)
|
||||||
|
index_dir = path.parent
|
||||||
|
index_name = path.name
|
||||||
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||||
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||||
|
|
||||||
|
# Write passages and create offset map
|
||||||
|
offset_map = {}
|
||||||
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
|
for chunk in self.chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk["metadata"],
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
# Build the vector index using precomputed embeddings
|
||||||
|
string_ids = [str(id_val) for id_val in ids]
|
||||||
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||||
|
builder_instance.build(embeddings, string_ids, index_path)
|
||||||
|
|
||||||
|
# Create metadata file
|
||||||
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||||
|
meta_data = {
|
||||||
|
"version": "1.0",
|
||||||
|
"backend_name": self.backend_name,
|
||||||
|
"embedding_model": self.embedding_model,
|
||||||
|
"dimensions": self.dimensions,
|
||||||
|
"backend_kwargs": self.backend_kwargs,
|
||||||
|
"embedding_mode": self.embedding_mode,
|
||||||
|
"passage_sources": [
|
||||||
|
{
|
||||||
|
"type": "jsonl",
|
||||||
|
"path": str(passages_file),
|
||||||
|
"index_path": str(offset_file),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"built_from_precomputed_embeddings": True,
|
||||||
|
"embeddings_source": str(embeddings_file),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add storage status flags for HNSW backend
|
||||||
|
if self.backend_name == "hnsw":
|
||||||
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||||
|
meta_data["is_compact"] = is_compact
|
||||||
|
meta_data["is_pruned"] = is_compact and is_recompute
|
||||||
|
|
||||||
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
"""
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
The searcher is responsible for loading the index and performing the search.
|
# Fix path resolution for Colab and other environments
|
||||||
It will also load the metadata of the index.
|
if not Path(index_path).is_absolute():
|
||||||
"""
|
index_path = str(Path(index_path).resolve())
|
||||||
def __init__(self, index_path: str, **backend_kwargs):
|
|
||||||
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
|
|
||||||
if not leann_meta_path.exists():
|
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}. Was the index built with LeannBuilder?")
|
|
||||||
|
|
||||||
with open(leann_meta_path, 'r', encoding='utf-8') as f:
|
self.meta_path_str = f"{index_path}.meta.json"
|
||||||
|
if not Path(self.meta_path_str).exists():
|
||||||
|
parent_dir = Path(index_path).parent
|
||||||
|
print(
|
||||||
|
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
|
||||||
|
)
|
||||||
|
# highlight in red the filenotfound error
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m"
|
||||||
|
)
|
||||||
|
with open(self.meta_path_str, encoding="utf-8") as f:
|
||||||
self.meta_data = json.load(f)
|
self.meta_data = json.load(f)
|
||||||
|
backend_name = self.meta_data["backend_name"]
|
||||||
backend_name = self.meta_data['backend_name']
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
self.embedding_model = self.meta_data['embedding_model']
|
# Support both old and new format
|
||||||
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
|
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
|
index_path, **final_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
final_kwargs = self.meta_data.get("backend_kwargs", {})
|
def search(
|
||||||
final_kwargs.update(backend_kwargs)
|
self,
|
||||||
if 'dimensions' not in final_kwargs:
|
query: str,
|
||||||
final_kwargs['dimensions'] = self.meta_data.get('dimensions')
|
top_k: int = 5,
|
||||||
|
complexity: int = 64,
|
||||||
|
beam_width: int = 1,
|
||||||
|
prune_ratio: float = 0.0,
|
||||||
|
recompute_embeddings: bool = True,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
expected_zmq_port: int = 5557,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
|
logger.info(f" Query: '{query}'")
|
||||||
|
logger.info(f" Top_k: {top_k}")
|
||||||
|
logger.info(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
# Smart top_k detection and adjustment
|
||||||
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
|
total_docs = len(self.passage_manager.global_offset_map)
|
||||||
|
original_top_k = top_k
|
||||||
|
if top_k > total_docs:
|
||||||
|
top_k = total_docs
|
||||||
|
logger.warning(
|
||||||
|
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
|
||||||
|
)
|
||||||
|
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
|
||||||
|
|
||||||
def search(self, query: str, top_k: int = 5, **search_kwargs):
|
zmq_port = None
|
||||||
query_embedding = _compute_embeddings([query], self.embedding_model)
|
|
||||||
|
|
||||||
search_kwargs['embedding_model'] = self.embedding_model
|
start_time = time.time()
|
||||||
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
|
if recompute_embeddings:
|
||||||
|
zmq_port = self.backend_impl._ensure_server_running(
|
||||||
|
self.meta_path_str,
|
||||||
|
port=expected_zmq_port,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
del expected_zmq_port
|
||||||
|
zmq_time = time.time() - start_time
|
||||||
|
logger.info(f" Launching server time: {zmq_time} seconds")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
|
query,
|
||||||
|
use_server_if_available=recompute_embeddings,
|
||||||
|
zmq_port=zmq_port,
|
||||||
|
)
|
||||||
|
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
|
time.time() - start_time
|
||||||
|
# logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
results = self.backend_impl.search(
|
||||||
|
query_embedding,
|
||||||
|
top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
beam_width=beam_width,
|
||||||
|
prune_ratio=prune_ratio,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
pruning_strategy=pruning_strategy,
|
||||||
|
zmq_port=zmq_port,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
time.time() - start_time
|
||||||
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
for label, dist in zip(results['labels'][0], results['distances'][0]):
|
if "labels" in results and "distances" in results:
|
||||||
if label < len(self.meta_data['chunks']):
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||||
chunk_info = self.meta_data['chunks'][label]
|
for i, (string_id, dist) in enumerate(
|
||||||
enriched_results.append(SearchResult(
|
zip(results["labels"][0], results["distances"][0], strict=False)
|
||||||
id=label,
|
):
|
||||||
score=dist,
|
try:
|
||||||
text=chunk_info['text'],
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
metadata=chunk_info.get('metadata', {})
|
enriched_results.append(
|
||||||
))
|
SearchResult(
|
||||||
|
id=string_id,
|
||||||
|
score=dist,
|
||||||
|
text=passage_data["text"],
|
||||||
|
metadata=passage_data.get("metadata", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Color codes for better logging
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
# Truncate text for display (first 100 chars)
|
||||||
|
display_text = passage_data["text"]
|
||||||
|
logger.info(
|
||||||
|
f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
RED = "\033[91m"
|
||||||
|
logger.error(
|
||||||
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
|
||||||
class LeannChat:
|
class LeannChat:
|
||||||
"""
|
def __init__(
|
||||||
The chat is responsible for the conversation with the LLM.
|
self,
|
||||||
It will use the searcher to get the results and then use the LLM to generate the response.
|
index_path: str,
|
||||||
"""
|
llm_config: dict[str, Any] | None = None,
|
||||||
def __init__(self, index_path: str, backend_name: Optional[str] = None, llm_model: str = "gpt-4o", **kwargs):
|
enable_warmup: bool = False,
|
||||||
if backend_name is None:
|
**kwargs,
|
||||||
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
|
):
|
||||||
if not leann_meta_path.exists():
|
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}.")
|
self.llm = get_llm(llm_config)
|
||||||
with open(leann_meta_path, 'r', encoding='utf-8') as f:
|
|
||||||
meta_data = json.load(f)
|
|
||||||
backend_name = meta_data['backend_name']
|
|
||||||
|
|
||||||
self.searcher = LeannSearcher(index_path, **kwargs)
|
def ask(
|
||||||
self.llm_model = llm_model
|
self,
|
||||||
|
question: str,
|
||||||
def ask(self, question: str, top_k=5, **kwargs):
|
top_k: int = 5,
|
||||||
"""
|
complexity: int = 64,
|
||||||
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
|
beam_width: int = 1,
|
||||||
chat.ask(
|
prune_ratio: float = 0.0,
|
||||||
"What is ANN?",
|
recompute_embeddings: bool = True,
|
||||||
top_k=10,
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
complexity=64,
|
llm_kwargs: dict[str, Any] | None = None,
|
||||||
beam_width=8,
|
expected_zmq_port: int = 5557,
|
||||||
USE_DEFERRED_FETCH=True,
|
**search_kwargs,
|
||||||
skip_search_reorder=True,
|
):
|
||||||
recompute_beighbor_embeddings=True,
|
if llm_kwargs is None:
|
||||||
dedup_node_dis=True,
|
llm_kwargs = {}
|
||||||
prune_ratio=0.1,
|
search_time = time.time()
|
||||||
batch_recompute=True,
|
results = self.searcher.search(
|
||||||
global_pruning=True
|
question,
|
||||||
)
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
Supported kwargs:
|
beam_width=beam_width,
|
||||||
- complexity (int): Search complexity parameter (default: 32)
|
prune_ratio=prune_ratio,
|
||||||
- beam_width (int): Beam width for search (default: 4)
|
recompute_embeddings=recompute_embeddings,
|
||||||
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
|
pruning_strategy=pruning_strategy,
|
||||||
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
expected_zmq_port=expected_zmq_port,
|
||||||
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
**search_kwargs,
|
||||||
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
)
|
||||||
- prune_ratio (float): Pruning ratio for search (default: 0.0)
|
search_time = time.time() - search_time
|
||||||
- batch_recompute (bool): Enable batch recomputation (default: False)
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
- global_pruning (bool): Enable global pruning (default: False)
|
|
||||||
"""
|
|
||||||
|
|
||||||
results = self.searcher.search(question, top_k=top_k, **kwargs)
|
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
|
prompt = (
|
||||||
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
|
f"{context}\n\n"
|
||||||
|
f"Question: {question}\n\n"
|
||||||
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
|
)
|
||||||
|
|
||||||
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
|
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||||
|
return ans
|
||||||
print(f"DEBUG: Calling LLM with prompt: {prompt}...")
|
|
||||||
try:
|
|
||||||
client = _get_openai_client()
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=self.llm_model,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a helpful assistant that answers questions based on the provided context."},
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return response.choices[0].message.content
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to call OpenAI API: {e}")
|
|
||||||
return f"Error: Could not get a response from the LLM. {e}"
|
|
||||||
|
|
||||||
def start_interactive(self):
|
def start_interactive(self):
|
||||||
print("\nLeann Chat started (type 'quit' to exit)")
|
print("\nLeann Chat started (type 'quit' to exit)")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
user_input = input("You: ").strip()
|
user_input = input("You: ").strip()
|
||||||
if user_input.lower() in ['quit', 'exit']:
|
if user_input.lower() in ["quit", "exit"]:
|
||||||
break
|
break
|
||||||
if not user_input:
|
if not user_input:
|
||||||
continue
|
continue
|
||||||
|
|||||||
717
packages/leann-core/src/leann/chat.py
Normal file
717
packages/leann-core/src/leann/chat.py
Normal file
@@ -0,0 +1,717 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
This file contains the chat generation logic for the LEANN project,
|
||||||
|
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import difflib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_ollama_models() -> list[str]:
|
||||||
|
"""Check available Ollama models and return a list"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
return [model["name"] for model in data.get("models", [])]
|
||||||
|
return []
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||||
|
"""Check if a model exists in Ollama's remote library and return available tags
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(model_exists, available_tags): bool and list of matching tags
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Split model name and tag
|
||||||
|
if ":" in model_name:
|
||||||
|
base_model, requested_tag = model_name.split(":", 1)
|
||||||
|
else:
|
||||||
|
base_model, requested_tag = model_name, None
|
||||||
|
|
||||||
|
# First check if base model exists in library
|
||||||
|
library_response = requests.get("https://ollama.com/library", timeout=8)
|
||||||
|
if library_response.status_code != 200:
|
||||||
|
return True, [] # Assume exists if can't check
|
||||||
|
|
||||||
|
# Extract model names from library page
|
||||||
|
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
|
||||||
|
|
||||||
|
if base_model not in models_in_library:
|
||||||
|
return False, [] # Base model doesn't exist
|
||||||
|
|
||||||
|
# If base model exists, get available tags
|
||||||
|
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
|
||||||
|
if tags_response.status_code != 200:
|
||||||
|
return True, [] # Base model exists but can't get tags
|
||||||
|
|
||||||
|
# Extract tags for this model - be more specific to avoid HTML artifacts
|
||||||
|
tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
|
||||||
|
raw_tags = re.findall(tag_pattern, tags_response.text)
|
||||||
|
|
||||||
|
# Clean up tags - remove HTML artifacts and duplicates
|
||||||
|
available_tags = []
|
||||||
|
seen = set()
|
||||||
|
for tag in raw_tags:
|
||||||
|
# Skip if it looks like HTML (contains < or >)
|
||||||
|
if "<" in tag or ">" in tag:
|
||||||
|
continue
|
||||||
|
if tag not in seen:
|
||||||
|
seen.add(tag)
|
||||||
|
available_tags.append(tag)
|
||||||
|
|
||||||
|
# Check if exact model exists
|
||||||
|
if requested_tag is None:
|
||||||
|
# User just requested base model, suggest tags
|
||||||
|
return True, available_tags[:10] # Return up to 10 tags
|
||||||
|
else:
|
||||||
|
exact_match = model_name in available_tags
|
||||||
|
return exact_match, available_tags[:10]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If scraping fails, assume model might exist (don't block user)
|
||||||
|
return True, []
|
||||||
|
|
||||||
|
|
||||||
|
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
|
||||||
|
"""Use intelligent fuzzy search for Ollama models"""
|
||||||
|
if not available_models:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_lower = query.lower()
|
||||||
|
suggestions = []
|
||||||
|
|
||||||
|
# 1. Exact matches first
|
||||||
|
exact_matches = [m for m in available_models if query_lower == m.lower()]
|
||||||
|
suggestions.extend(exact_matches)
|
||||||
|
|
||||||
|
# 2. Starts with query
|
||||||
|
starts_with = [
|
||||||
|
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
|
||||||
|
]
|
||||||
|
suggestions.extend(starts_with)
|
||||||
|
|
||||||
|
# 3. Contains query
|
||||||
|
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
|
||||||
|
suggestions.extend(contains)
|
||||||
|
|
||||||
|
# 4. Base model name matching (remove version numbers)
|
||||||
|
def get_base_name(model_name: str) -> str:
|
||||||
|
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
|
||||||
|
return model_name.split(":")[0].split("-")[0]
|
||||||
|
|
||||||
|
query_base = get_base_name(query_lower)
|
||||||
|
base_matches = [
|
||||||
|
m
|
||||||
|
for m in available_models
|
||||||
|
if get_base_name(m.lower()) == query_base and m not in suggestions
|
||||||
|
]
|
||||||
|
suggestions.extend(base_matches)
|
||||||
|
|
||||||
|
# 5. Family/variant matching
|
||||||
|
model_families = {
|
||||||
|
"llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"],
|
||||||
|
"qwen": ["qwen", "qwen2", "qwen3"],
|
||||||
|
"gemma": ["gemma", "gemma2"],
|
||||||
|
"phi": ["phi", "phi2", "phi3"],
|
||||||
|
"mistral": ["mistral", "mixtral", "openhermes"],
|
||||||
|
"dolphin": ["dolphin", "openchat"],
|
||||||
|
"deepseek": ["deepseek", "deepseek-coder"],
|
||||||
|
}
|
||||||
|
|
||||||
|
query_family = None
|
||||||
|
for family, variants in model_families.items():
|
||||||
|
if any(variant in query_lower for variant in variants):
|
||||||
|
query_family = family
|
||||||
|
break
|
||||||
|
|
||||||
|
if query_family:
|
||||||
|
family_variants = model_families[query_family]
|
||||||
|
family_matches = [
|
||||||
|
m
|
||||||
|
for m in available_models
|
||||||
|
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
|
||||||
|
]
|
||||||
|
suggestions.extend(family_matches)
|
||||||
|
|
||||||
|
# 6. Use difflib for remaining fuzzy matches
|
||||||
|
remaining_models = [m for m in available_models if m not in suggestions]
|
||||||
|
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
|
||||||
|
suggestions.extend(difflib_matches)
|
||||||
|
|
||||||
|
return suggestions[:8] # Return top 8 suggestions
|
||||||
|
|
||||||
|
|
||||||
|
# Remove this function entirely - we don't need external API calls for Ollama
|
||||||
|
|
||||||
|
|
||||||
|
# Remove this too - no need for fallback
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
|
||||||
|
"""Use difflib to find similar model names"""
|
||||||
|
if not available_models:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get close matches using fuzzy matching
|
||||||
|
suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3)
|
||||||
|
return suggestions
|
||||||
|
|
||||||
|
|
||||||
|
def check_hf_model_exists(model_name: str) -> bool:
|
||||||
|
"""Quick check if HuggingFace model exists without downloading"""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import model_info
|
||||||
|
|
||||||
|
model_info(model_name)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_popular_hf_models() -> list[str]:
|
||||||
|
"""Return a list of popular HuggingFace models for suggestions"""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import list_models
|
||||||
|
|
||||||
|
# Get popular text-generation models, sorted by downloads
|
||||||
|
models = list_models(
|
||||||
|
filter="text-generation",
|
||||||
|
sort="downloads",
|
||||||
|
direction=-1,
|
||||||
|
limit=20, # Get top 20 most downloaded
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract model names and filter for chat/conversation models
|
||||||
|
model_names = []
|
||||||
|
chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"]
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
model_name = model.id if hasattr(model, "id") else str(model)
|
||||||
|
# Prioritize models with chat-related keywords
|
||||||
|
if any(keyword in model_name.lower() for keyword in chat_keywords):
|
||||||
|
model_names.append(model_name)
|
||||||
|
elif len(model_names) < 10: # Fill up with other popular models
|
||||||
|
model_names.append(model_name)
|
||||||
|
|
||||||
|
return model_names[:10] if model_names else _get_fallback_hf_models()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Fallback to static list if API call fails
|
||||||
|
return _get_fallback_hf_models()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fallback_hf_models() -> list[str]:
|
||||||
|
"""Fallback list of popular HuggingFace models"""
|
||||||
|
return [
|
||||||
|
"microsoft/DialoGPT-medium",
|
||||||
|
"microsoft/DialoGPT-large",
|
||||||
|
"facebook/blenderbot-400M-distill",
|
||||||
|
"microsoft/phi-2",
|
||||||
|
"deepseek-ai/deepseek-llm-7b-chat",
|
||||||
|
"microsoft/DialoGPT-small",
|
||||||
|
"facebook/blenderbot_small-90M",
|
||||||
|
"microsoft/phi-1_5",
|
||||||
|
"facebook/opt-350m",
|
||||||
|
"EleutherAI/gpt-neo-1.3B",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
||||||
|
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import list_models
|
||||||
|
|
||||||
|
# HF Hub's search is already fuzzy! It handles typos and partial matches
|
||||||
|
models = list_models(
|
||||||
|
search=query,
|
||||||
|
filter="text-generation",
|
||||||
|
sort="downloads",
|
||||||
|
direction=-1,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
|
||||||
|
|
||||||
|
# If direct search doesn't return enough results, try some variations
|
||||||
|
if len(model_names) < 3:
|
||||||
|
# Try searching for partial matches or common variations
|
||||||
|
variations = []
|
||||||
|
|
||||||
|
# Extract base name (e.g., "gpt3" from "gpt-3.5")
|
||||||
|
base_query = query.lower().replace("-", "").replace(".", "").replace("_", "")
|
||||||
|
if base_query != query.lower():
|
||||||
|
variations.append(base_query)
|
||||||
|
|
||||||
|
# Try common model name patterns
|
||||||
|
if "gpt" in query.lower():
|
||||||
|
variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"])
|
||||||
|
elif "llama" in query.lower():
|
||||||
|
variations.extend(["llama2", "alpaca", "vicuna"])
|
||||||
|
elif "bert" in query.lower():
|
||||||
|
variations.extend(["roberta", "distilbert", "albert"])
|
||||||
|
|
||||||
|
# Search with variations
|
||||||
|
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
|
||||||
|
try:
|
||||||
|
var_models = list_models(
|
||||||
|
search=var,
|
||||||
|
filter="text-generation",
|
||||||
|
sort="downloads",
|
||||||
|
direction=-1,
|
||||||
|
limit=3,
|
||||||
|
)
|
||||||
|
var_names = [
|
||||||
|
model.id if hasattr(model, "id") else str(model) for model in var_models
|
||||||
|
]
|
||||||
|
model_names.extend(var_names)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove duplicates while preserving order
|
||||||
|
seen = set()
|
||||||
|
unique_models = []
|
||||||
|
for model in model_names:
|
||||||
|
if model not in seen:
|
||||||
|
seen.add(model)
|
||||||
|
unique_models.append(model)
|
||||||
|
|
||||||
|
return unique_models[:limit]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If search fails, return empty list
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||||
|
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
||||||
|
return search_hf_models_fuzzy(query, limit)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
||||||
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
|
if llm_type == "ollama":
|
||||||
|
available_models = check_ollama_models()
|
||||||
|
if available_models and model_name not in available_models:
|
||||||
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
|
||||||
|
# Check if the model exists remotely and get available tags
|
||||||
|
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
|
||||||
|
|
||||||
|
if model_exists_remotely and model_name in available_tags:
|
||||||
|
# Exact model exists remotely - suggest pulling it
|
||||||
|
error_msg += "\n\nTo install the requested model:\n"
|
||||||
|
error_msg += f" ollama pull {model_name}\n"
|
||||||
|
|
||||||
|
# Show local alternatives
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\nOr use one of these similar installed models:\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
|
||||||
|
elif model_exists_remotely and available_tags:
|
||||||
|
# Base model exists but requested tag doesn't - suggest correct tags
|
||||||
|
base_model = model_name.split(":")[0]
|
||||||
|
requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
|
||||||
|
|
||||||
|
error_msg += (
|
||||||
|
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
||||||
|
)
|
||||||
|
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
|
||||||
|
for i, tag in enumerate(available_tags[:8], 1):
|
||||||
|
error_msg += f" {i}. ollama pull {tag}\n"
|
||||||
|
if len(available_tags) > 8:
|
||||||
|
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
|
||||||
|
|
||||||
|
# Also show local alternatives
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\nOr use one of these similar installed models:\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Model doesn't exist remotely - show fuzzy suggestions
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||||
|
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
else:
|
||||||
|
error_msg += "\n\nYour installed models:\n"
|
||||||
|
for i, model in enumerate(available_models[:8], 1):
|
||||||
|
error_msg += f" {i}. {model}\n"
|
||||||
|
if len(available_models) > 8:
|
||||||
|
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||||
|
|
||||||
|
error_msg += "\n\nCommands:"
|
||||||
|
error_msg += "\n ollama list # List installed models"
|
||||||
|
if model_exists_remotely and available_tags:
|
||||||
|
if model_name in available_tags:
|
||||||
|
error_msg += f"\n ollama pull {model_name} # Install requested model"
|
||||||
|
else:
|
||||||
|
error_msg += (
|
||||||
|
f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
||||||
|
)
|
||||||
|
error_msg += "\n https://ollama.com/library # Browse available models"
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
elif llm_type == "hf":
|
||||||
|
# For HF models, we can do a quick existence check
|
||||||
|
if not check_hf_model_exists(model_name):
|
||||||
|
# Use HF Hub's native fuzzy search directly
|
||||||
|
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
|
||||||
|
|
||||||
|
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
|
||||||
|
if search_suggestions:
|
||||||
|
error_msg += "\n\nDid you mean one of these?\n"
|
||||||
|
for i, suggestion in enumerate(search_suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
else:
|
||||||
|
# Fallback to popular models if search returns nothing
|
||||||
|
popular_models = get_popular_hf_models()
|
||||||
|
error_msg += "\n\nPopular chat models:\n"
|
||||||
|
for i, model in enumerate(popular_models[:5], 1):
|
||||||
|
error_msg += f" {i}. {model}\n"
|
||||||
|
|
||||||
|
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
return None # Model is valid or we can't check
|
||||||
|
|
||||||
|
|
||||||
|
class LLMInterface(ABC):
|
||||||
|
"""Abstract base class for a generic Language Model (LLM) interface."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
|
||||||
|
chat.ask(
|
||||||
|
"What is ANN?",
|
||||||
|
top_k=10,
|
||||||
|
complexity=64,
|
||||||
|
beam_width=8,
|
||||||
|
USE_DEFERRED_FETCH=True,
|
||||||
|
skip_search_reorder=True,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
dedup_node_dis=True,
|
||||||
|
prune_ratio=0.1,
|
||||||
|
batch_recompute=True,
|
||||||
|
global_pruning=True
|
||||||
|
)
|
||||||
|
|
||||||
|
Supported kwargs:
|
||||||
|
- complexity (int): Search complexity parameter (default: 32)
|
||||||
|
- beam_width (int): Beam width for search (default: 4)
|
||||||
|
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
|
||||||
|
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
||||||
|
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
||||||
|
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
||||||
|
- prune_ratio (float): Pruning ratio for search (default: 0.0)
|
||||||
|
- batch_recompute (bool): Enable batch recomputation (default: False)
|
||||||
|
- global_pruning (bool): Enable global pruning (default: False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# """
|
||||||
|
# Sends a prompt to the LLM and returns the generated text.
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# prompt: The input prompt for the LLM.
|
||||||
|
# **kwargs: Additional keyword arguments for the LLM backend.
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# The response string from the LLM.
|
||||||
|
# """
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaChat(LLMInterface):
|
||||||
|
"""LLM interface for Ollama models."""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
||||||
|
self.model = model
|
||||||
|
self.host = host
|
||||||
|
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Check if the Ollama server is responsive
|
||||||
|
if host:
|
||||||
|
requests.get(host)
|
||||||
|
|
||||||
|
# Pre-check model availability with helpful suggestions
|
||||||
|
model_error = validate_model_and_suggest(model, "ollama")
|
||||||
|
if model_error:
|
||||||
|
raise ValueError(model_error)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||||
|
)
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
full_url = f"{self.host}/api/generate"
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False, # Keep it simple for now
|
||||||
|
"options": kwargs,
|
||||||
|
}
|
||||||
|
logger.debug(f"Sending request to Ollama: {payload}")
|
||||||
|
try:
|
||||||
|
logger.info("Sending request to Ollama and waiting for response...")
|
||||||
|
response = requests.post(full_url, data=json.dumps(payload))
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# The response from Ollama can be a stream of JSON objects, handle this
|
||||||
|
response_parts = response.text.strip().split("\n")
|
||||||
|
full_response = ""
|
||||||
|
for part in response_parts:
|
||||||
|
if part:
|
||||||
|
json_part = json.loads(part)
|
||||||
|
full_response += json_part.get("response", "")
|
||||||
|
if json_part.get("done"):
|
||||||
|
break
|
||||||
|
return full_response
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Error communicating with Ollama: {e}")
|
||||||
|
return f"Error: Could not get a response from Ollama. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
class HFChat(LLMInterface):
|
||||||
|
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||||
|
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||||
|
|
||||||
|
# Pre-check model availability with helpful suggestions
|
||||||
|
model_error = validate_model_and_suggest(model_name, "hf")
|
||||||
|
if model_error:
|
||||||
|
raise ValueError(model_error)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-detect device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device = "cuda"
|
||||||
|
logger.info("CUDA is available. Using GPU.")
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
self.device = "mps"
|
||||||
|
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||||
|
else:
|
||||||
|
self.device = "cpu"
|
||||||
|
logger.info("No GPU detected. Using CPU.")
|
||||||
|
|
||||||
|
# Load tokenizer and model
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||||
|
device_map="auto" if self.device != "cpu" else None,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move model to device if not using device_map
|
||||||
|
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
|
# Set pad token if not present
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
print("kwargs in HF: ", kwargs)
|
||||||
|
# Check if this is a Qwen model and add /no_think by default
|
||||||
|
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
|
||||||
|
|
||||||
|
# For Qwen models, automatically add /no_think to the prompt
|
||||||
|
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
|
||||||
|
prompt = prompt + " /no_think"
|
||||||
|
|
||||||
|
# Prepare chat template
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
|
# Apply chat template if available
|
||||||
|
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||||
|
try:
|
||||||
|
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Chat template failed, using raw prompt: {e}")
|
||||||
|
formatted_prompt = prompt
|
||||||
|
else:
|
||||||
|
# Fallback for models without chat template
|
||||||
|
formatted_prompt = prompt
|
||||||
|
|
||||||
|
# Tokenize input
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
formatted_prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move inputs to device
|
||||||
|
if self.device != "cpu":
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# Set generation parameters
|
||||||
|
generation_config = {
|
||||||
|
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
|
||||||
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
"top_p": kwargs.get("top_p", 0.9),
|
||||||
|
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||||
|
"pad_token_id": self.tokenizer.eos_token_id,
|
||||||
|
"eos_token_id": self.tokenizer.eos_token_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle temperature=0 for greedy decoding
|
||||||
|
if generation_config["temperature"] == 0.0:
|
||||||
|
generation_config["do_sample"] = False
|
||||||
|
generation_config.pop("temperature")
|
||||||
|
|
||||||
|
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model.generate(**inputs, **generation_config)
|
||||||
|
|
||||||
|
# Decode response
|
||||||
|
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
||||||
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChat(LLMInterface):
|
||||||
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
self.client = openai.OpenAI(api_key=self.api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
# Default parameters for OpenAI
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(**params)
|
||||||
|
return response.choices[0].message.content.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error communicating with OpenAI: {e}")
|
||||||
|
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
class SimulatedChat(LLMInterface):
|
||||||
|
"""A simple simulated chat for testing and development."""
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
logger.info("Simulating LLM call...")
|
||||||
|
print("Prompt sent to LLM (simulation):", prompt[:500] + "...")
|
||||||
|
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
|
||||||
|
"""
|
||||||
|
Factory function to get an LLM interface based on configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: A dictionary specifying the LLM type and its parameters.
|
||||||
|
Example: {"type": "ollama", "model": "llama3"}
|
||||||
|
{"type": "hf", "model": "distilgpt2"}
|
||||||
|
None (for simulation mode)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of an LLMInterface subclass.
|
||||||
|
"""
|
||||||
|
if llm_config is None:
|
||||||
|
llm_config = {
|
||||||
|
"type": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_type = llm_config.get("type", "openai")
|
||||||
|
model = llm_config.get("model")
|
||||||
|
|
||||||
|
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
|
||||||
|
|
||||||
|
if llm_type == "ollama":
|
||||||
|
return OllamaChat(
|
||||||
|
model=model or "llama3:8b",
|
||||||
|
host=llm_config.get("host", "http://localhost:11434"),
|
||||||
|
)
|
||||||
|
elif llm_type == "hf":
|
||||||
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
|
elif llm_type == "openai":
|
||||||
|
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||||
|
elif llm_type == "simulated":
|
||||||
|
return SimulatedChat()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown LLM type: '{llm_type}'")
|
||||||
367
packages/leann-core/src/leann/cli.py
Normal file
367
packages/leann-core/src/leann/cli.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||||
|
"""Extract text from PDF using PyMuPDF for better quality."""
|
||||||
|
try:
|
||||||
|
import fitz # PyMuPDF
|
||||||
|
|
||||||
|
doc = fitz.open(file_path)
|
||||||
|
text = ""
|
||||||
|
for page in doc:
|
||||||
|
text += page.get_text()
|
||||||
|
doc.close()
|
||||||
|
return text
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to default reader
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
|
||||||
|
"""Extract text from PDF using pdfplumber for better quality."""
|
||||||
|
try:
|
||||||
|
import pdfplumber
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
with pdfplumber.open(file_path) as pdf:
|
||||||
|
for page in pdf.pages:
|
||||||
|
text += page.extract_text() or ""
|
||||||
|
return text
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to default reader
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LeannCLI:
|
||||||
|
def __init__(self):
|
||||||
|
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||||
|
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_index_path(self, index_name: str) -> str:
|
||||||
|
index_dir = self.indexes_dir / index_name
|
||||||
|
return str(index_dir / "documents.leann")
|
||||||
|
|
||||||
|
def index_exists(self, index_name: str) -> bool:
|
||||||
|
index_dir = self.indexes_dir / index_name
|
||||||
|
meta_file = index_dir / "documents.leann.meta.json"
|
||||||
|
return meta_file.exists()
|
||||||
|
|
||||||
|
def create_parser(self) -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="leann",
|
||||||
|
description="LEANN - Local Enhanced AI Navigation",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
leann build my-docs --docs ./documents # Build index named my-docs
|
||||||
|
leann search my-docs "query" # Search in my-docs index
|
||||||
|
leann ask my-docs "question" # Ask my-docs index
|
||||||
|
leann list # List all stored indexes
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||||
|
build_parser.add_argument("index_name", help="Index name")
|
||||||
|
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||||
|
)
|
||||||
|
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
||||||
|
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
||||||
|
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||||
|
build_parser.add_argument("--complexity", type=int, default=64)
|
||||||
|
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||||
|
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||||
|
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||||
|
|
||||||
|
# Search command
|
||||||
|
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||||
|
search_parser.add_argument("index_name", help="Index name")
|
||||||
|
search_parser.add_argument("query", help="Search query")
|
||||||
|
search_parser.add_argument("--top-k", type=int, default=5)
|
||||||
|
search_parser.add_argument("--complexity", type=int, default=64)
|
||||||
|
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
|
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
|
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--pruning-strategy",
|
||||||
|
choices=["global", "local", "proportional"],
|
||||||
|
default="global",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ask command
|
||||||
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
|
ask_parser.add_argument("index_name", help="Index name")
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="ollama",
|
||||||
|
choices=["simulated", "ollama", "hf", "openai"],
|
||||||
|
)
|
||||||
|
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||||
|
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||||
|
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||||
|
ask_parser.add_argument("--top-k", type=int, default=20)
|
||||||
|
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||||
|
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
|
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
|
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--pruning-strategy",
|
||||||
|
choices=["global", "local", "proportional"],
|
||||||
|
default="global",
|
||||||
|
)
|
||||||
|
|
||||||
|
# List command
|
||||||
|
subparsers.add_parser("list", help="List all indexes")
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def list_indexes(self):
|
||||||
|
print("Stored LEANN indexes:")
|
||||||
|
|
||||||
|
if not self.indexes_dir.exists():
|
||||||
|
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||||
|
|
||||||
|
if not index_dirs:
|
||||||
|
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Found {len(index_dirs)} indexes:")
|
||||||
|
for i, index_dir in enumerate(index_dirs, 1):
|
||||||
|
index_name = index_dir.name
|
||||||
|
status = "✓" if self.index_exists(index_name) else "✗"
|
||||||
|
|
||||||
|
print(f" {i}. {index_name} [{status}]")
|
||||||
|
if self.index_exists(index_name):
|
||||||
|
index_dir / "documents.leann.meta.json"
|
||||||
|
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
|
||||||
|
1024 * 1024
|
||||||
|
)
|
||||||
|
print(f" Size: {size_mb:.1f} MB")
|
||||||
|
|
||||||
|
if index_dirs:
|
||||||
|
example_name = index_dirs[0].name
|
||||||
|
print("\nUsage:")
|
||||||
|
print(f' leann search {example_name} "your query"')
|
||||||
|
print(f" leann ask {example_name} --interactive")
|
||||||
|
|
||||||
|
def load_documents(self, docs_dir: str):
|
||||||
|
print(f"Loading documents from {docs_dir}...")
|
||||||
|
|
||||||
|
# Try to use better PDF parsers first
|
||||||
|
documents = []
|
||||||
|
docs_path = Path(docs_dir)
|
||||||
|
|
||||||
|
for file_path in docs_path.rglob("*.pdf"):
|
||||||
|
print(f"Processing PDF: {file_path}")
|
||||||
|
|
||||||
|
# Try PyMuPDF first (best quality)
|
||||||
|
text = extract_pdf_text_with_pymupdf(str(file_path))
|
||||||
|
if text is None:
|
||||||
|
# Try pdfplumber
|
||||||
|
text = extract_pdf_text_with_pdfplumber(str(file_path))
|
||||||
|
|
||||||
|
if text:
|
||||||
|
# Create a simple document structure
|
||||||
|
from llama_index.core import Document
|
||||||
|
|
||||||
|
doc = Document(text=text, metadata={"source": str(file_path)})
|
||||||
|
documents.append(doc)
|
||||||
|
else:
|
||||||
|
# Fallback to default reader
|
||||||
|
print(f"Using default reader for {file_path}")
|
||||||
|
default_docs = SimpleDirectoryReader(
|
||||||
|
str(file_path.parent),
|
||||||
|
filename_as_id=True,
|
||||||
|
required_exts=[file_path.suffix],
|
||||||
|
).load_data()
|
||||||
|
documents.extend(default_docs)
|
||||||
|
|
||||||
|
# Load other file types with default reader
|
||||||
|
other_docs = SimpleDirectoryReader(
|
||||||
|
docs_dir,
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".txt", ".md", ".docx"],
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
documents.extend(other_docs)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
async def build_index(self, args):
|
||||||
|
docs_dir = args.docs
|
||||||
|
index_name = args.index_name
|
||||||
|
index_dir = self.indexes_dir / index_name
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
|
if index_dir.exists() and not args.force:
|
||||||
|
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||||
|
return
|
||||||
|
|
||||||
|
all_texts = self.load_documents(docs_dir)
|
||||||
|
if not all_texts:
|
||||||
|
print("No documents found")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
complexity=args.complexity,
|
||||||
|
is_compact=args.compact,
|
||||||
|
is_recompute=args.recompute,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"Index built at {index_path}")
|
||||||
|
|
||||||
|
async def search_documents(self, args):
|
||||||
|
index_name = args.index_name
|
||||||
|
query = args.query
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
|
if not self.index_exists(index_name):
|
||||||
|
print(
|
||||||
|
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
|
results = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Search results for '{query}' (top {len(results)}):")
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f"{i}. Score: {result.score:.3f}")
|
||||||
|
print(f" {result.text[:200]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
async def ask_questions(self, args):
|
||||||
|
index_name = args.index_name
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
|
if not self.index_exists(index_name):
|
||||||
|
print(
|
||||||
|
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Starting chat with index '{index_name}'...")
|
||||||
|
print(f"Using {args.model} ({args.llm})")
|
||||||
|
|
||||||
|
llm_config = {"type": args.llm, "model": args.model}
|
||||||
|
if args.llm == "ollama":
|
||||||
|
llm_config["host"] = args.host
|
||||||
|
|
||||||
|
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||||
|
|
||||||
|
if args.interactive:
|
||||||
|
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("\nYou: ").strip()
|
||||||
|
if user_input.lower() in ["quit", "exit", "q"]:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
user_input,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
)
|
||||||
|
print(f"LEANN: {response}")
|
||||||
|
else:
|
||||||
|
query = input("Enter your question: ").strip()
|
||||||
|
if query:
|
||||||
|
response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
)
|
||||||
|
print(f"LEANN: {response}")
|
||||||
|
|
||||||
|
async def run(self, args=None):
|
||||||
|
parser = self.create_parser()
|
||||||
|
|
||||||
|
if args is None:
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.command:
|
||||||
|
parser.print_help()
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.command == "list":
|
||||||
|
self.list_indexes()
|
||||||
|
elif args.command == "build":
|
||||||
|
await self.build_index(args)
|
||||||
|
elif args.command == "search":
|
||||||
|
await self.search_documents(args)
|
||||||
|
elif args.command == "ask":
|
||||||
|
await self.ask_questions(args)
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
asyncio.run(cli.run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
367
packages/leann-core/src/leann/embedding_compute.py
Normal file
367
packages/leann-core/src/leann/embedding_compute.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
"""
|
||||||
|
Unified embedding computation module
|
||||||
|
Consolidates all embedding computation logic using SentenceTransformer
|
||||||
|
Preserves all optimization parameters to ensure performance
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Set up logger with proper level
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
# Global model cache to avoid repeated loading
|
||||||
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings(
|
||||||
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
mode: str = "sentence-transformers",
|
||||||
|
is_build: bool = False,
|
||||||
|
batch_size: int = 32,
|
||||||
|
adaptive_optimization: bool = True,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Unified embedding computation entry point
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Model name
|
||||||
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||||
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
|
batch_size: Batch size for processing
|
||||||
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
if mode == "sentence-transformers":
|
||||||
|
return compute_embeddings_sentence_transformers(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
is_build=is_build,
|
||||||
|
batch_size=batch_size,
|
||||||
|
adaptive_optimization=adaptive_optimization,
|
||||||
|
)
|
||||||
|
elif mode == "openai":
|
||||||
|
return compute_embeddings_openai(texts, model_name)
|
||||||
|
elif mode == "mlx":
|
||||||
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_sentence_transformers(
|
||||||
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
use_fp16: bool = True,
|
||||||
|
device: str = "auto",
|
||||||
|
batch_size: int = 32,
|
||||||
|
is_build: bool = False,
|
||||||
|
adaptive_optimization: bool = True,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Model name
|
||||||
|
use_fp16: Whether to use FP16 precision
|
||||||
|
device: Device to use ('auto', 'cuda', 'mps', 'cpu')
|
||||||
|
batch_size: Batch size for processing
|
||||||
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
|
"""
|
||||||
|
# Handle empty input
|
||||||
|
if not texts:
|
||||||
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-detect device
|
||||||
|
if device == "auto":
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
# Apply optimizations based on benchmark results
|
||||||
|
if adaptive_optimization:
|
||||||
|
# Use optimal batch_size constants for different devices based on benchmark results
|
||||||
|
if device == "mps":
|
||||||
|
batch_size = 128 # MPS optimal batch size from benchmark
|
||||||
|
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
||||||
|
batch_size = 32
|
||||||
|
elif device == "cuda":
|
||||||
|
batch_size = 256 # CUDA optimal batch size
|
||||||
|
# Keep original batch_size for CPU
|
||||||
|
|
||||||
|
# Create cache key
|
||||||
|
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||||
|
|
||||||
|
# Check if model is already cached
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
logger.info(f"Using cached optimized model: {model_name}")
|
||||||
|
model = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
logger.info(f"Loading and caching optimized SentenceTransformer model: {model_name}")
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Apply hardware optimizations
|
||||||
|
if device == "cuda":
|
||||||
|
# TODO: Haven't tested this yet
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.deterministic = False
|
||||||
|
torch.cuda.set_per_process_memory_fraction(0.9)
|
||||||
|
elif device == "mps":
|
||||||
|
try:
|
||||||
|
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
||||||
|
torch.mps.set_per_process_memory_fraction(0.9)
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning("Some MPS optimizations not available in this PyTorch version")
|
||||||
|
elif device == "cpu":
|
||||||
|
# TODO: Haven't tested this yet
|
||||||
|
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
||||||
|
try:
|
||||||
|
torch.backends.mkldnn.enabled = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Prepare optimized model and tokenizer parameters
|
||||||
|
model_kwargs = {
|
||||||
|
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
"_fast_init": True,
|
||||||
|
"attn_implementation": "eager", # Use eager attention for speed
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer_kwargs = {
|
||||||
|
"use_fast": True,
|
||||||
|
"padding": True,
|
||||||
|
"truncation": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try local loading first
|
||||||
|
model_kwargs["local_files_only"] = True
|
||||||
|
tokenizer_kwargs["local_files_only"] = True
|
||||||
|
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
logger.info("Model loaded successfully! (local + optimized)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Local loading failed ({e}), trying network download...")
|
||||||
|
# Fallback to network loading
|
||||||
|
model_kwargs["local_files_only"] = False
|
||||||
|
tokenizer_kwargs["local_files_only"] = False
|
||||||
|
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
logger.info("Model loaded successfully! (network + optimized)")
|
||||||
|
|
||||||
|
# Apply additional optimizations based on mode
|
||||||
|
if use_fp16 and device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
model = model.half()
|
||||||
|
logger.info(f"Applied FP16 precision: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"FP16 optimization failed: {e}")
|
||||||
|
|
||||||
|
# Apply torch.compile optimization
|
||||||
|
if device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
model = torch.compile(model, mode="reduce-overhead", dynamic=True)
|
||||||
|
logger.info(f"Applied torch.compile optimization: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"torch.compile optimization failed: {e}")
|
||||||
|
|
||||||
|
# Set model to eval mode and disable gradients for inference
|
||||||
|
model.eval()
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
# Cache the model
|
||||||
|
_model_cache[cache_key] = model
|
||||||
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
|
# Compute embeddings with optimized inference mode
|
||||||
|
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
||||||
|
|
||||||
|
# Use torch.inference_mode for optimal performance
|
||||||
|
with torch.inference_mode():
|
||||||
|
embeddings = model.encode(
|
||||||
|
texts,
|
||||||
|
batch_size=batch_size,
|
||||||
|
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||||
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
|
"""Compute embeddings using OpenAI API"""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import openai
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
# Cache OpenAI client
|
||||||
|
cache_key = "openai_client"
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
client = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
client = openai.OpenAI(api_key=api_key)
|
||||||
|
_model_cache[cache_key] = client
|
||||||
|
logger.info("OpenAI client cached")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
print(f"len of texts: {len(texts)}")
|
||||||
|
|
||||||
|
# OpenAI has limits on batch size and input length
|
||||||
|
max_batch_size = 1000 # Conservative batch size
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||||
|
batch_range = range(0, len(texts), max_batch_size)
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback when tqdm is not available
|
||||||
|
batch_iterator = range(0, len(texts), max_batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_texts = texts[i : i + max_batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||||
|
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch {i} failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
print(f"len of embeddings: {len(embeddings)}")
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
||||||
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
|
"""Computes embeddings using an MLX model."""
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.utils import load
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache MLX model and tokenizer
|
||||||
|
cache_key = f"mlx_{model_name}"
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
logger.info(f"Using cached MLX model: {model_name}")
|
||||||
|
model, tokenizer = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
logger.info(f"Loading and caching MLX model: {model_name}")
|
||||||
|
model, tokenizer = load(model_name)
|
||||||
|
_model_cache[cache_key] = (model, tokenizer)
|
||||||
|
logger.info(f"MLX model cached: {cache_key}")
|
||||||
|
|
||||||
|
# Process chunks in batches with progress bar
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
batch_iterator = range(0, len(chunks), batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_chunks = chunks[i : i + batch_size]
|
||||||
|
|
||||||
|
# Tokenize all chunks in the batch
|
||||||
|
batch_token_ids = []
|
||||||
|
for chunk in batch_chunks:
|
||||||
|
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||||
|
batch_token_ids.append(token_ids)
|
||||||
|
|
||||||
|
# Pad sequences to the same length for batch processing
|
||||||
|
max_length = max(len(ids) for ids in batch_token_ids)
|
||||||
|
padded_token_ids = []
|
||||||
|
for token_ids in batch_token_ids:
|
||||||
|
# Pad with tokenizer.pad_token_id or 0
|
||||||
|
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||||
|
padded_token_ids.append(padded)
|
||||||
|
|
||||||
|
# Convert to MLX array with batch dimension
|
||||||
|
input_ids = mx.array(padded_token_ids)
|
||||||
|
|
||||||
|
# Get embeddings for the batch
|
||||||
|
embeddings = model(input_ids)
|
||||||
|
|
||||||
|
# Mean pooling for each sequence in the batch
|
||||||
|
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||||
|
|
||||||
|
# Convert batch embeddings to numpy
|
||||||
|
for j in range(len(batch_chunks)):
|
||||||
|
pooled_list = pooled[j].tolist() # Convert to list
|
||||||
|
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||||
|
all_embeddings.append(pooled_numpy)
|
||||||
|
|
||||||
|
# Stack numpy arrays
|
||||||
|
return np.stack(all_embeddings)
|
||||||
413
packages/leann-core/src/leann/embedding_server_manager.py
Normal file
413
packages/leann-core/src/leann/embedding_server_manager.py
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
import atexit
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||||
|
format="%(levelname)s - %(name)s - %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_colab_environment() -> bool:
|
||||||
|
"""Check if we're running in Google Colab environment."""
|
||||||
|
return "COLAB_GPU" in os.environ or "COLAB_TPU" in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def _get_available_port(start_port: int = 5557) -> int:
|
||||||
|
"""Get an available port starting from start_port."""
|
||||||
|
port = start_port
|
||||||
|
while port < start_port + 100: # Try up to 100 ports
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("localhost", port))
|
||||||
|
return port
|
||||||
|
except OSError:
|
||||||
|
port += 1
|
||||||
|
raise RuntimeError(f"No available ports found in range {start_port}-{start_port + 100}")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_port(port: int) -> bool:
|
||||||
|
"""Check if a port is in use"""
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _check_process_matches_config(
|
||||||
|
port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the process using the port matches our expected model and passages file.
|
||||||
|
Returns True if matches, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||||
|
if not _is_process_listening_on_port(proc, port):
|
||||||
|
continue
|
||||||
|
|
||||||
|
cmdline = proc.info["cmdline"]
|
||||||
|
if not cmdline:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return _check_cmdline_matches_config(
|
||||||
|
cmdline, port, expected_model, expected_passages_file
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"No process found listening on port {port}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not check process on port {port}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_process_listening_on_port(proc, port: int) -> bool:
|
||||||
|
"""Check if a process is listening on the given port."""
|
||||||
|
try:
|
||||||
|
connections = proc.net_connections()
|
||||||
|
for conn in connections:
|
||||||
|
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _check_cmdline_matches_config(
|
||||||
|
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if command line matches our expected configuration."""
|
||||||
|
cmdline_str = " ".join(cmdline)
|
||||||
|
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
||||||
|
|
||||||
|
# Check if it's our embedding server
|
||||||
|
is_embedding_server = any(
|
||||||
|
server_type in cmdline_str
|
||||||
|
for server_type in [
|
||||||
|
"embedding_server",
|
||||||
|
"leann_backend_diskann.embedding_server",
|
||||||
|
"leann_backend_hnsw.hnsw_embedding_server",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_embedding_server:
|
||||||
|
logger.debug(f"Process on port {port} is not our embedding server")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check model name
|
||||||
|
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
||||||
|
|
||||||
|
# Check passages file if provided
|
||||||
|
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||||
|
|
||||||
|
result = model_matches and passages_matches
|
||||||
|
logger.debug(
|
||||||
|
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected model."""
|
||||||
|
if "--model-name" not in cmdline:
|
||||||
|
return False
|
||||||
|
|
||||||
|
model_idx = cmdline.index("--model-name")
|
||||||
|
if model_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_model = cmdline[model_idx + 1]
|
||||||
|
return actual_model == expected_model
|
||||||
|
|
||||||
|
|
||||||
|
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected passages file."""
|
||||||
|
if "--passages-file" not in cmdline:
|
||||||
|
return False # Expected but not found
|
||||||
|
|
||||||
|
passages_idx = cmdline.index("--passages-file")
|
||||||
|
if passages_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_passages = cmdline[passages_idx + 1]
|
||||||
|
expected_path = Path(expected_passages_file).resolve()
|
||||||
|
actual_path = Path(actual_passages).resolve()
|
||||||
|
return actual_path == expected_path
|
||||||
|
|
||||||
|
|
||||||
|
def _find_compatible_port_or_next_available(
|
||||||
|
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||||
|
) -> tuple[int, bool]:
|
||||||
|
"""
|
||||||
|
Find a port that either has a compatible server or is available.
|
||||||
|
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||||
|
"""
|
||||||
|
for port in range(start_port, start_port + max_attempts):
|
||||||
|
if not _check_port(port):
|
||||||
|
# Port is available
|
||||||
|
return port, False
|
||||||
|
|
||||||
|
# Port is in use, check if it's compatible
|
||||||
|
if _check_process_matches_config(port, model_name, passages_file):
|
||||||
|
logger.info(f"Found compatible server on port {port}")
|
||||||
|
return port, True
|
||||||
|
else:
|
||||||
|
logger.info(f"Port {port} has incompatible server, trying next port...")
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingServerManager:
|
||||||
|
"""
|
||||||
|
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, backend_module_name: str):
|
||||||
|
"""
|
||||||
|
Initializes the manager for a specific backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend_module_name (str): The full module name of the backend's server script.
|
||||||
|
e.g., "leann_backend_diskann.embedding_server"
|
||||||
|
"""
|
||||||
|
self.backend_module_name = backend_module_name
|
||||||
|
self.server_process: subprocess.Popen | None = None
|
||||||
|
self.server_port: int | None = None
|
||||||
|
self._atexit_registered = False
|
||||||
|
|
||||||
|
def start_server(
|
||||||
|
self,
|
||||||
|
port: int,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[bool, int]:
|
||||||
|
"""Start the embedding server."""
|
||||||
|
passages_file = kwargs.get("passages_file")
|
||||||
|
|
||||||
|
# Check if we have a compatible server already running
|
||||||
|
if self._has_compatible_running_server(model_name, passages_file):
|
||||||
|
logger.info("Found compatible running server!")
|
||||||
|
return True, port
|
||||||
|
|
||||||
|
# For Colab environment, use a different strategy
|
||||||
|
if _is_colab_environment():
|
||||||
|
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||||
|
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
|
# Find a compatible port or next available
|
||||||
|
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||||
|
port, model_name, passages_file
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_compatible:
|
||||||
|
logger.info(f"Found compatible server on port {actual_port}")
|
||||||
|
return True, actual_port
|
||||||
|
|
||||||
|
# Start a new server
|
||||||
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
|
def _start_server_colab(
|
||||||
|
self,
|
||||||
|
port: int,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[bool, int]:
|
||||||
|
"""Start server with Colab-specific configuration."""
|
||||||
|
# Try to find an available port
|
||||||
|
try:
|
||||||
|
actual_port = _get_available_port(port)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.error("No available ports found")
|
||||||
|
return False, port
|
||||||
|
|
||||||
|
logger.info(f"Starting server on port {actual_port} for Colab environment")
|
||||||
|
|
||||||
|
# Use a simpler startup strategy for Colab
|
||||||
|
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# In Colab, we'll use a more direct approach
|
||||||
|
self._launch_server_process_colab(command, actual_port)
|
||||||
|
return self._wait_for_server_ready_colab(actual_port)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||||
|
return False, actual_port
|
||||||
|
|
||||||
|
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
|
||||||
|
"""Check if we have a compatible running server."""
|
||||||
|
if not (self.server_process and self.server_process.poll() is None and self.server_port):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||||
|
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info("Existing server process is incompatible. Should start a new server.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _start_new_server(
|
||||||
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
|
) -> tuple[bool, int]:
|
||||||
|
"""Start a new embedding server on the given port."""
|
||||||
|
logger.info(f"Starting embedding server on port {port}...")
|
||||||
|
|
||||||
|
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._launch_server_process(command, port)
|
||||||
|
return self._wait_for_server_ready(port)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start embedding server: {e}")
|
||||||
|
return False, port
|
||||||
|
|
||||||
|
def _build_server_command(
|
||||||
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
|
) -> list:
|
||||||
|
"""Build the command to start the embedding server."""
|
||||||
|
command = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
self.backend_module_name,
|
||||||
|
"--zmq-port",
|
||||||
|
str(port),
|
||||||
|
"--model-name",
|
||||||
|
model_name,
|
||||||
|
]
|
||||||
|
|
||||||
|
if kwargs.get("passages_file"):
|
||||||
|
# Convert to absolute path to ensure subprocess can find the file
|
||||||
|
passages_file = Path(kwargs["passages_file"]).resolve()
|
||||||
|
command.extend(["--passages-file", str(passages_file)])
|
||||||
|
if embedding_mode != "sentence-transformers":
|
||||||
|
command.extend(["--embedding-mode", embedding_mode])
|
||||||
|
if kwargs.get("distance_metric"):
|
||||||
|
command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||||
|
|
||||||
|
return command
|
||||||
|
|
||||||
|
def _launch_server_process(self, command: list, port: int) -> None:
|
||||||
|
"""Launch the server process."""
|
||||||
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
|
|
||||||
|
# Let server output go directly to console
|
||||||
|
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||||
|
self.server_process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
cwd=project_root,
|
||||||
|
stdout=None, # Direct to console
|
||||||
|
stderr=None, # Direct to console
|
||||||
|
)
|
||||||
|
self.server_port = port
|
||||||
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
|
# Register atexit callback only when we actually start a process
|
||||||
|
if not self._atexit_registered:
|
||||||
|
# Use a lambda to avoid issues with bound methods
|
||||||
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
|
self._atexit_registered = True
|
||||||
|
|
||||||
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
|
"""Wait for the server to be ready."""
|
||||||
|
max_wait, wait_interval = 120, 0.5
|
||||||
|
for _ in range(int(max_wait / wait_interval)):
|
||||||
|
if _check_port(port):
|
||||||
|
logger.info("Embedding server is ready!")
|
||||||
|
return True, port
|
||||||
|
|
||||||
|
if self.server_process and self.server_process.poll() is not None:
|
||||||
|
logger.error("Server terminated during startup.")
|
||||||
|
return False, port
|
||||||
|
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
|
||||||
|
logger.error(f"Server failed to start within {max_wait} seconds.")
|
||||||
|
self.stop_server()
|
||||||
|
return False, port
|
||||||
|
|
||||||
|
def stop_server(self):
|
||||||
|
"""Stops the embedding server process if it's running."""
|
||||||
|
if not self.server_process:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.server_process.poll() is not None:
|
||||||
|
# Process already terminated
|
||||||
|
self.server_process = None
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
|
)
|
||||||
|
self.server_process.terminate()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.server_process.wait(timeout=5)
|
||||||
|
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.warning(
|
||||||
|
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||||
|
)
|
||||||
|
self.server_process.kill()
|
||||||
|
|
||||||
|
# Clean up process resources to prevent resource tracker warnings
|
||||||
|
try:
|
||||||
|
self.server_process.wait() # Ensure process is fully cleaned up
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.server_process = None
|
||||||
|
|
||||||
|
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||||
|
"""Launch the server process with Colab-specific settings."""
|
||||||
|
logger.info(f"Colab Command: {' '.join(command)}")
|
||||||
|
|
||||||
|
# In Colab, we need to be more careful about process management
|
||||||
|
self.server_process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
self.server_port = port
|
||||||
|
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
|
# Register atexit callback
|
||||||
|
if not self._atexit_registered:
|
||||||
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
|
self._atexit_registered = True
|
||||||
|
|
||||||
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
|
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||||
|
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
|
||||||
|
|
||||||
|
for _ in range(int(max_wait / wait_interval)):
|
||||||
|
if _check_port(port):
|
||||||
|
logger.info("Colab embedding server is ready!")
|
||||||
|
return True, port
|
||||||
|
|
||||||
|
if self.server_process and self.server_process.poll() is not None:
|
||||||
|
# Check for error output
|
||||||
|
stdout, stderr = self.server_process.communicate()
|
||||||
|
logger.error("Colab server terminated during startup.")
|
||||||
|
logger.error(f"stdout: {stdout}")
|
||||||
|
logger.error(f"stderr: {stderr}")
|
||||||
|
return False, port
|
||||||
|
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
|
||||||
|
logger.error(f"Colab server failed to start within {max_wait} seconds.")
|
||||||
|
self.stop_server()
|
||||||
|
return False, port
|
||||||
@@ -1,59 +1,105 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
class LeannBackendBuilderInterface(ABC):
|
class LeannBackendBuilderInterface(ABC):
|
||||||
"""用于构建索引的后端接口"""
|
"""Backend interface for building indexes"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(self, data: np.ndarray, index_path: str, **kwargs) -> None:
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
|
||||||
"""构建索引
|
"""Build index
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 向量数据 (N, D)
|
data: Vector data (N, D)
|
||||||
index_path: 索引保存路径
|
ids: List of string IDs for each vector
|
||||||
**kwargs: 后端特定的构建参数
|
index_path: Path to save index
|
||||||
|
**kwargs: Backend-specific build parameters
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LeannBackendSearcherInterface(ABC):
|
class LeannBackendSearcherInterface(ABC):
|
||||||
"""用于搜索的后端接口"""
|
"""Backend interface for searching"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
"""初始化搜索器
|
"""Initialize searcher
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_path: 索引文件路径
|
index_path: Path to index file
|
||||||
**kwargs: 后端特定的加载参数
|
**kwargs: Backend-specific loading parameters
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
def _ensure_server_running(self, passages_source_file: str, port: int | None, **kwargs) -> int:
|
||||||
"""搜索最近邻
|
"""Ensure server is running"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: np.ndarray,
|
||||||
|
top_k: int,
|
||||||
|
complexity: int = 64,
|
||||||
|
beam_width: int = 1,
|
||||||
|
prune_ratio: float = 0.0,
|
||||||
|
recompute_embeddings: bool = False,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
zmq_port: int | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Search for nearest neighbors
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 查询向量 (1, D) 或 (B, D)
|
query: Query vectors (B, D) where B is batch size, D is dimension
|
||||||
top_k: 返回的最近邻数量
|
top_k: Number of nearest neighbors to return
|
||||||
**kwargs: 搜索参数
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
|
**kwargs: Backend-specific parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
{"labels": [...], "distances": [...]}
|
{"labels": [...], "distances": [...]}
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_query_embedding(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
use_server_if_available: bool = True,
|
||||||
|
zmq_port: int | None = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query string to embed
|
||||||
|
zmq_port: ZMQ port for embedding server
|
||||||
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Query embedding as numpy array with shape (1, D)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LeannBackendFactoryInterface(ABC):
|
class LeannBackendFactoryInterface(ABC):
|
||||||
"""后端工厂接口"""
|
"""Backend factory interface"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def builder(**kwargs) -> LeannBackendBuilderInterface:
|
def builder(**kwargs) -> LeannBackendBuilderInterface:
|
||||||
"""创建 Builder 实例"""
|
"""Create Builder instance"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
||||||
"""创建 Searcher 实例"""
|
"""Create Searcher instance"""
|
||||||
pass
|
pass
|
||||||
@@ -1,15 +1,41 @@
|
|||||||
# packages/leann-core/src/leann/registry.py
|
# packages/leann-core/src/leann/registry.py
|
||||||
|
|
||||||
from typing import Dict, TYPE_CHECKING
|
import importlib
|
||||||
|
import importlib.metadata
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
|
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_backend(name: str):
|
def register_backend(name: str):
|
||||||
"""A decorator to register a new backend class."""
|
"""A decorator to register a new backend class."""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
print(f"INFO: Registering backend '{name}'")
|
print(f"INFO: Registering backend '{name}'")
|
||||||
BACKEND_REGISTRY[name] = cls
|
BACKEND_REGISTRY[name] = cls
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def autodiscover_backends():
|
||||||
|
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
||||||
|
# print("INFO: Starting backend auto-discovery...")
|
||||||
|
discovered_backends = []
|
||||||
|
for dist in importlib.metadata.distributions():
|
||||||
|
dist_name = dist.metadata["name"]
|
||||||
|
if dist_name.startswith("leann-backend-"):
|
||||||
|
backend_module_name = dist_name.replace("-", "_")
|
||||||
|
discovered_backends.append(backend_module_name)
|
||||||
|
|
||||||
|
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
||||||
|
try:
|
||||||
|
importlib.import_module(backend_module_name)
|
||||||
|
# Registration message is printed by the decorator
|
||||||
|
except ImportError:
|
||||||
|
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
|
pass
|
||||||
|
# print("INFO: Backend auto-discovery finished.")
|
||||||
|
|||||||
197
packages/leann-core/src/leann/searcher_base.py
Normal file
197
packages/leann-core/src/leann/searcher_base.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .embedding_server_manager import EmbeddingServerManager
|
||||||
|
from .interface import LeannBackendSearcherInterface
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for Leann searchers, containing common logic for
|
||||||
|
loading metadata, managing embedding servers, and handling file paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, backend_module_name: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes the BaseSearcher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the Leann index file (e.g., '.../my_index.leann').
|
||||||
|
backend_module_name: The specific embedding server module to use
|
||||||
|
(e.g., 'leann_backend_hnsw.hnsw_embedding_server').
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
|
self.index_path = Path(index_path)
|
||||||
|
self.index_dir = self.index_path.parent
|
||||||
|
self.meta = kwargs.get("meta", self._load_meta())
|
||||||
|
|
||||||
|
if not self.meta:
|
||||||
|
raise ValueError("Searcher requires metadata from .meta.json.")
|
||||||
|
|
||||||
|
self.dimensions = self.meta.get("dimensions")
|
||||||
|
if not self.dimensions:
|
||||||
|
raise ValueError("Dimensions not found in Leann metadata.")
|
||||||
|
|
||||||
|
self.embedding_model = self.meta.get("embedding_model")
|
||||||
|
if not self.embedding_model:
|
||||||
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
||||||
|
|
||||||
|
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
|
||||||
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name=backend_module_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_meta(self) -> dict[str, Any]:
|
||||||
|
"""Loads the metadata file associated with the index."""
|
||||||
|
# This is the corrected logic for finding the meta file.
|
||||||
|
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
|
if not meta_path.exists():
|
||||||
|
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> int:
|
||||||
|
"""
|
||||||
|
Ensures the embedding server is running if recompute is needed.
|
||||||
|
This is a helper for subclasses.
|
||||||
|
"""
|
||||||
|
if not self.embedding_model:
|
||||||
|
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
|
||||||
|
|
||||||
|
# Get distance_metric from meta if not provided in kwargs
|
||||||
|
distance_metric = (
|
||||||
|
kwargs.get("distance_metric")
|
||||||
|
or self.meta.get("backend_kwargs", {}).get("distance_metric")
|
||||||
|
or "mips"
|
||||||
|
)
|
||||||
|
|
||||||
|
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||||
|
port=port,
|
||||||
|
model_name=self.embedding_model,
|
||||||
|
embedding_mode=self.embedding_mode,
|
||||||
|
passages_file=passages_source_file,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
|
)
|
||||||
|
if not server_started:
|
||||||
|
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||||
|
|
||||||
|
return actual_port
|
||||||
|
|
||||||
|
def compute_query_embedding(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
use_server_if_available: bool = True,
|
||||||
|
zmq_port: int = 5557,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embedding for a query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query string to embed
|
||||||
|
zmq_port: ZMQ port for embedding server
|
||||||
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Query embedding as numpy array
|
||||||
|
"""
|
||||||
|
# Try to use embedding server if available and requested
|
||||||
|
if use_server_if_available:
|
||||||
|
try:
|
||||||
|
# TODO: Maybe we can directly use this port here?
|
||||||
|
# For this internal method, it's ok to assume that the server is running
|
||||||
|
# on that port?
|
||||||
|
|
||||||
|
# Ensure we have a server with passages_file for compatibility
|
||||||
|
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
|
# Convert to absolute path to ensure server can find it
|
||||||
|
zmq_port = self._ensure_server_running(
|
||||||
|
str(passages_source_file.resolve()), zmq_port
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
|
0:1
|
||||||
|
] # Return (1, D) shape
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Embedding server failed: {e}")
|
||||||
|
print("⏭️ Falling back to direct model loading...")
|
||||||
|
|
||||||
|
# Fallback to direct computation
|
||||||
|
from .embedding_compute import compute_embeddings
|
||||||
|
|
||||||
|
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||||
|
|
||||||
|
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||||
|
"""Compute embeddings using the ZMQ embedding server."""
|
||||||
|
import msgpack
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
try:
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
|
||||||
|
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||||
|
|
||||||
|
# Send embedding request
|
||||||
|
request = chunks
|
||||||
|
request_bytes = msgpack.packb(request)
|
||||||
|
socket.send(request_bytes)
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response_bytes = socket.recv()
|
||||||
|
response = msgpack.unpackb(response_bytes)
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
|
# Convert response to numpy array
|
||||||
|
if isinstance(response, list) and len(response) > 0:
|
||||||
|
return np.array(response, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Invalid response from embedding server")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: np.ndarray,
|
||||||
|
top_k: int,
|
||||||
|
complexity: int = 64,
|
||||||
|
beam_width: int = 1,
|
||||||
|
prune_ratio: float = 0.0,
|
||||||
|
recompute_embeddings: bool = False,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
zmq_port: int | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Search for the top_k nearest neighbors of the query vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query vectors (B, D) where B is batch size, D is dimension
|
||||||
|
top_k: Number of nearest neighbors to return
|
||||||
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
|
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Ensures the embedding server is stopped when the searcher is destroyed."""
|
||||||
|
if hasattr(self, "embedding_server_manager"):
|
||||||
|
self.embedding_server_manager.stop_server()
|
||||||
39
packages/leann/README.md
Normal file
39
packages/leann/README.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# LEANN - The smallest vector index in the world
|
||||||
|
|
||||||
|
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Default installation (HNSW backend, recommended)
|
||||||
|
uv pip install leann
|
||||||
|
|
||||||
|
# With DiskANN backend (for large-scale deployments)
|
||||||
|
uv pip install leann[diskann]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from pathlib import Path
|
||||||
|
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||||
|
|
||||||
|
# Build an index
|
||||||
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
|
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
|
||||||
|
# Search
|
||||||
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
|
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||||
|
|
||||||
|
# Chat with your data
|
||||||
|
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||||
|
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License
|
||||||
12
packages/leann/__init__.py
Normal file
12
packages/leann/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
LEANN - Low-storage Embedding Approximation for Neural Networks
|
||||||
|
|
||||||
|
A revolutionary vector database that democratizes personal AI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
# Re-export main API from leann-core
|
||||||
|
from leann_core import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
__all__ = ["LeannBuilder", "LeannChat", "LeannSearcher"]
|
||||||
40
packages/leann/pyproject.toml
Normal file
40
packages/leann/pyproject.toml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "leann"
|
||||||
|
version = "0.1.16"
|
||||||
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
license = { text = "MIT" }
|
||||||
|
authors = [
|
||||||
|
{ name = "LEANN Team" }
|
||||||
|
]
|
||||||
|
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default installation: core + hnsw
|
||||||
|
dependencies = [
|
||||||
|
"leann-core>=0.1.0",
|
||||||
|
"leann-backend-hnsw>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
diskann = [
|
||||||
|
"leann-backend-diskann>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Repository = "https://github.com/yichuan-w/LEANN"
|
||||||
|
Issues = "https://github.com/yichuan-w/LEANN/issues"
|
||||||
140
packages/wechat-exporter/main.py
Normal file
140
packages/wechat-exporter/main.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import xml.etree.ElementTree as ElementTree
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import typer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_path(s: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove invalid characters to sanitize a path.
|
||||||
|
:param s: str to sanitize
|
||||||
|
:returns: sanitized str
|
||||||
|
"""
|
||||||
|
ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace(" ", "")
|
||||||
|
for i in ban_chars:
|
||||||
|
s = s.replace(i, "")
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def process_history(history: str):
|
||||||
|
if history.startswith("<?xml") or history.startswith("<msg>"):
|
||||||
|
try:
|
||||||
|
root = ElementTree.fromstring(history)
|
||||||
|
title = root.find(".//title").text if root.find(".//title") is not None else None
|
||||||
|
quoted = (
|
||||||
|
root.find(".//refermsg/content").text
|
||||||
|
if root.find(".//refermsg/content") is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if title and quoted:
|
||||||
|
return {"title": title, "quoted": process_history(quoted)}
|
||||||
|
if title:
|
||||||
|
return title
|
||||||
|
except Exception:
|
||||||
|
return history
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
def get_message(history: dict | str):
|
||||||
|
if isinstance(history, dict):
|
||||||
|
if "title" in history:
|
||||||
|
return history["title"]
|
||||||
|
else:
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
def export_chathistory(user_id: str):
|
||||||
|
res = requests.get(
|
||||||
|
"http://localhost:48065/wechat/chatlog",
|
||||||
|
params={"userId": user_id, "count": 100000},
|
||||||
|
).json()
|
||||||
|
for i in range(len(res["chatLogs"])):
|
||||||
|
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
|
||||||
|
res["chatLogs"][i]["message"] = get_message(res["chatLogs"][i]["content"])
|
||||||
|
return res["chatLogs"]
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]):
|
||||||
|
"""
|
||||||
|
Export all users' chat history to json files.
|
||||||
|
"""
|
||||||
|
if not dest.is_dir():
|
||||||
|
if not dest.exists():
|
||||||
|
inp = typer.prompt("Destination path does not exist, create it? (y/n)")
|
||||||
|
if inp.lower() == "y":
|
||||||
|
dest.mkdir(parents=True)
|
||||||
|
else:
|
||||||
|
typer.echo("Aborted.", err=True)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
typer.echo("Destination path is not a directory!", err=True)
|
||||||
|
return
|
||||||
|
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
||||||
|
|
||||||
|
exported_count = 0
|
||||||
|
for user in tqdm(all_users):
|
||||||
|
try:
|
||||||
|
usr_chatlog = export_chathistory(user["arg"])
|
||||||
|
|
||||||
|
# Only write file if there are messages
|
||||||
|
if len(usr_chatlog) > 0:
|
||||||
|
out_path = dest / get_safe_path((user["title"] or "") + "-" + user["arg"] + ".json")
|
||||||
|
with open(out_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(usr_chatlog, f, ensure_ascii=False, indent=2)
|
||||||
|
exported_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting {user.get('title', 'Unknown')}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Exported {exported_count} users' chat history to {dest} in json.")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_sqlite(
|
||||||
|
dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path(
|
||||||
|
"chatlog.db"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Export all users' chat history to a sqlite database.
|
||||||
|
"""
|
||||||
|
connection = sqlite3.connect(dest)
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)"
|
||||||
|
)
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)")
|
||||||
|
cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)")
|
||||||
|
|
||||||
|
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
||||||
|
for user in tqdm(all_users):
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)",
|
||||||
|
(user["arg"], user["title"]),
|
||||||
|
)
|
||||||
|
usr_chatlog = export_chathistory(user["arg"])
|
||||||
|
for msg in usr_chatlog:
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
user["arg"],
|
||||||
|
msg["fromUser"],
|
||||||
|
msg["toUser"],
|
||||||
|
msg["message"],
|
||||||
|
msg["createTime"],
|
||||||
|
str(msg["content"]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
BIN
packages/wechat-exporter/wechattweak-cli
Executable file
BIN
packages/wechat-exporter/wechattweak-cli
Executable file
Binary file not shown.
121
pyproject.toml
121
pyproject.toml
@@ -5,11 +5,10 @@ build-backend = "setuptools.build_meta"
|
|||||||
[project]
|
[project]
|
||||||
name = "leann-workspace"
|
name = "leann-workspace"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.9"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core",
|
"leann-core",
|
||||||
"leann-backend-diskann",
|
|
||||||
"leann-backend-hnsw",
|
"leann-backend-hnsw",
|
||||||
"numpy>=1.26.0",
|
"numpy>=1.26.0",
|
||||||
"torch",
|
"torch",
|
||||||
@@ -21,26 +20,62 @@ dependencies = [
|
|||||||
"colorama",
|
"colorama",
|
||||||
"boto3",
|
"boto3",
|
||||||
"protobuf==4.25.3",
|
"protobuf==4.25.3",
|
||||||
"sglang[all]",
|
"sglang",
|
||||||
"ollama",
|
"ollama",
|
||||||
"requests>=2.25.0",
|
"requests>=2.25.0",
|
||||||
"sentence-transformers>=2.2.0",
|
"sentence-transformers>=2.2.0",
|
||||||
"openai>=1.0.0",
|
"openai>=1.0.0",
|
||||||
|
# PDF parsing dependencies - essential for document processing
|
||||||
"PyPDF2>=3.0.0",
|
"PyPDF2>=3.0.0",
|
||||||
|
"pdfplumber>=0.11.0",
|
||||||
|
"pymupdf>=1.26.0",
|
||||||
|
"pypdfium2>=4.30.0",
|
||||||
|
# LlamaIndex core and readers - updated versions
|
||||||
"llama-index>=0.12.44",
|
"llama-index>=0.12.44",
|
||||||
"llama-index-readers-docling",
|
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||||
"llama-index-node-parser-docling",
|
# "llama-index-readers-docling", # Requires Python >= 3.10
|
||||||
|
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
||||||
|
"llama-index-vector-stores-faiss>=0.4.0",
|
||||||
|
"llama-index-embeddings-huggingface>=0.5.5",
|
||||||
|
# Other dependencies
|
||||||
"ipykernel==6.29.5",
|
"ipykernel==6.29.5",
|
||||||
"msgpack>=1.1.1",
|
"msgpack>=1.1.1",
|
||||||
|
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||||
|
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||||
|
"psutil>=5.8.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=7.0",
|
"pytest>=7.0",
|
||||||
"pytest-cov>=4.0",
|
"pytest-cov>=4.0",
|
||||||
|
"pytest-xdist>=3.0", # For parallel test execution
|
||||||
"black>=23.0",
|
"black>=23.0",
|
||||||
"ruff>=0.1.0",
|
"ruff>=0.1.0",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
|
"huggingface-hub>=0.20.0",
|
||||||
|
"pre-commit>=3.5.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
test = [
|
||||||
|
"pytest>=7.0",
|
||||||
|
"pytest-timeout>=2.0",
|
||||||
|
"llama-index-core>=0.12.0",
|
||||||
|
"llama-index-readers-file>=0.4.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
"sentence-transformers>=2.2.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
diskann = [
|
||||||
|
"leann-backend-diskann",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add a new optional dependency group for document processing
|
||||||
|
documents = [
|
||||||
|
"beautifulsoup4>=4.13.0", # For HTML parsing
|
||||||
|
"python-docx>=0.8.11", # For Word documents
|
||||||
|
"openpyxl>=3.1.0", # For Excel files
|
||||||
|
"pandas>=2.2.0", # For data processing
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
@@ -51,3 +86,79 @@ py-modules = []
|
|||||||
leann-core = { path = "packages/leann-core", editable = true }
|
leann-core = { path = "packages/leann-core", editable = true }
|
||||||
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py310"
|
||||||
|
line-length = 100
|
||||||
|
extend-exclude = [
|
||||||
|
"third_party",
|
||||||
|
"*.egg-info",
|
||||||
|
"__pycache__",
|
||||||
|
".git",
|
||||||
|
".venv",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"N", # pep8-naming
|
||||||
|
"RUF", # ruff-specific rules
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by formatter)
|
||||||
|
"B008", # do not perform function calls in argument defaults
|
||||||
|
"B904", # raise without from
|
||||||
|
"N812", # lowercase imported as non-lowercase
|
||||||
|
"N806", # variable in function should be lowercase
|
||||||
|
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"test/**/*.py" = ["E402"] # module level import not at top of file (common in tests)
|
||||||
|
"examples/**/*.py" = ["E402"] # module level import not at top of file (common in examples)
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
|
indent-style = "space"
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
|
line-ending = "auto"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"ruff>=0.12.4",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.lychee]
|
||||||
|
accept = ["200", "403", "429", "503"]
|
||||||
|
timeout = 20
|
||||||
|
max_retries = 2
|
||||||
|
exclude = ["localhost", "127.0.0.1", "example.com"]
|
||||||
|
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
|
||||||
|
scheme = ["https", "http"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
python_classes = ["Test*"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
markers = [
|
||||||
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
|
"openai: marks tests that require OpenAI API key",
|
||||||
|
]
|
||||||
|
timeout = 600
|
||||||
|
addopts = [
|
||||||
|
"-v",
|
||||||
|
"--tb=short",
|
||||||
|
"--strict-markers",
|
||||||
|
"--disable-warnings",
|
||||||
|
]
|
||||||
|
env = [
|
||||||
|
"HF_HUB_DISABLE_SYMLINKS=1",
|
||||||
|
"TOKENIZERS_PARALLELISM=false",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
import faiss
|
|
||||||
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
|
||||||
|
|
||||||
# print total number of nodes
|
|
||||||
print(hnsw_index.ntotal)
|
|
||||||
|
|
||||||
# print stats of the graph
|
|
||||||
print(hnsw_index.hnsw.print_neighbor_stats(0))
|
|
||||||
|
|
||||||
|
|
||||||
# save_degree_distribution
|
|
||||||
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
import faiss
|
|
||||||
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
|
||||||
|
|
||||||
# print total number of nodes
|
|
||||||
print(nsg_index.ntotal)
|
|
||||||
|
|
||||||
# print stats of the graph
|
|
||||||
print(nsg_index.nsg.print_neighbor_stats(0))
|
|
||||||
|
|
||||||
# save degree distribution
|
|
||||||
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import time
|
|
||||||
|
|
||||||
# import bitsandbytes as bnb
|
|
||||||
from bitsandbytes.nn import Linear8bitLt
|
|
||||||
|
|
||||||
# set default to half
|
|
||||||
import torch
|
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
|
|
||||||
M = 2048
|
|
||||||
N = 2048
|
|
||||||
|
|
||||||
bsz = 2048
|
|
||||||
import torch_int
|
|
||||||
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
|
|
||||||
|
|
||||||
fp16_model = nn.Sequential(
|
|
||||||
nn.Linear(M, N),
|
|
||||||
# nn.Linear(2048, 2048)
|
|
||||||
)
|
|
||||||
|
|
||||||
int8_model = nn.Sequential(
|
|
||||||
Linear8bitLt(M, N, has_fp16_weights=False),
|
|
||||||
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
int8_model.load_state_dict(fp16_model.state_dict())
|
|
||||||
int8_model = int8_model.to(0) # Quantization happens here
|
|
||||||
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
|
|
||||||
|
|
||||||
# Create random input tensor
|
|
||||||
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
|
|
||||||
|
|
||||||
# Speed test function
|
|
||||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
# Actual timing
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for _ in range(num_iterations):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
avg_time = (end_time - start_time) / num_iterations
|
|
||||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
|
||||||
return avg_time
|
|
||||||
|
|
||||||
# Run speed tests
|
|
||||||
with torch.no_grad(): # Disable gradient calculation for inference
|
|
||||||
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
|
|
||||||
int8_time = speed_test(int8_model, input_tensor, "INT8")
|
|
||||||
|
|
||||||
# Calculate speedup
|
|
||||||
speedup = fp16_time / int8_time
|
|
||||||
print(f"INT8 is {speedup:.2f}x faster than FP16")
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
n,d,seqlen,bs,latency,h,flop,io,intensity,throughput,series
|
|
||||||
3,256,256,2048,0.009623501679245285,768,618475290624,167.48502132816208,3692720015.912285,64267177503366.266,dense
|
|
||||||
3,256,256,1024,0.004853848615384615,768,309237645312,166.15392854317415,1861151572.059558,63709783682138.234,dense
|
|
||||||
3,256,256,512,0.0024687246971962615,768,154618822656,163.57953256539062,945221081.3366361,62631051097597.516,dense
|
|
||||||
3,256,256,256,0.0012845360838052097,768,77309411328,157.64931990085577,490388486.1451936,60184694149645.54,dense
|
|
||||||
3,256,256,128,0.0006901147179878049,768,38654705664,147.57393422494675,261934506.70684624,56012000116019.945,dense
|
|
||||||
3,256,256,64,0.0003363830693015702,768,19327352832,153.1328437752606,126212981.84970059,57456378146882.51,dense
|
|
||||||
3,256,256,32,0.00018671159748991485,768,9663676416,141.10249365427362,68486928.65540518,51757237075334.75,dense
|
|
||||||
3,256,256,16,0.00012353640857142858,768,4831838208,111.40488993609125,43371868.24359184,39112665358133.98,dense
|
|
||||||
3,256,256,8,9.774760007849294e-05,768,2415919104,76.43260800265766,31608487.09906635,24715891766754.14,dense
|
|
||||||
3,256,256,4,6.672271167474822e-05,768,1207959552,64.82614227498455,18633833.660438772,18104173551704.773,dense
|
|
||||||
3,256,256,2,4.9758770289855074e-05,768,603979776,55.317122669351576,10918495.880745342,12138157202874.861,dense
|
|
||||||
3,256,1,2048,9.785507940251571e-05,768,2415919104,76.34865809334705,31643242.518371396,24688745017132.86,dense
|
|
||||||
3,256,1,1024,6.692813470149253e-05,768,1207959552,64.62717090938949,18691202.70936228,18048606275785.867,dense
|
|
||||||
3,256,1,512,4.9680950036205655e-05,768,603979776,55.40377142534654,10901419.893658841,12157170415618.898,dense
|
|
||||||
3,256,1,256,4.2781118741058655e-05,768,301989888,45.95672244805227,6571179.83862661,7058952568020.829,dense
|
|
||||||
3,256,1,128,5.0662328255350016e-05,768,150994944,31.046026784880404,4863583.512513602,2980418571348.519,dense
|
|
||||||
3,256,1,64,4.475009253945481e-05,768,75497472,30.75426042497223,2454862.219307235,1687090857598.4766,dense
|
|
||||||
3,256,1,32,4.51682671454219e-05,768,37748736,28.29313765537115,1334201.1218340008,835735758435.5786,dense
|
|
||||||
3,256,1,16,5.03585186661834e-05,768,18874368,24.401035466223117,773506.846712577,374799904761.1871,dense
|
|
||||||
3,256,1,8,5.023459565217391e-05,768,9437184,23.972005435021096,393675.19858030166,187862246674.45105,dense
|
|
||||||
3,256,1,4,5.053219391083726e-05,768,4718592,23.58765586356967,200044.97383259286,93377936614.54384,dense
|
|
||||||
3,256,1,2,4.4607398995335484e-05,768,2359296,26.58285456464288,88752.54515134107,52890239133.797226,dense
|
|
||||||
12,256,256,2048,0.14480779847058822,3072,9895604649984,44.620009282941716,221775046868.20184,68336130750540.26,dense
|
|
||||||
12,256,256,1024,0.07254347629166667,3072,4947802324992,44.664248332585096,110777691547.58836,68204648824643.82,dense
|
|
||||||
12,256,256,512,0.036310761444444443,3072,2473901162496,44.876147984203506,55127306456.13385,68131349056975.164,dense
|
|
||||||
12,256,256,256,0.01821551906896552,3072,1236950581248,45.24607467289738,27338295977.947884,67906414116709.98,dense
|
|
||||||
12,256,256,128,0.009229417903030302,3072,618475290624,45.67217092440895,13541622351.335684,67011299859001.46,dense
|
|
||||||
12,256,256,64,0.004754550595394737,3072,309237645312,46.31372736116993,6677019167.566916,65040352207320.695,dense
|
|
||||||
12,256,256,32,0.002405752659340659,3072,154618822656,49.68826015254682,3111777755.5766335,64270456921525.82,dense
|
|
||||||
12,256,256,16,0.0012287219045005488,3072,77309411328,56.323579604557374,1372594069.3184311,62918558743709.18,dense
|
|
||||||
12,256,256,8,0.0006206816149425287,3072,38654705664,70.95456179103653,544781120.315271,62277832520589.78,dense
|
|
||||||
12,256,256,4,0.0003875502697142857,3072,19327352832,81.16954743236613,238110885.71245712,49870569942445.75,dense
|
|
||||||
12,256,256,2,0.00027502018627941914,3072,9663676416,91.50537035282076,105607751.53129694,35138062215483.168,dense
|
|
||||||
12,256,1,2048,0.0006202853873290136,3072,38654705664,70.99988634205897,544433345.6784943,62317614526515.766,dense
|
|
||||||
12,256,1,1024,0.00038721467732724153,3072,19327352832,81.2398957010995,237904697.74985722,49913791918755.53,dense
|
|
||||||
12,256,1,512,0.000274364799,3072,9663676416,91.72395326121995,105356082.81599998,35221998052308.45,dense
|
|
||||||
12,256,1,256,0.00012488918589482266,3072,4831838208,176.31707535146046,27404255.647778228,38689003962834.75,dense
|
|
||||||
12,256,1,128,8.976711102514506e-05,3072,2415919104,227.78088507574267,10606329.425740216,26913187652026.21,dense
|
|
||||||
12,256,1,64,8.715176287471176e-05,3072,1207959552,225.59268282689945,5354604.31102229,13860414432884.701,dense
|
|
||||||
12,256,1,32,8.523013435114503e-05,3072,603979776,226.06539514085782,2671703.8033338524,7086458100741.991,dense
|
|
||||||
12,256,1,16,7.901561645904116e-05,3072,301989888,241.35704882952732,1251216.3595988373,3821901309300.556,dense
|
|
||||||
12,256,1,8,7.827949114210329e-05,3072,150994944,242.37091635608994,622991.1833900034,1928920867994.581,dense
|
|
||||||
12,256,1,4,7.779445951035782e-05,3072,75497472,243.25022783249054,310369.58391664835,970473636235.5986,dense
|
|
||||||
12,256,1,2,7.758845406626506e-05,3072,37748736,243.57933441822672,154975.11761480253,486525172518.07056,dense
|
|
||||||
3,256,256,2048,0.00507974918466899,768,206158430208,475.59810852303485,433471930.42508715,40584371927298.98,qk_init
|
|
||||||
3,256,256,1024,0.0025616677649325623,768,103079215104,471.5519977009198,218595649.27424532,40239103803811.82,qk_init
|
|
||||||
3,256,256,512,0.0013029336670480549,768,51539607552,463.55374128015677,111183672.92143403,39556585922573.38,qk_init
|
|
||||||
3,256,256,256,0.0006738189029345373,768,25769803776,448.1766342333362,57499213.050413854,38244406121244.69,qk_init
|
|
||||||
3,256,256,128,0.000358254672959467,768,12884901888,421.47375986100144,30571065.425874516,35965760841472.125,qk_init
|
|
||||||
3,256,256,64,0.0002007051105022831,768,6442450944,376.1611839930762,17126836.096194826,32099087700742.5,qk_init
|
|
||||||
3,256,256,32,0.00012189697230142565,768,3221225472,309.6773881032524,10401874.969721656,26425803784810.87,qk_init
|
|
||||||
3,256,256,16,8.453561698040722e-05,768,1610612736,223.2711923587723,7213705.982328083,19052475081281.902,qk_init
|
|
||||||
3,256,256,8,6.407660705009276e-05,768,805306368,147.2797083750448,5467870.468274581,12567868448003.822,qk_init
|
|
||||||
3,256,256,4,5.036328747284576e-05,768,402653184,93.69110391262903,4297667.197682838,7994974200544.344,qk_init
|
|
||||||
3,256,256,2,4.5488761135057476e-05,768,201326592,51.865470527877875,3881707.616858238,4425853485045.578,qk_init
|
|
||||||
12,256,256,2048,0.020202365999999996,3072,824633720832,478.3437947812648,1723935231.9999998,40818670488001.266,qk_init
|
|
||||||
12,256,256,1024,0.010124155888157895,3072,412316860416,477.2583770318811,863927969.1228071,40726048173387.19,qk_init
|
|
||||||
12,256,256,512,0.005085633937062937,3072,206158430208,475.04777848703077,433974095.9627039,40537410430893.29,qk_init
|
|
||||||
12,256,256,256,0.0025654916853281853,3072,103079215104,470.84913933193053,218921957.14800516,40179126556324.74,qk_init
|
|
||||||
12,256,256,128,0.0013045765704467354,3072,51539607552,462.9699702434292,111323867.34478809,39506770794105.96,qk_init
|
|
||||||
12,256,256,64,0.0006742801519939804,3072,25769803776,447.87005387442576,57538572.970153,38218244597284.33,qk_init
|
|
||||||
12,256,256,32,0.00035831976790671853,3072,12884901888,421.3971919051604,30576620.194706645,35959227042573.69,qk_init
|
|
||||||
12,256,256,16,0.0002005369068918302,3072,6442450944,376.4766953382971,17112482.721436176,32126011335534.68,qk_init
|
|
||||||
12,256,256,8,0.00012179187250509165,3072,3221225472,309.94462293386505,10392906.453767821,26448607823689.82,qk_init
|
|
||||||
12,256,256,4,8.452507263643351e-05,3072,1610612736,223.2990450204527,7212806.198308992,19054851841745.297,qk_init
|
|
||||||
12,256,256,2,6.412381767545489e-05,3072,805306368,147.17127491946468,5471899.108305484,12558615459794.32,qk_init
|
|
||||||
3,256,256,2048,0.0016183739398395718,768,805306368,811597824.0,0.9922480620155039,1265467.7325087283,qk_ar
|
|
||||||
3,256,256,1024,0.0008322699728813558,768,402653184,405798912.0,0.9922480620155039,1230369.9921491416,qk_ar
|
|
||||||
3,256,256,512,0.00043886859397590365,768,201326592,202899456.0,0.9922480620155039,1166636.2255762408,qk_ar
|
|
||||||
3,256,256,256,0.00024185948322147648,768,100663296,101449728.0,0.9922480620155039,1058465.8355760013,qk_ar
|
|
||||||
3,256,256,128,0.00014308985100166944,768,50331648,50724864.0,0.9922480620155039,894542.82818777,qk_ar
|
|
||||||
3,256,256,64,9.382939365815932e-05,768,25165824,25362432.0,0.9922480620155039,682089.028872613,qk_ar
|
|
||||||
3,256,256,32,6.856070612244899e-05,768,12582912,12681216.0,0.9922480620155039,466739.6503012703,qk_ar
|
|
||||||
3,256,256,16,5.452260553129549e-05,768,6291456,6340608.0,0.9922480620155039,293456.26174846216,qk_ar
|
|
||||||
3,256,256,8,4.608557533261417e-05,768,3145728,3170304.0,0.9922480620155039,173590.1080166944,qk_ar
|
|
||||||
3,256,256,4,4.386146957766642e-05,768,1572864,1585152.0,0.9922480620155039,91196.21477609445,qk_ar
|
|
||||||
3,256,256,2,4.330941094420601e-05,768,786432,792576.0,0.9922480620155039,46179.33969539622,qk_ar
|
|
||||||
12,256,256,2048,0.006347041645299144,3072,3221225472,3246391296.0,0.9922480620155039,322670.011392918,qk_ar
|
|
||||||
12,256,256,1024,0.0031943104467592586,3072,1610612736,1623195648.0,0.9922480620155039,320569.96872013,qk_ar
|
|
||||||
12,256,256,512,0.0016183416350267381,3072,805306368,811597824.0,0.9922480620155039,316373.2483416833,qk_ar
|
|
||||||
12,256,256,256,0.0008325934893977947,3072,402653184,405798912.0,0.9922480620155039,307472.9784221131,qk_ar
|
|
||||||
12,256,256,128,0.0004389725746987952,3072,201326592,202899456.0,0.9922480620155039,291589.9702568624,qk_ar
|
|
||||||
12,256,256,64,0.00024191767449664432,3072,100663296,101449728.0,0.9922480620155039,264552.8076159138,qk_ar
|
|
||||||
12,256,256,32,0.0001431546143572621,3072,50331648,50724864.0,0.9922480620155039,223534.53392804778,qk_ar
|
|
||||||
12,256,256,16,9.404283597678917e-05,3072,25165824,25362432.0,0.9922480620155039,170135.23501087292,qk_ar
|
|
||||||
12,256,256,8,6.855550037091989e-05,3072,12582912,12681216.0,0.9922480620155039,116693.773026467,qk_ar
|
|
||||||
12,256,256,4,5.4802094978165945e-05,3072,6291456,6340608.0,0.9922480620155039,72989.91036006316,qk_ar
|
|
||||||
12,256,256,2,4.608510707869206e-05,3072,3145728,3170304.0,0.9922480620155039,43397.96795057727,qk_ar
|
|
||||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 45 KiB |
@@ -1,376 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import AutoModel
|
|
||||||
from tqdm import tqdm
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import math
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BenchmarkConfig:
|
|
||||||
model_path: str
|
|
||||||
batch_sizes: List[int]
|
|
||||||
seq_length: int
|
|
||||||
num_runs: int
|
|
||||||
use_fp16: bool = True
|
|
||||||
use_cuda_graphs: bool = False
|
|
||||||
use_flash_attention: bool = False
|
|
||||||
max_batch_size: int = 256 # Maximum batch size before splitting
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphContainer:
|
|
||||||
"""Container for managing CUDA graphs for different batch sizes."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int):
|
|
||||||
self.model = model
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.max_batch_size = max_batch_size
|
|
||||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
|
||||||
# For CUDA graphs, we always use the actual batch size or max_batch_size
|
|
||||||
effective_batch_size = min(batch_size, self.max_batch_size)
|
|
||||||
|
|
||||||
if effective_batch_size not in self.graphs:
|
|
||||||
self.graphs[effective_batch_size] = CUDAGraphWrapper(
|
|
||||||
self.model, effective_batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[effective_batch_size]
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphWrapper:
|
|
||||||
"""Wrapper for CUDA graph capture and replay."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
|
||||||
self.model = model
|
|
||||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
|
||||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
|
||||||
|
|
||||||
# Warm up
|
|
||||||
self._warmup()
|
|
||||||
|
|
||||||
# Capture graph
|
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(self.graph):
|
|
||||||
self.static_output = self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000, (batch_size, seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(num_warmup):
|
|
||||||
self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
||||||
self.static_input.copy_(input_ids)
|
|
||||||
self.static_attention_mask.copy_(attention_mask)
|
|
||||||
self.graph.replay()
|
|
||||||
return self.static_output
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptimizer:
|
|
||||||
"""Applies various optimizations to the model."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
|
||||||
print("\nApplying model optimizations:")
|
|
||||||
|
|
||||||
# Move to GPU
|
|
||||||
model = model.cuda()
|
|
||||||
print("- Model moved to GPU")
|
|
||||||
|
|
||||||
# FP16
|
|
||||||
if config.use_fp16:
|
|
||||||
model = model.half()
|
|
||||||
print("- Using FP16 precision")
|
|
||||||
|
|
||||||
# Check if using SDPA
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
# No need to do anything as it's automatically enabled
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Flash Attention
|
|
||||||
if config.use_flash_attention:
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attention import FlashAttention
|
|
||||||
print("- Flash Attention 2 available")
|
|
||||||
if hasattr(model.config, "attention_mode"):
|
|
||||||
model.config.attention_mode = "flash_attention_2"
|
|
||||||
print(" - Enabled Flash Attention 2 mode")
|
|
||||||
except ImportError:
|
|
||||||
print("- Flash Attention not available")
|
|
||||||
|
|
||||||
# Optimize LayerNorm
|
|
||||||
try:
|
|
||||||
num_layernorms = 0
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, torch.nn.LayerNorm):
|
|
||||||
module.forward = torch.jit.script(module.forward)
|
|
||||||
num_layernorms += 1
|
|
||||||
if num_layernorms > 0:
|
|
||||||
print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"- LayerNorm optimization failed: {e}")
|
|
||||||
|
|
||||||
# Memory efficient attention
|
|
||||||
try:
|
|
||||||
from xformers.ops import memory_efficient_attention
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
print("- Enabled xformers memory efficient attention")
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
print("- Xformers not available")
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
print("- Model set to eval mode")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
|
||||||
"""Handles accurate GPU timing using CUDA events."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start_event.record()
|
|
||||||
yield
|
|
||||||
self.end_event.record()
|
|
||||||
self.end_event.synchronize()
|
|
||||||
|
|
||||||
def elapsed_time(self) -> float:
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
|
||||||
"""Main benchmark runner."""
|
|
||||||
|
|
||||||
def __init__(self, config: BenchmarkConfig):
|
|
||||||
self.config = config
|
|
||||||
self.model = self._load_model()
|
|
||||||
self.cuda_graphs = (
|
|
||||||
CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size)
|
|
||||||
if config.use_cuda_graphs
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
self.timer = Timer()
|
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
|
||||||
print(f"Loading model from {self.config.model_path}...")
|
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
|
||||||
return ModelOptimizer.optimize(model, self.config)
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000,
|
|
||||||
(batch_size, self.config.seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_inference(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
original_batch_size = input_ids.shape[0]
|
|
||||||
print(f"Original input_ids shape: {input_ids.shape}")
|
|
||||||
|
|
||||||
# Split large batches to avoid OOM
|
|
||||||
max_batch_size = self.config.max_batch_size
|
|
||||||
if original_batch_size > max_batch_size:
|
|
||||||
print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}")
|
|
||||||
total_time = 0
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(0, original_batch_size, max_batch_size):
|
|
||||||
end_idx = min(i + max_batch_size, original_batch_size)
|
|
||||||
batch_slice = input_ids[i:end_idx]
|
|
||||||
mask_slice = attention_mask[i:end_idx]
|
|
||||||
|
|
||||||
print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}")
|
|
||||||
|
|
||||||
# Use CUDA graph if available (with the smaller batch size)
|
|
||||||
chunk_cuda_graph = None
|
|
||||||
if cuda_graph_wrapper is not None:
|
|
||||||
chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0])
|
|
||||||
|
|
||||||
with self.timer.timing():
|
|
||||||
if chunk_cuda_graph is not None:
|
|
||||||
chunk_output = chunk_cuda_graph(batch_slice, mask_slice)
|
|
||||||
else:
|
|
||||||
chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice)
|
|
||||||
|
|
||||||
total_time += self.timer.elapsed_time()
|
|
||||||
outputs.append(chunk_output.last_hidden_state)
|
|
||||||
|
|
||||||
# Combine outputs
|
|
||||||
combined_output = torch.cat(outputs, dim=0)
|
|
||||||
print(f"Combined output shape: {combined_output.shape}")
|
|
||||||
|
|
||||||
# Create a wrapper object similar to model output to maintain consistency
|
|
||||||
class DummyOutput:
|
|
||||||
def __init__(self, hidden_states):
|
|
||||||
self.last_hidden_state = hidden_states
|
|
||||||
|
|
||||||
output = DummyOutput(combined_output)
|
|
||||||
return total_time, output
|
|
||||||
else:
|
|
||||||
# Process normally for small batches
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
|
||||||
if cuda_graph_wrapper is not None:
|
|
||||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
|
||||||
else:
|
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
|
||||||
return self.timer.elapsed_time(), output
|
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for batch_size in self.config.batch_sizes:
|
|
||||||
print(f"\nTesting batch size: {batch_size}")
|
|
||||||
times = []
|
|
||||||
|
|
||||||
# Get or create CUDA graph for this batch size
|
|
||||||
cuda_graph_wrapper = None
|
|
||||||
if self.cuda_graphs is not None:
|
|
||||||
if batch_size <= self.config.max_batch_size:
|
|
||||||
cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size)
|
|
||||||
else:
|
|
||||||
# For large batches, we'll use the max_batch_size graph in chunks
|
|
||||||
cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs
|
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
|
||||||
input_ids = self._create_random_batch(batch_size)
|
|
||||||
|
|
||||||
# Run benchmark
|
|
||||||
for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
|
||||||
elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper)
|
|
||||||
times.append(elapsed_time)
|
|
||||||
print(f"Run {run_idx+1}: {elapsed_time:.4f}s")
|
|
||||||
|
|
||||||
# Calculate statistics
|
|
||||||
avg_time = np.mean(times)
|
|
||||||
std_time = np.std(times)
|
|
||||||
throughput = batch_size / avg_time
|
|
||||||
|
|
||||||
results[batch_size] = {
|
|
||||||
"avg_time": avg_time,
|
|
||||||
"std_time": std_time,
|
|
||||||
"throughput": throughput,
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
|
||||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
default="facebook/contriever",
|
|
||||||
help="Path to the model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_sizes",
|
|
||||||
type=str,
|
|
||||||
default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096",
|
|
||||||
help="Comma-separated list of batch sizes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seq_length",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Sequence length for input",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_runs",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Number of runs for each batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no_fp16",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable FP16 inference",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_cuda_graphs",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CUDA Graphs optimization",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_flash_attention",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Flash Attention 2 if available",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_batch_size",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Maximum batch size before splitting to prevent OOM",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = BenchmarkConfig(
|
|
||||||
model_path=args.model_path,
|
|
||||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
|
||||||
seq_length=args.seq_length,
|
|
||||||
num_runs=args.num_runs,
|
|
||||||
use_fp16=not args.no_fp16,
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
use_flash_attention=args.use_flash_attention,
|
|
||||||
max_batch_size=args.max_batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
benchmark = Benchmark(config)
|
|
||||||
results = benchmark.run()
|
|
||||||
|
|
||||||
# Print overall summary
|
|
||||||
print("\n===== BENCHMARK SUMMARY =====")
|
|
||||||
print(f"Model: {config.model_path}")
|
|
||||||
print(f"Sequence Length: {config.seq_length}")
|
|
||||||
print(f"FP16: {config.use_fp16}")
|
|
||||||
print(f"CUDA Graphs: {config.use_cuda_graphs}")
|
|
||||||
print(f"Flash Attention: {config.use_flash_attention}")
|
|
||||||
print(f"Max Batch Size: {config.max_batch_size}")
|
|
||||||
print("\nResults:")
|
|
||||||
|
|
||||||
print("\nBatch Size | Avg Time (s) | Throughput (seq/s)")
|
|
||||||
print("-" * 50)
|
|
||||||
for bs in sorted(results.keys()):
|
|
||||||
r = results[bs]
|
|
||||||
print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,218 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import time
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
# Import necessary functions from the quantize.py file
|
|
||||||
def get_group_qparams(w, n_bit=4, groupsize=128):
|
|
||||||
# needed for GPTQ with padding
|
|
||||||
if groupsize > w.shape[-1]:
|
|
||||||
groupsize = w.shape[-1]
|
|
||||||
assert groupsize > 1
|
|
||||||
assert w.shape[-1] % groupsize == 0
|
|
||||||
assert w.dim() == 2
|
|
||||||
|
|
||||||
to_quant = w.reshape(-1, groupsize)
|
|
||||||
assert torch.isnan(to_quant).sum() == 0
|
|
||||||
|
|
||||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
||||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
||||||
max_int = 2**n_bit - 1
|
|
||||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
||||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
||||||
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
|
||||||
torch.bfloat16
|
|
||||||
).reshape(w.shape[0], -1)
|
|
||||||
|
|
||||||
def pack_scales_and_zeros(scales, zeros):
|
|
||||||
assert scales.shape == zeros.shape
|
|
||||||
assert scales.dtype == torch.bfloat16
|
|
||||||
assert zeros.dtype == torch.bfloat16
|
|
||||||
return (
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
scales.reshape(scales.size(0), scales.size(1), 1),
|
|
||||||
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
|
||||||
],
|
|
||||||
2,
|
|
||||||
)
|
|
||||||
.transpose(0, 1)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
|
||||||
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
|
||||||
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
|
||||||
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
|
||||||
return w_int32, scales_and_zeros
|
|
||||||
|
|
||||||
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
|
||||||
assert groupsize > 1
|
|
||||||
# needed for GPTQ single column quantize
|
|
||||||
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
|
||||||
groupsize = w.shape[-1]
|
|
||||||
|
|
||||||
assert w.shape[-1] % groupsize == 0
|
|
||||||
assert w.dim() == 2
|
|
||||||
|
|
||||||
to_quant = w.reshape(-1, groupsize)
|
|
||||||
assert torch.isnan(to_quant).sum() == 0
|
|
||||||
|
|
||||||
scales = scales.reshape(-1, 1)
|
|
||||||
zeros = zeros.reshape(-1, 1)
|
|
||||||
min_val = zeros - scales * (2 ** (n_bit - 1))
|
|
||||||
max_int = 2**n_bit - 1
|
|
||||||
min_int = 0
|
|
||||||
w_int32 = (
|
|
||||||
to_quant.sub(min_val)
|
|
||||||
.div(scales)
|
|
||||||
.round()
|
|
||||||
.clamp_(min_int, max_int)
|
|
||||||
.to(torch.int32)
|
|
||||||
.reshape_as(w)
|
|
||||||
)
|
|
||||||
|
|
||||||
return w_int32
|
|
||||||
|
|
||||||
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
|
||||||
weight_int32, scales_and_zeros = group_quantize_tensor(
|
|
||||||
weight_bf16, n_bit=4, groupsize=groupsize
|
|
||||||
)
|
|
||||||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
|
||||||
return weight_int4pack, scales_and_zeros
|
|
||||||
|
|
||||||
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
|
||||||
origin_x_size = x.size()
|
|
||||||
x = x.reshape(-1, origin_x_size[-1])
|
|
||||||
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
|
||||||
new_shape = origin_x_size[:-1] + (out_features,)
|
|
||||||
c = c.reshape(new_shape)
|
|
||||||
return c
|
|
||||||
|
|
||||||
class WeightOnlyInt4Linear(torch.nn.Module):
|
|
||||||
__constants__ = ['in_features', 'out_features']
|
|
||||||
in_features: int
|
|
||||||
out_features: int
|
|
||||||
weight: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, in_features: int, out_features: int,
|
|
||||||
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.groupsize = groupsize
|
|
||||||
self.inner_k_tiles = inner_k_tiles
|
|
||||||
|
|
||||||
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
|
||||||
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
|
||||||
self.register_buffer(
|
|
||||||
"weight",
|
|
||||||
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"scales_and_zeros",
|
|
||||||
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
input = input.to(torch.bfloat16)
|
|
||||||
return linear_forward_int4(
|
|
||||||
input,
|
|
||||||
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define dimensions that satisfy the requirements for INT4 quantization
|
|
||||||
# in_features must be divisible by inner_k_tiles * 16
|
|
||||||
# out_features must be divisible by 8
|
|
||||||
in_features = 1024 # Must be divisible by inner_k_tiles * 16
|
|
||||||
out_features = 2048 # Must be divisible by 8
|
|
||||||
groupsize = 128
|
|
||||||
inner_k_tiles = 8
|
|
||||||
|
|
||||||
# Create models
|
|
||||||
fp16_model = nn.Sequential(
|
|
||||||
nn.Linear(in_features, out_features, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create INT4 model
|
|
||||||
int4_model = nn.Sequential(
|
|
||||||
WeightOnlyInt4Linear(in_features, out_features, bias=False,
|
|
||||||
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Quantize the weights and set up the INT4 model
|
|
||||||
with torch.no_grad():
|
|
||||||
# Convert FP16 weights to INT4
|
|
||||||
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
|
|
||||||
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
|
||||||
fp16_weight, groupsize, inner_k_tiles
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the quantized weights in the INT4 model
|
|
||||||
int4_model[0].weight.copy_(weight_int4pack)
|
|
||||||
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
|
|
||||||
|
|
||||||
# Move models to GPU
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
fp16_model = fp16_model.to(device)
|
|
||||||
int4_model = int4_model.to(device)
|
|
||||||
|
|
||||||
# Create random input tensor
|
|
||||||
batch_size = 1024
|
|
||||||
input_tensor = torch.randn(batch_size, in_features, device=device)
|
|
||||||
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
|
|
||||||
|
|
||||||
# Speed test function
|
|
||||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
# Actual timing
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for _ in range(num_iterations):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
avg_time = (end_time - start_time) / num_iterations
|
|
||||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
|
||||||
return avg_time
|
|
||||||
|
|
||||||
# Run speed tests
|
|
||||||
with torch.no_grad(): # Disable gradient calculation for inference
|
|
||||||
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
|
|
||||||
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
|
|
||||||
|
|
||||||
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
|
|
||||||
int4_time = speed_test(int4_model, input_tensor, "INT4")
|
|
||||||
|
|
||||||
# Calculate speedup
|
|
||||||
speedup = fp16_time / int4_time
|
|
||||||
print(f"INT4 is {speedup:.2f}x faster than FP16")
|
|
||||||
|
|
||||||
# Calculate memory savings
|
|
||||||
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
|
|
||||||
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
|
|
||||||
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
|
|
||||||
|
|
||||||
memory_reduction = fp16_memory / int4_memory
|
|
||||||
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
|
|
||||||
|
|
||||||
# Check accuracy
|
|
||||||
with torch.no_grad():
|
|
||||||
fp16_output = fp16_model(input_tensor_bf16)
|
|
||||||
int4_output = int4_model(input_tensor)
|
|
||||||
|
|
||||||
# Calculate error metrics
|
|
||||||
abs_error = torch.abs(fp16_output - int4_output)
|
|
||||||
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
|
|
||||||
|
|
||||||
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
|
|
||||||
print(f"Max absolute error: {abs_error.max().item():.6f}")
|
|
||||||
print(f"Mean relative error: {rel_error.mean().item():.6f}")
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
import torch
|
|
||||||
import nvmath.bindings.cublas
|
|
||||||
import ctypes
|
|
||||||
|
|
||||||
# 创建 CUBLAS 句柄
|
|
||||||
handle = nvmath.bindings.cublas.create()
|
|
||||||
|
|
||||||
# 准备数据 - 使用 uint8 类型,并确保内存连续
|
|
||||||
m, n, k = 64, 32, 48
|
|
||||||
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
|
|
||||||
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
|
|
||||||
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
|
|
||||||
|
|
||||||
# 确保张量在 CUDA 上
|
|
||||||
assert a.is_cuda and b.is_cuda and c.is_cuda
|
|
||||||
# 确保张量是连续的
|
|
||||||
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
|
|
||||||
|
|
||||||
# 获取指针
|
|
||||||
a_ptr = a.data_ptr()
|
|
||||||
b_ptr = b.data_ptr()
|
|
||||||
c_ptr = c.data_ptr()
|
|
||||||
|
|
||||||
# 设置参数
|
|
||||||
transa = 0 # CUBLAS_OP_N (不转置)
|
|
||||||
transb = 0 # CUBLAS_OP_N (不转置)
|
|
||||||
transc = 0 # CUBLAS_OP_N (不转置)
|
|
||||||
|
|
||||||
# 设置偏置值
|
|
||||||
a_bias = 0
|
|
||||||
b_bias = 0
|
|
||||||
c_bias = 0
|
|
||||||
|
|
||||||
# 设置正确的 leading dimensions
|
|
||||||
lda = k # A 的 leading dimension
|
|
||||||
ldb = n # B 的 leading dimension
|
|
||||||
ldc = n # C 的 leading dimension
|
|
||||||
|
|
||||||
c_mult = 1
|
|
||||||
c_shift = 0
|
|
||||||
|
|
||||||
# 打印调试信息
|
|
||||||
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
|
|
||||||
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
|
|
||||||
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用 uint8gemm_bias
|
|
||||||
nvmath.bindings.cublas.uint8gemm_bias(
|
|
||||||
handle,
|
|
||||||
transa, transb, transc,
|
|
||||||
m, n, k,
|
|
||||||
a_ptr, a_bias, lda,
|
|
||||||
b_ptr, b_bias, ldb,
|
|
||||||
c_ptr, c_bias, ldc,
|
|
||||||
c_mult, c_shift
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
# 尝试使用 ctypes 转换指针
|
|
||||||
a_ptr_c = ctypes.c_void_p(a_ptr).value
|
|
||||||
b_ptr_c = ctypes.c_void_p(b_ptr).value
|
|
||||||
c_ptr_c = ctypes.c_void_p(c_ptr).value
|
|
||||||
|
|
||||||
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
|
|
||||||
|
|
||||||
# 再次尝试调用
|
|
||||||
nvmath.bindings.cublas.uint8gemm_bias(
|
|
||||||
handle,
|
|
||||||
transa, transb, transc,
|
|
||||||
m, n, k,
|
|
||||||
a_ptr_c, a_bias, lda,
|
|
||||||
b_ptr_c, b_bias, ldb,
|
|
||||||
c_ptr_c, c_bias, ldc,
|
|
||||||
c_mult, c_shift
|
|
||||||
)
|
|
||||||
|
|
||||||
# 销毁 CUBLAS 句柄
|
|
||||||
nvmath.bindings.cublas.destroy(handle)
|
|
||||||
|
|
||||||
# 打印结果
|
|
||||||
print("Result:")
|
|
||||||
print(c)
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
|
||||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
|
||||||
from llmcompressor import oneshot
|
|
||||||
|
|
||||||
# Select quantization algorithm. In this case, we:
|
|
||||||
# * apply SmoothQuant to make the activations easier to quantize
|
|
||||||
# * quantize the weights to int8 with GPTQ (static per channel)
|
|
||||||
# * quantize the activations to int8 (dynamic per token)
|
|
||||||
recipe = [
|
|
||||||
SmoothQuantModifier(smoothing_strength=0.8),
|
|
||||||
GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply quantization using the built in open_platypus dataset.
|
|
||||||
# * See examples for demos showing how to pass a custom calibration set
|
|
||||||
oneshot(
|
|
||||||
model="facebook/contriever",
|
|
||||||
dataset="open_platypus",
|
|
||||||
recipe=recipe,
|
|
||||||
output_dir="contriever-INT4",
|
|
||||||
max_seq_length=2048,
|
|
||||||
num_calibration_samples=512,
|
|
||||||
)
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
"""
|
|
||||||
This example demonstrates basic matrix multiplication of FP8 tensors.
|
|
||||||
|
|
||||||
In narrow-precision operations, quantization scales must be provided for each tensor. These
|
|
||||||
scales are used to dequantize input operands and quantize the result. Without proper
|
|
||||||
scaling, the results of FP8 operations will likely exceed the type's range.
|
|
||||||
|
|
||||||
FP8 is only supported with cuBLAS 12.8 or newer and on devices with compute
|
|
||||||
capability 8.9 or higher.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import nvmath
|
|
||||||
|
|
||||||
# Prepare sample input data. Note that N, M and K must be divisible by 16 for FP8.
|
|
||||||
# cuBLAS requires B to be column-major, so we first create a row-major tensor and then
|
|
||||||
# transpose it.
|
|
||||||
m, n, k = 64, 32, 48
|
|
||||||
a = (torch.rand(m, k, device="cuda") * 10).type(torch.float8_e4m3fn)
|
|
||||||
b = (torch.rand(n, k, device="cuda") * 10).type(torch.float8_e4m3fn).T
|
|
||||||
|
|
||||||
# Prepare quantization scales. The scales must allow the result to fit within the dynamic
|
|
||||||
# range of the data type used. Scales can be provided either as a dictionary or as a
|
|
||||||
# MatmulQuantizationScales object. Note that scales are only allowed for FP8 operands.
|
|
||||||
scales = {"a": 1, "b": 1, "d": 0.1}
|
|
||||||
|
|
||||||
# Perform the multiplication. The result of the multiplication will be:
|
|
||||||
# (scales.a * A) @ (scales.b * B) * scales.d
|
|
||||||
result = nvmath.linalg.advanced.matmul(a, b, quantization_scales=scales)
|
|
||||||
|
|
||||||
# Check how scaling helped to fit into the dynamic range of float8_e4m3fn type.
|
|
||||||
result_without_scaling = nvmath.linalg.advanced.matmul(a, b, quantization_scales={"a": 1, "b": 1, "d": 1})
|
|
||||||
print("Without scaling, most of the elements were clamped to the maximum value of float8_e4m3fn type (448):")
|
|
||||||
print(result_without_scaling)
|
|
||||||
print(f"\nWith D scale set to {scales['d']}, they were scaled down to fit into the dynamic range of float8_e4m3fn:")
|
|
||||||
print(result)
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user