Compare commits
266 Commits
perf-build
...
dynamic-ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d02aee6901 | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
62a5d7b31d | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
0a69118f87 | ||
|
|
880a039e1d | ||
|
|
4a39b40e72 | ||
|
|
ed5fd88a85 | ||
|
|
8f4f2b4873 | ||
|
|
6a06bd893a | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
4e5b73ce7b | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
a6f07a54f1 | ||
|
|
46905e0687 | ||
|
|
838ade231e | ||
|
|
da6540decd | ||
|
|
39e18a7c11 | ||
|
|
6bde28584b | ||
|
|
f62632c41f | ||
|
|
27708243ca | ||
|
|
9a1e4652ca | ||
|
|
14e84d9e2d | ||
|
|
2dcfca19ff | ||
|
|
bee2167ee3 | ||
|
|
ef980d70b3 | ||
|
|
db3c63c441 | ||
|
|
00eeadb9dd | ||
|
|
42c8370709 | ||
|
|
fafdf8fcbe | ||
|
|
21f7d8e031 | ||
|
|
46565b9249 | ||
|
|
3dad76126a | ||
|
|
18e28bda32 | ||
|
|
609fa62fd5 | ||
|
|
eab13434ef | ||
|
|
b2390ccc14 | ||
|
|
e8fca2c84a | ||
|
|
790ae14f69 | ||
|
|
ac363072e6 | ||
|
|
93465af46c | ||
|
|
792ece67dc | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
3ff5aac8e0 | ||
|
|
67fef60466 | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
5567302316 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c96d653072 | ||
|
|
8b22d2b5d3 | ||
|
|
4cb544ee38 | ||
|
|
f94ce63d51 | ||
|
|
4271ff9d84 | ||
|
|
0d448c4a41 | ||
|
|
af5599e33c | ||
|
|
efdf6d917a | ||
|
|
dd71ac8d71 | ||
|
|
8bee1d4100 | ||
|
|
33521d6d00 | ||
|
|
8899734952 | ||
|
|
54df6310c5 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 | ||
|
|
5d09586853 | ||
|
|
a7cba078dd | ||
|
|
b3e9ee96fa | ||
|
|
8537a6b17e | ||
|
|
7c8d7dc5c2 | ||
|
|
8e23d663e6 | ||
|
|
8a3994bf80 | ||
|
|
8375f601ba | ||
|
|
c87c0fe662 | ||
|
|
73927b68ef | ||
|
|
cc1a62e5aa | ||
|
|
802020cb41 | ||
|
|
cdb92f7cf4 | ||
|
|
dc69bdec00 | ||
|
|
98073e9868 | ||
|
|
cf2ef48967 | ||
|
|
0692bbf7a2 | ||
|
|
52584a171f | ||
|
|
efd6b5324b | ||
|
|
2baaa4549b | ||
|
|
35310ddd52 | ||
|
|
fc9c5cb39d | ||
|
|
8f2a1e87ea | ||
|
|
50caf65f28 | ||
|
|
1b48794ca8 | ||
|
|
4aef1d814e | ||
|
|
75ddcd6158 | ||
|
|
2a4df11f5c | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 | ||
|
|
a97d3ada1c | ||
|
|
6217bb5638 | ||
|
|
2760e99e18 | ||
|
|
0544f96b79 | ||
|
|
2ebb29de65 | ||
|
|
43762d44c7 | ||
|
|
cdaf0c98be | ||
|
|
aa9a14a917 | ||
|
|
9efcc6d95c | ||
|
|
f3f5d91207 | ||
|
|
6070160959 | ||
|
|
43155d2811 | ||
|
|
d3f85678ec | ||
|
|
2a96d05b21 | ||
|
|
851e888535 | ||
|
|
90120d4dff | ||
|
|
8513471573 | ||
|
|
71e5f1774c | ||
|
|
870a443446 | ||
|
|
cefaa2a4cc | ||
|
|
ab72a2ab9d | ||
|
|
046d457d22 | ||
|
|
7fd0a30fee | ||
|
|
c2f35c8e73 | ||
|
|
573313f0b6 | ||
|
|
f7af6805fa | ||
|
|
966de3a399 | ||
|
|
8a75829f3a | ||
|
|
0f7e34b9e2 | ||
|
|
be0322b616 | ||
|
|
232a525a62 | ||
|
|
587ce65cf6 | ||
|
|
ccf6c8bfd7 | ||
|
|
c112956d2d | ||
|
|
b3970793cf | ||
|
|
727724990e | ||
|
|
530f6e4af5 | ||
|
|
2f224f5793 | ||
|
|
1b6272ce0e |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
||||
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Bug Report
|
||||
description: Report a bug in LEANN
|
||||
labels: ["bug"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: A clear description of the bug
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduce
|
||||
attributes:
|
||||
label: How to reproduce
|
||||
placeholder: |
|
||||
1. Install with...
|
||||
2. Run command...
|
||||
3. See error
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: error
|
||||
attributes:
|
||||
label: Error message
|
||||
description: Paste any error messages
|
||||
render: shell
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: LEANN Version
|
||||
placeholder: "0.1.0"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: os
|
||||
attributes:
|
||||
label: Operating System
|
||||
options:
|
||||
- macOS
|
||||
- Linux
|
||||
- Windows
|
||||
- Docker
|
||||
validations:
|
||||
required: true
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Documentation
|
||||
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||
about: Read the docs first
|
||||
- name: Discussions
|
||||
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||
about: Ask questions and share ideas
|
||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: Feature Request
|
||||
description: Suggest a new feature for LEANN
|
||||
labels: ["enhancement"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: What problem does this solve?
|
||||
description: Describe the problem or need
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: solution
|
||||
attributes:
|
||||
label: Proposed solution
|
||||
description: How would you like this to work?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: example
|
||||
attributes:
|
||||
label: Example usage
|
||||
description: Show how the API might look
|
||||
render: python
|
||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
## What does this PR do?
|
||||
|
||||
<!-- Brief description of your changes -->
|
||||
|
||||
## Related Issues
|
||||
|
||||
Fixes #
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Tests pass (`uv run pytest`)
|
||||
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||
12
.github/workflows/build-and-publish.yml
vendored
Normal file
12
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
402
.github/workflows/build-reusable.yml
vendored
Normal file
402
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,402 @@
|
||||
name: Reusable Build
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
ref:
|
||||
description: 'Git ref to build'
|
||||
required: false
|
||||
type: string
|
||||
default: ''
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Lint and Format Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install ruff
|
||||
run: |
|
||||
uv tool install ruff
|
||||
|
||||
- name: Run ruff check
|
||||
run: |
|
||||
ruff check .
|
||||
|
||||
- name: Run ruff format check
|
||||
run: |
|
||||
ruff format --check .
|
||||
|
||||
build:
|
||||
needs: lint
|
||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-22.04
|
||||
python: '3.9'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.10'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.11'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.12'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
# ARM64 Linux builds
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.9'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.10'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.11'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.12'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.13'
|
||||
- os: macos-14
|
||||
python: '3.9'
|
||||
- os: macos-14
|
||||
python: '3.10'
|
||||
- os: macos-14
|
||||
python: '3.11'
|
||||
- os: macos-14
|
||||
python: '3.12'
|
||||
- os: macos-14
|
||||
python: '3.13'
|
||||
- os: macos-15
|
||||
python: '3.9'
|
||||
- os: macos-15
|
||||
python: '3.10'
|
||||
- os: macos-15
|
||||
python: '3.11'
|
||||
- os: macos-15
|
||||
python: '3.12'
|
||||
- os: macos-15
|
||||
python: '3.13'
|
||||
- os: macos-13
|
||||
python: '3.9'
|
||||
- os: macos-13
|
||||
python: '3.10'
|
||||
- os: macos-13
|
||||
python: '3.11'
|
||||
- os: macos-13
|
||||
python: '3.12'
|
||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||
patchelf
|
||||
|
||||
# Debug: Show system information
|
||||
echo "🔍 System Information:"
|
||||
echo "Architecture: $(uname -m)"
|
||||
echo "OS: $(uname -a)"
|
||||
echo "CPU info: $(lscpu | head -5)"
|
||||
|
||||
# Install math library based on architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||
|
||||
if [[ "$ARCH" == "x86_64" ]]; then
|
||||
# Install Intel MKL for DiskANN on x86_64
|
||||
echo "📦 Installing Intel MKL for x86_64..."
|
||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||
echo "✅ Intel MKL installed for x86_64"
|
||||
|
||||
# Debug: Check MKL installation
|
||||
echo "🔍 MKL Installation Check:"
|
||||
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||
|
||||
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||
echo "📦 Installing OpenBLAS for ARM64..."
|
||||
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||
echo "✅ OpenBLAS installed for ARM64"
|
||||
|
||||
# Debug: Check OpenBLAS installation
|
||||
echo "🔍 OpenBLAS Installation Check:"
|
||||
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||
fi
|
||||
|
||||
# Debug: Show final library paths
|
||||
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Don't install LLVM, use system clang for better compatibility
|
||||
brew install libomp boost protobuf zeromq
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --system auditwheel
|
||||
else
|
||||
uv pip install --system delocate
|
||||
fi
|
||||
|
||||
- name: Set macOS environment variables
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Use brew --prefix to automatically detect Homebrew installation path
|
||||
HOMEBREW_PREFIX=$(brew --prefix)
|
||||
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
||||
|
||||
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
||||
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||
|
||||
# Set compiler flags for OpenMP (required for both backends)
|
||||
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
||||
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
# Build core (platform independent)
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
# Build HNSW backend
|
||||
cd packages/leann-backend-hnsw
|
||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build DiskANN backend
|
||||
cd packages/leann-backend-diskann
|
||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||
# But Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build meta package (platform independent)
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
- name: Repair wheels (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
# Repair HNSW wheel
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ -d dist ]; then
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
- name: Repair wheels (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Determine deployment target based on runner OS
|
||||
# Must match the Homebrew libraries for each macOS version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
HNSW_TARGET="13.0"
|
||||
DISKANN_TARGET="13.3"
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
HNSW_TARGET="14.0"
|
||||
DISKANN_TARGET="14.0"
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
fi
|
||||
|
||||
# Repair HNSW wheel
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ -d dist ]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
- name: List built packages
|
||||
run: |
|
||||
echo "📦 Built packages:"
|
||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create a virtual environment with the correct Python version
|
||||
uv venv --python ${{ matrix.python }}
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Install packages using --find-links to prioritize local builds
|
||||
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
|
||||
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
|
||||
|
||||
# Install test dependencies using extras
|
||||
uv pip install -e ".[test]"
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
HF_HUB_DISABLE_SYMLINKS: 1
|
||||
TOKENIZERS_PARALLELISM: false
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
||||
OMP_NUM_THREADS: 1
|
||||
MKL_NUM_THREADS: 1
|
||||
run: |
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
pytest tests/ -v --tb=short
|
||||
|
||||
- name: Run sanity checks (optional)
|
||||
run: |
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Run distance function tests if available
|
||||
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||
echo "Running distance function sanity checks..."
|
||||
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||
fi
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||
path: packages/*/dist/
|
||||
|
||||
|
||||
arch-smoke:
|
||||
name: Arch Linux smoke test (install & import)
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: archlinux:latest
|
||||
|
||||
steps:
|
||||
- name: Prepare system
|
||||
run: |
|
||||
pacman -Syu --noconfirm
|
||||
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||
|
||||
- name: Download ALL wheel artifacts from this run
|
||||
uses: actions/download-artifact@v5
|
||||
with:
|
||||
# Don't specify name, download all artifacts
|
||||
path: ./wheels
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Create virtual environment and install wheels
|
||||
run: |
|
||||
uv venv
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
uv pip install --find-links wheels leann-core
|
||||
uv pip install --find-links wheels leann-backend-hnsw
|
||||
uv pip install --find-links wheels leann-backend-diskann
|
||||
uv pip install --find-links wheels leann
|
||||
|
||||
- name: Import & tiny runtime check
|
||||
env:
|
||||
OMP_NUM_THREADS: 1
|
||||
MKL_NUM_THREADS: 1
|
||||
run: |
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
python - <<'PY'
|
||||
import leann
|
||||
import leann_backend_hnsw as h
|
||||
import leann_backend_diskann as d
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
b = LeannBuilder(backend_name="hnsw")
|
||||
b.add_text("hello arch")
|
||||
b.build_index("arch_demo.leann")
|
||||
s = LeannSearcher("arch_demo.leann")
|
||||
print("search:", s.search("hello", top_k=1))
|
||||
PY
|
||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Link Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1"
|
||||
|
||||
jobs:
|
||||
link-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
129
.github/workflows/release-manual.yml
vendored
Normal file
129
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to release (e.g., 0.1.2)'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
name: Update Version
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Validate version
|
||||
run: |
|
||||
# Remove 'v' prefix if present for validation
|
||||
VERSION_CLEAN="${{ inputs.version }}"
|
||||
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
||||
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||
|
||||
- name: Update versions and push
|
||||
id: push
|
||||
run: |
|
||||
# Check current version
|
||||
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||
echo "Current version: $CURRENT_VERSION"
|
||||
echo "Target version: ${{ inputs.version }}"
|
||||
|
||||
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
else
|
||||
./scripts/bump_version.sh ${{ inputs.version }}
|
||||
git config user.name "GitHub Actions"
|
||||
git config user.email "actions@github.com"
|
||||
git add packages/*/pyproject.toml
|
||||
git commit -m "chore: release v${{ inputs.version }}"
|
||||
git push origin main
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||
fi
|
||||
|
||||
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
build-packages:
|
||||
name: Build packages
|
||||
needs: update-version
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
with:
|
||||
ref: 'main'
|
||||
|
||||
publish:
|
||||
name: Publish and Release
|
||||
needs: [update-version, build-packages]
|
||||
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: 'main'
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist-artifacts
|
||||
|
||||
- name: Collect packages
|
||||
run: |
|
||||
mkdir -p dist
|
||||
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||
|
||||
echo "📦 Packages to publish:"
|
||||
ls -la dist/
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
if [ -z "$TWINE_PASSWORD" ]; then
|
||||
echo "❌ PYPI_API_TOKEN not configured!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing --verbose
|
||||
|
||||
echo "✅ Published to PyPI!"
|
||||
|
||||
- name: Create release
|
||||
run: |
|
||||
# Check if tag already exists
|
||||
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||
else
|
||||
git tag "v${{ inputs.version }}"
|
||||
git push origin "v${{ inputs.version }}"
|
||||
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||
fi
|
||||
|
||||
# Check if release already exists
|
||||
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||
else
|
||||
gh release create "v${{ inputs.version }}" \
|
||||
--title "Release v${{ inputs.version }}" \
|
||||
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||
--latest
|
||||
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||
fi
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
33
.gitignore
vendored
33
.gitignore
vendored
@@ -9,19 +9,20 @@ demo/indices/
|
||||
outputs/
|
||||
*.pkl
|
||||
*.pdf
|
||||
*.idx
|
||||
*.idx
|
||||
*.map
|
||||
.history/
|
||||
scripts/
|
||||
lm_eval.egg-info/
|
||||
demo/experiment_results/**/*.json
|
||||
*.jsonl
|
||||
*.eml
|
||||
*.emlx
|
||||
*.json
|
||||
!.vscode/*.json
|
||||
*.sh
|
||||
*.txt
|
||||
!CMakeLists.txt
|
||||
!llms.txt
|
||||
latency_breakdown*.json
|
||||
experiment_results/eval_results/diskann/*.json
|
||||
aws/
|
||||
@@ -35,11 +36,15 @@ build/
|
||||
nprobe_logs/
|
||||
micro/results
|
||||
micro/contriever-INT8
|
||||
examples/data/*
|
||||
!examples/data/2501.14312v1 (1).pdf
|
||||
!examples/data/2506.08276v1.pdf
|
||||
!examples/data/PrideandPrejudice.txt
|
||||
!examples/data/README.md
|
||||
data/*
|
||||
!data/2501.14312v1 (1).pdf
|
||||
!data/2506.08276v1.pdf
|
||||
!data/PrideandPrejudice.txt
|
||||
!data/huawei_pangu.md
|
||||
!data/ground_truth/
|
||||
!data/indices/
|
||||
!data/queries/
|
||||
!data/.gitattributes
|
||||
*.qdstrm
|
||||
benchmark_results/
|
||||
results/
|
||||
@@ -84,4 +89,16 @@ test_*.py
|
||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
*.passages.json
|
||||
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
paru-bin/
|
||||
|
||||
CLAUDE.md
|
||||
CLAUDE.local.md
|
||||
.claude/*.local.*
|
||||
.claude/local/*
|
||||
benchmarks/data/
|
||||
test_add/*
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||
url = https://github.com/zeromq/libzmq.git
|
||||
[submodule "packages/astchunk-leann"]
|
||||
path = packages/astchunk-leann
|
||||
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||
|
||||
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
8
.vscode/extensions.json
vendored
8
.vscode/extensions.json
vendored
@@ -1,9 +1,5 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"llvm-vs-code-extensions.vscode-clangd",
|
||||
"ms-python.python",
|
||||
"ms-vscode.cmake-tools",
|
||||
"vadimcn.vscode-lldb",
|
||||
"eamodio.gitlens",
|
||||
"charliermarsh.ruff",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
283
.vscode/launch.json
vendored
283
.vscode/launch.json
vendored
@@ -1,283 +0,0 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
// new emdedder
|
||||
{
|
||||
"name": "New Embedder",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--search",
|
||||
"--use-original",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--nprobe",
|
||||
"5000",
|
||||
"--load",
|
||||
"flat",
|
||||
"--embedder",
|
||||
"intfloat/multilingual-e5-small"
|
||||
]
|
||||
}
|
||||
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
|
||||
{
|
||||
"name": "main.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--query",
|
||||
"1000",
|
||||
"--load",
|
||||
"bm25"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Simple Build",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/simple_build.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
//# Fix for Intel MKL error
|
||||
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
|
||||
//python faiss/demo/build_demo.py
|
||||
{
|
||||
"name": "Build Demo",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/build_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings",
|
||||
"--port",
|
||||
"8082",
|
||||
"--diskann-search-memory-maximum",
|
||||
"2",
|
||||
"--diskann-graph",
|
||||
"240",
|
||||
"--search-only"
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
|
||||
},
|
||||
"preLaunchTask": "CMake: build",
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve MAC",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"ollama",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings"
|
||||
],
|
||||
"preLaunchTask": "CMake: build",
|
||||
"env": {
|
||||
"KMP_DUPLICATE_LIB_OK": "TRUE",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
|
||||
"KMP_BLOCKTIME": "0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Python Debugger: Current File with Arguments",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "ric/main_ric.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--config-name",
|
||||
"${input:configSelection}"
|
||||
],
|
||||
"justMyCode": false
|
||||
},
|
||||
//python ./demo/validate_equivalence.py sglang
|
||||
{
|
||||
"name": "Validate Equivalence",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/validate_equivalence.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"sglang"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
|
||||
{
|
||||
"name": "Retrieval Demo",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"vllm",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
// "flat",
|
||||
"ivf_flat"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
|
||||
{
|
||||
"name": "Retrieval Demo DiskANN",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--hnsw-M",
|
||||
"64",
|
||||
"--hnsw-efConstruction",
|
||||
"150",
|
||||
"--hnsw-efSearch",
|
||||
"128",
|
||||
"--hnsw-sq-bits",
|
||||
"8"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Find Probe",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "find_probe.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
},
|
||||
{
|
||||
"name": "Python: Attach",
|
||||
"type": "debugpy",
|
||||
"request": "attach",
|
||||
"processId": "${command:pickProcess}",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Edge RAG",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"edgerag_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Launch Embedding Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/embedding_server.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--zmq-port",
|
||||
"5556",
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "HNSW Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--load",
|
||||
"hnsw",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--search",
|
||||
"--skip-pa",
|
||||
"--recompute",
|
||||
"--hnsw-old"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"id": "configSelection",
|
||||
"type": "pickString",
|
||||
"description": "Select a configuration",
|
||||
"options": [
|
||||
"example_config",
|
||||
"vllm_gritlm"
|
||||
],
|
||||
"default": "example_config"
|
||||
}
|
||||
],
|
||||
}
|
||||
61
.vscode/settings.json
vendored
Executable file → Normal file
61
.vscode/settings.json
vendored
Executable file → Normal file
@@ -1,43 +1,22 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"./sglang_repo/python"
|
||||
],
|
||||
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
|
||||
"cmake.configureArgs": [
|
||||
"-DPYBIND=True",
|
||||
"-DUPDATE_EDITABLE_INSTALL=ON",
|
||||
],
|
||||
"cmake.environment": {
|
||||
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
|
||||
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit",
|
||||
"source.fixAll": "explicit"
|
||||
},
|
||||
"cmake.buildDirectory": "${workspaceFolder}/build",
|
||||
"files.associations": {
|
||||
"*.tcc": "cpp",
|
||||
"deque": "cpp",
|
||||
"string": "cpp",
|
||||
"unordered_map": "cpp",
|
||||
"vector": "cpp",
|
||||
"map": "cpp",
|
||||
"unordered_set": "cpp",
|
||||
"atomic": "cpp",
|
||||
"inplace_vector": "cpp",
|
||||
"*.ipp": "cpp",
|
||||
"forward_list": "cpp",
|
||||
"list": "cpp",
|
||||
"any": "cpp",
|
||||
"system_error": "cpp",
|
||||
"__hash_table": "cpp",
|
||||
"__split_buffer": "cpp",
|
||||
"__tree": "cpp",
|
||||
"ios": "cpp",
|
||||
"set": "cpp",
|
||||
"__string": "cpp",
|
||||
"string_view": "cpp",
|
||||
"ranges": "cpp",
|
||||
"iosfwd": "cpp"
|
||||
},
|
||||
"lldb.displayFormat": "auto",
|
||||
"lldb.showDisassembly": "auto",
|
||||
"lldb.dereferencePointers": true,
|
||||
"lldb.consoleMode": "commands",
|
||||
}
|
||||
"editor.insertSpaces": true,
|
||||
"editor.tabSize": 4
|
||||
},
|
||||
"ruff.enable": true,
|
||||
"files.watcherExclude": {
|
||||
"**/.venv/**": true,
|
||||
"**/__pycache__/**": true,
|
||||
"**/*.egg-info/**": true,
|
||||
"**/build/**": true,
|
||||
"**/dist/**": true
|
||||
}
|
||||
}
|
||||
|
||||
16
.vscode/tasks.json
vendored
16
.vscode/tasks.json
vendored
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "cmake",
|
||||
"label": "CMake: build",
|
||||
"command": "build",
|
||||
"targets": [
|
||||
"all"
|
||||
],
|
||||
"group": "build",
|
||||
"problemMatcher": [],
|
||||
"detail": "CMake template build task"
|
||||
}
|
||||
]
|
||||
}
|
||||
784
README.md
784
README.md
@@ -3,20 +3,27 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
|
||||
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
|
||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
|
||||
</p>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
||||
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
|
||||
|
||||
|
||||
@@ -26,55 +33,172 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||
</p>
|
||||
|
||||
**The numbers speak for themselves:** Index 60 million Wikipedia articles in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks below ↓](#storage-usage-comparison)
|
||||
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#-storage-comparison)
|
||||
|
||||
## Why This Matters
|
||||
|
||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||
|
||||
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||
|
||||
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
|
||||
|
||||
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||
|
||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||
|
||||
## Quick Start in 1 minute
|
||||
## Installation
|
||||
|
||||
### 📦 Prerequisites: Install uv
|
||||
|
||||
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
### 🚀 Quick Install
|
||||
|
||||
Clone the repository to access all examples and try amazing applications,
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
```
|
||||
|
||||
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
<!--
|
||||
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<strong>🔧 Build from Source (Recommended for development)</strong>
|
||||
</summary>
|
||||
|
||||
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
**macOS:**
|
||||
|
||||
Note: DiskANN requires MacOS 13.3 or later.
|
||||
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf
|
||||
export CC=$(brew --prefix llvm)/bin/clang
|
||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
brew install libomp boost protobuf zeromq pkgconf
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Linux (Ubuntu/Debian):**
|
||||
|
||||
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
|
||||
|
||||
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
|
||||
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
|
||||
sudo apt-get update && sudo apt-get install -y \
|
||||
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||
libmkl-full-dev
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Ollama Setup (Optional for Local LLM):**
|
||||
**Linux (Arch Linux):**
|
||||
|
||||
*We support both hf-transformers and Ollama for local LLMs. Ollama is recommended for faster performance.*
|
||||
```bash
|
||||
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
|
||||
boost boost-libs protobuf abseil-cpp libaio zeromq
|
||||
|
||||
*macOS:*
|
||||
# For MKL in DiskANN
|
||||
sudo pacman -S --needed base-devel git
|
||||
git clone https://aur.archlinux.org/paru-bin.git
|
||||
cd paru-bin && makepkg -si
|
||||
paru -S intel-oneapi-mkl intel-oneapi-compiler
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
|
||||
|
||||
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
|
||||
|
||||
```bash
|
||||
sudo dnf groupinstall -y "Development Tools"
|
||||
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
|
||||
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
|
||||
|
||||
# For MKL in DiskANN
|
||||
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
|
||||
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
|
||||
Check out [demo.ipynb](demo.ipynb) or [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
from pathlib import Path
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||
|
||||
# Chat with your data
|
||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
```
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||
|
||||
|
||||
|
||||
### Generation Model Setup
|
||||
|
||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
|
||||
<details>
|
||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||
|
||||
Set your OpenAI API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
|
||||
|
||||
**macOS:**
|
||||
|
||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
|
||||
@@ -83,7 +207,8 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
*Linux:*
|
||||
**Linux:**
|
||||
|
||||
```bash
|
||||
# Install Ollama
|
||||
curl -fsSL https://ollama.ai/install.sh | sh
|
||||
@@ -95,81 +220,127 @@ ollama serve &
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
You can also replace `llama3.2:1b` to `deepseek-r1:1.5b` or `qwen3:4b` for better performance but higher memory usage.
|
||||
</details>
|
||||
|
||||
## Dead Simple API
|
||||
|
||||
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
||||
## ⭐ Flexible Configuration
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||
|
||||
# 1. Build index (no embeddings stored!)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("C# is a powerful programming language")
|
||||
builder.add_text("Python is a powerful programming language")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Neural networks process complex data")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
|
||||
builder.build_index("knowledge.leann")
|
||||
|
||||
# 2. Search with real-time embeddings
|
||||
searcher = LeannSearcher("knowledge.leann")
|
||||
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
|
||||
print(results)
|
||||
```
|
||||
|
||||
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
||||
|
||||
[Try the interactive demo →](demo.ipynb)
|
||||
|
||||
## Wild Things You Can Do
|
||||
|
||||
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
||||
|
||||
### Process Any Documents (.pdf, .txt, .md)
|
||||
|
||||
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents.
|
||||
|
||||
```bash
|
||||
# Drop your PDFs, .txt, .md files into examples/data/
|
||||
uv run ./examples/main_cli_example.py
|
||||
|
||||
# Or use python directly
|
||||
source .venv/bin/activate
|
||||
python ./examples/main_cli_example.py
|
||||
```
|
||||
|
||||
Uses Ollama `qwen3:8b` by default. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
||||
|
||||
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
||||
|
||||
### Search Your Entire Life
|
||||
```bash
|
||||
python examples/mail_reader_leann.py
|
||||
# "What did my boss say about the Christmas party last year?"
|
||||
# "Find all emails from my mom about birthday plans"
|
||||
```
|
||||
**90K emails → 14MB.** Finally, search your email like you search Google.
|
||||
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
||||
|
||||
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
|
||||
|
||||
```bash
|
||||
# Use default mail path (works for most macOS setups)
|
||||
python examples/mail_reader_leann.py
|
||||
# Core Parameters (General preprocessing for all examples)
|
||||
--index-dir DIR # Directory to store the index (default: current directory)
|
||||
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
||||
--max-items N # Limit data preprocessing (default: -1, process all data)
|
||||
--force-rebuild # Force rebuild index even if it exists
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
|
||||
# Embedding Parameters
|
||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# Process all emails (may take time but indexes everything)
|
||||
python examples/mail_reader_leann.py --max-emails -1
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||
|
||||
# Limit number of emails processed (useful for testing)
|
||||
python examples/mail_reader_leann.py --max-emails 1000
|
||||
# Search Parameters
|
||||
--top-k N # Number of results to retrieve (default: 20)
|
||||
--search-complexity N # Search complexity for graph traversal (default: 32)
|
||||
|
||||
# Run a single query
|
||||
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
|
||||
# Chunking Parameters
|
||||
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
||||
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
||||
|
||||
# Index Building Parameters
|
||||
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
||||
--graph-degree N # Graph degree for index construction (default: 32)
|
||||
--build-complexity N # Build complexity for index construction (default: 64)
|
||||
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
||||
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
|
||||
|
||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # Don't forget to activate the virtual environment
|
||||
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--data-dir DIR # Directory containing documents to process (default: data)
|
||||
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Process all documents with larger chunks for academic papers
|
||||
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
||||
|
||||
# Filter only markdown and Python files with smaller chunks
|
||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||
|
||||
# Enable AST-aware chunking for code files
|
||||
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
|
||||
|
||||
# Or use the specialized code RAG for better code understanding
|
||||
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
|
||||
```bash
|
||||
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
||||
```
|
||||
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
|
||||
--include-html # Include HTML content in processing (useful for newsletters)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search work emails from a specific account
|
||||
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
|
||||
|
||||
# Find all receipts and order confirmations (includes HTML)
|
||||
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -183,29 +354,32 @@ Once the index is built, you can ask questions like:
|
||||
- "Show me emails about travel expenses"
|
||||
</details>
|
||||
|
||||
### Time Machine for the Web
|
||||
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/google_history_reader_leann.py
|
||||
# "What was that AI paper I read last month?"
|
||||
# "Show me all the cooking videos I watched"
|
||||
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
|
||||
```
|
||||
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default Chrome profile (auto-finds all profiles)
|
||||
python examples/google_history_reader_leann.py
|
||||
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
|
||||
```
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search academic research from your browsing history
|
||||
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
|
||||
|
||||
# Limit number of history entries processed (useful for testing)
|
||||
python examples/google_history_reader_leann.py --max-entries 500
|
||||
|
||||
# Run a single query
|
||||
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
|
||||
# Track competitor analysis across work profile
|
||||
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -238,44 +412,58 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### WeChat Detective
|
||||
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/wechat_history_reader_leann.py
|
||||
# "Show me all group chats about weekend plans"
|
||||
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
||||
```
|
||||
**400K messages → 64MB.** Search years of chat history in any language.
|
||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||
|
||||
First, you need to install the WeChat exporter:
|
||||
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
|
||||
|
||||
```bash
|
||||
brew install sunnyyoung/repo/wechattweak-cli
|
||||
```
|
||||
|
||||
or install it manually (if you have issues with Homebrew):
|
||||
|
||||
```bash
|
||||
sudo packages/wechat-exporter/wechattweak-cli install
|
||||
```
|
||||
|
||||
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
||||
**Troubleshooting:**
|
||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||
```bash
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default settings (recommended for first run)
|
||||
python examples/wechat_history_reader_leann.py
|
||||
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
|
||||
--force-export # Force re-export even if data exists
|
||||
```
|
||||
|
||||
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
|
||||
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search for travel plans discussed in group chats
|
||||
python -m apps.wechat_rag --query "travel plans" --max-items 10000
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
|
||||
|
||||
# Limit number of chat entries processed (useful for testing)
|
||||
python examples/wechat_history_reader_leann.py --max-entries 1000
|
||||
|
||||
# Run a single query
|
||||
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
|
||||
# Re-export and search recent chats (useful after new messages)
|
||||
python -m apps.wechat_rag --force-export --query "work schedule"
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -289,6 +477,197 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||
|
||||
<details>
|
||||
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
||||
|
||||
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||
|
||||
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
|
||||
|
||||
</details>
|
||||
|
||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||
|
||||
**Key features:**
|
||||
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
|
||||
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
|
||||
- 📚 **Context-aware assistance** for debugging and development
|
||||
- 🚀 **Zero-config setup** with automatic language detection
|
||||
|
||||
```bash
|
||||
# Install LEANN globally for MCP integration
|
||||
uv tool install leann-core --with leann
|
||||
claude mcp add --scope user leann-server -- leann_mcp
|
||||
# Setup is automatic - just start using Claude Code!
|
||||
```
|
||||
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
||||
|
||||

|
||||
|
||||
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||
|
||||
## 🖥️ Command Line Interface
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
### Installation
|
||||
|
||||
If you followed the Quick Start, `leann` is already installed in your virtual environment:
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
leann --help
|
||||
```
|
||||
|
||||
**To make it globally available:**
|
||||
```bash
|
||||
# Install the LEANN CLI globally using uv tool
|
||||
uv tool install leann-core --with leann
|
||||
|
||||
|
||||
# Now you can use leann from anywhere without activating venv
|
||||
leann --help
|
||||
```
|
||||
|
||||
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
|
||||
|
||||
|
||||
|
||||
### Usage Examples
|
||||
|
||||
```bash
|
||||
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
|
||||
leann build my-docs --docs ./your_documents
|
||||
|
||||
# Search your documents
|
||||
leann search my-docs "machine learning concepts"
|
||||
|
||||
# Interactive chat with your documents
|
||||
leann ask my-docs --interactive
|
||||
|
||||
# List all your indexes
|
||||
leann list
|
||||
|
||||
# Remove an index
|
||||
leann remove my-docs
|
||||
```
|
||||
|
||||
**Key CLI features:**
|
||||
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
|
||||
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
|
||||
- Smart text chunking with overlap for all other content
|
||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||
- Organized index storage in `.leann/indexes/` (project-local)
|
||||
- Support for advanced search parameters
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||
|
||||
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help`, `leann list --help`, `leann remove --help` to get the complete CLI reference.
|
||||
|
||||
**Build Command:**
|
||||
```bash
|
||||
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS]
|
||||
|
||||
Options:
|
||||
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||
--graph-degree N Graph degree (default: 32)
|
||||
--complexity N Build complexity (default: 64)
|
||||
--force Force rebuild existing index
|
||||
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
||||
--recompute / --no-recompute Enable recomputation (default: true)
|
||||
```
|
||||
|
||||
**Search Command:**
|
||||
```bash
|
||||
leann search INDEX_NAME QUERY [OPTIONS]
|
||||
|
||||
Options:
|
||||
--top-k N Number of results (default: 5)
|
||||
--complexity N Search complexity (default: 64)
|
||||
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
||||
--pruning-strategy {global,local,proportional}
|
||||
```
|
||||
|
||||
**Ask Command:**
|
||||
```bash
|
||||
leann ask INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
**List Command:**
|
||||
```bash
|
||||
leann list
|
||||
|
||||
# Lists all indexes across all projects with status indicators:
|
||||
# ✅ - Index is complete and ready to use
|
||||
# ❌ - Index is incomplete or corrupted
|
||||
# 📁 - CLI-created index (in .leann/indexes/)
|
||||
# 📄 - App-created index (*.leann.meta.json files)
|
||||
```
|
||||
|
||||
**Remove Command:**
|
||||
```bash
|
||||
leann remove INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--force, -f Force removal without confirmation
|
||||
|
||||
# Smart removal: automatically finds and safely removes indexes
|
||||
# - Shows all matching indexes across projects
|
||||
# - Requires confirmation for cross-project removal
|
||||
# - Interactive selection when multiple matches found
|
||||
# - Supports both CLI and app-created indexes
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🚀 Advanced Features
|
||||
|
||||
### 🎯 Metadata Filtering
|
||||
|
||||
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
|
||||
|
||||
```python
|
||||
# Add metadata during indexing
|
||||
builder.add_text(
|
||||
"def authenticate_user(token): ...",
|
||||
metadata={"file_extension": ".py", "lines_of_code": 25}
|
||||
)
|
||||
|
||||
# Search with filters
|
||||
results = searcher.search(
|
||||
query="authentication function",
|
||||
metadata_filters={
|
||||
"file_extension": {"==": ".py"},
|
||||
"lines_of_code": {"<": 100}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
|
||||
|
||||
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
||||
|
||||
### 🔍 Grep Search
|
||||
|
||||
For exact text matching instead of semantic search, use the `use_grep` parameter:
|
||||
|
||||
```python
|
||||
# Exact text search
|
||||
results = searcher.search("banana‑crocodile", use_grep=True, top_k=1)
|
||||
```
|
||||
|
||||
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
|
||||
|
||||
📖 **[Complete grep search guide →](docs/grep_search.md)**
|
||||
|
||||
## 🏗️ Architecture & How It Works
|
||||
|
||||
@@ -300,62 +679,39 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
**Core techniques:**
|
||||
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||
|
||||
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||
**Backends:**
|
||||
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
||||
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Run the comparison yourself:
|
||||
```bash
|
||||
python examples/compare_faiss_vs_leann.py
|
||||
```
|
||||
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
||||
|
||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
||||
|
||||
### 📊 Storage Comparison
|
||||
|
||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||
|--------|-------------|------------|-------------|--------------|---------------|
|
||||
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
|
||||
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
|
||||
| Savings| 91% | 97% | 97% | 97% | 95% |
|
||||
|
||||
| System | Storage |
|
||||
|--------|---------|
|
||||
| FAISS HNSW | 5.5 MB |
|
||||
| LEANN | 0.5 MB |
|
||||
| **Savings** | **91%** |
|
||||
|
||||
Same dataset, same hardware, same embedding model. LEANN just works better.
|
||||
|
||||
## Reproduce Our Results
|
||||
|
||||
```bash
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run.
|
||||
|
||||
### Storage Usage Comparison
|
||||
|
||||
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
||||
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
||||
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
||||
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
||||
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
||||
|
||||
<!-- ### Memory Usage Comparison
|
||||
|
||||
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
|
||||
| --------------------- | ---------------- | ---------------- | ---------------- |
|
||||
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
|
||||
| **Leann** | **xx MB** | **x GB** | **x GB** |
|
||||
| **Reduction** | **x%** | **x%** | **x%** |
|
||||
|
||||
### Query Performance of LEANN
|
||||
|
||||
| Backend | Index Size | Query Time | Recall@3 |
|
||||
| ------------------- | ---------- | ---------- | --------- |
|
||||
| DiskANN | 1M docs | xms | 0.95 |
|
||||
| HNSW | 1M docs | xms | 0.95 | -->
|
||||
|
||||
*Benchmarks run on Apple M3 Pro 36 GB*
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
## 🔬 Paper
|
||||
|
||||
If you find Leann useful, please cite:
|
||||
@@ -364,97 +720,25 @@ If you find Leann useful, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{wang2025leannlowstoragevectorindex,
|
||||
title={LEANN: A Low-Storage Vector Index},
|
||||
title={LEANN: A Low-Storage Vector Index},
|
||||
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
|
||||
year={2025},
|
||||
eprint={2506.08276},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.DB},
|
||||
url={https://arxiv.org/abs/2506.08276},
|
||||
url={https://arxiv.org/abs/2506.08276},
|
||||
}
|
||||
```
|
||||
|
||||
## ✨ Features
|
||||
## ✨ [Detailed Features →](docs/features.md)
|
||||
|
||||
### 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
|
||||
### 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
|
||||
### 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
### Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
|
||||
|
||||
|
||||
<!-- ## ❓ FAQ
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### NCCL Topology Error
|
||||
|
||||
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
||||
|
||||
```
|
||||
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
||||
```
|
||||
|
||||
**Solution**: Set these environment variables before running your script:
|
||||
|
||||
```bash
|
||||
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
||||
export NCCL_IB_DISABLE=1
|
||||
export NCCL_NET_PLUGIN=none
|
||||
export NCCL_SOCKET_IFNAME=ens5
|
||||
``` -->
|
||||
|
||||
## 📈 Roadmap
|
||||
|
||||
### 🎯 Q2 2025
|
||||
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] HNSW backend integration
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
### 🚀 Q3 2025
|
||||
## ❓ [FAQ →](docs/faq.md)
|
||||
|
||||
|
||||
- [ ] Advanced caching strategies
|
||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||
- [ ] Add OpenAI recompute API
|
||||
|
||||
### 🌟 Q4 2025
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
## 📈 [Roadmap →](docs/roadmap.md)
|
||||
|
||||
## 📄 License
|
||||
|
||||
@@ -462,13 +746,18 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
- **Microsoft Research** for the DiskANN algorithm
|
||||
- **Meta AI** for FAISS and optimization insights
|
||||
- **HuggingFace** for the transformer ecosystem
|
||||
- **Our amazing contributors** who make this possible
|
||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||
|
||||
---
|
||||
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
||||
|
||||
|
||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://www.star-history.com/#yichuan-w/LEANN&Date)
|
||||
<p align="center">
|
||||
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
||||
</p>
|
||||
@@ -476,4 +765,3 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
<p align="center">
|
||||
Made with ❤️ by the Leann team
|
||||
</p>
|
||||
|
||||
|
||||
342
apps/base_rag_example.py
Normal file
342
apps/base_rag_example.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Base class for unified RAG examples interface.
|
||||
Provides common parameters and functionality for all RAG examples.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
class BaseRAGExample(ABC):
|
||||
"""Base class for all RAG examples with unified interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
default_index_name: str,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.default_index_name = default_index_name
|
||||
self.parser = self._create_parser()
|
||||
|
||||
def _create_parser(self) -> argparse.ArgumentParser:
|
||||
"""Create argument parser with common parameters."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
# Core parameters (all examples share these)
|
||||
core_group = parser.add_argument_group("Core Parameters")
|
||||
core_group.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default=f"./{self.default_index_name}",
|
||||
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Query to run (if not provided, will run in interactive mode)",
|
||||
)
|
||||
# Allow subclasses to override default max_items
|
||||
max_items_default = getattr(self, "max_items_default", -1)
|
||||
core_group.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=max_items_default,
|
||||
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||
)
|
||||
|
||||
# Embedding parameters
|
||||
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||
# Allow subclasses to override default embedding_model
|
||||
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||
embedding_group.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default=embedding_model_default,
|
||||
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||
)
|
||||
|
||||
# LLM parameters
|
||||
llm_group = parser.add_argument_group("LLM Parameters")
|
||||
llm_group.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="openai",
|
||||
choices=["openai", "ollama", "hf", "simulated"],
|
||||
help="LLM backend: openai, ollama, or hf (default: openai)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="Host for Ollama API (default: http://localhost:11434)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--thinking-budget",
|
||||
type=str,
|
||||
choices=["low", "medium", "high"],
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
|
||||
# AST Chunking parameters
|
||||
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||
ast_group.add_argument(
|
||||
"--use-ast-chunking",
|
||||
action="store_true",
|
||||
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum characters per AST chunk (default: 512)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-overlap",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Overlap between AST chunks (default: 64)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--code-file-extensions",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-fallback-traditional",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
search_group = parser.add_argument_group("Search Parameters")
|
||||
search_group.add_argument(
|
||||
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||
)
|
||||
search_group.add_argument(
|
||||
"--search-complexity",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Search complexity for graph traversal (default: 64)",
|
||||
)
|
||||
|
||||
# Index building parameters
|
||||
index_group = parser.add_argument_group("Index Building Parameters")
|
||||
index_group.add_argument(
|
||||
"--backend-name",
|
||||
type=str,
|
||||
default="hnsw",
|
||||
choices=["hnsw", "diskann"],
|
||||
help="Backend to use for index (default: hnsw)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--graph-degree",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Graph degree for index construction (default: 32)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--build-complexity",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Build complexity for index construction (default: 64)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-compact",
|
||||
action="store_true",
|
||||
help="Disable compact index storage",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-recompute",
|
||||
action="store_true",
|
||||
help="Disable embedding recomputation",
|
||||
)
|
||||
|
||||
# Add source-specific parameters
|
||||
self._add_specific_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add source-specific arguments. Override in subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
"""Get LLM configuration based on arguments."""
|
||||
config = {"type": args.llm}
|
||||
|
||||
if args.llm == "openai":
|
||||
config["model"] = args.llm_model or "gpt-4o"
|
||||
elif args.llm == "ollama":
|
||||
config["model"] = args.llm_model or "llama3.2:1b"
|
||||
config["host"] = args.llm_host
|
||||
elif args.llm == "hf":
|
||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
elif args.llm == "simulated":
|
||||
# Simulated LLM doesn't need additional configuration
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
print(f"Total text chunks: {len(texts)}")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
is_recompute=not args.no_recompute,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
# Add texts in batches for better progress tracking
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for text in batch:
|
||||
builder.add_text(text)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
builder.build_index(index_path)
|
||||
print(f"Index saved to: {index_path}")
|
||||
|
||||
# Register project directory so leann list can discover this index
|
||||
# The index is saved as args.index_dir/index_name.leann
|
||||
# We want to register the current working directory where the app is run
|
||||
register_project_directory(Path.cwd())
|
||||
|
||||
return index_path
|
||||
|
||||
async def run_interactive_chat(self, args, index_path: str):
|
||||
"""Run interactive chat with the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||
print("Type 'quit' or 'exit' to stop.\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("You: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
||||
)
|
||||
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point for the example."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Check if index exists
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
index_exists = Path(args.index_dir).exists()
|
||||
|
||||
if not index_exists or args.force_rebuild:
|
||||
# Load data and build index
|
||||
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||
texts = await self.load_data(args)
|
||||
|
||||
if not texts:
|
||||
print("No data found to index!")
|
||||
return
|
||||
|
||||
index_path = await self.build_index(args, texts)
|
||||
else:
|
||||
print(f"\nUsing existing index in {args.index_dir}")
|
||||
|
||||
# Run query or interactive mode
|
||||
if args.query:
|
||||
await self.run_single_query(args, index_path, args.query)
|
||||
else:
|
||||
await self.run_interactive_chat(args, index_path)
|
||||
171
apps/browser_rag.py
Normal file
171
apps/browser_rag.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Browser History RAG example using the unified interface.
|
||||
Supports Chrome browser history.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .history_data.history import ChromeHistoryReader
|
||||
|
||||
|
||||
class BrowserRAG(BaseRAGExample):
|
||||
"""RAG example for Chrome browser history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Browser History",
|
||||
description="Process and query Chrome browser history with LEANN",
|
||||
default_index_name="google_history_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add browser-specific arguments."""
|
||||
browser_group = parser.add_argument_group("Browser Parameters")
|
||||
browser_group.add_argument(
|
||||
"--chrome-profile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--auto-find-profiles",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _get_chrome_base_path(self) -> Path:
|
||||
"""Get the base Chrome profile path based on OS."""
|
||||
if sys.platform == "darwin":
|
||||
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||
elif sys.platform.startswith("linux"):
|
||||
return Path.home() / ".config" / "google-chrome"
|
||||
elif sys.platform == "win32":
|
||||
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||
else:
|
||||
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||
|
||||
def _find_chrome_profiles(self) -> list[Path]:
|
||||
"""Auto-detect all Chrome profiles."""
|
||||
base_path = self._get_chrome_base_path()
|
||||
if not base_path.exists():
|
||||
return []
|
||||
|
||||
profiles = []
|
||||
|
||||
# Check Default profile
|
||||
default_profile = base_path / "Default"
|
||||
if default_profile.exists() and (default_profile / "History").exists():
|
||||
profiles.append(default_profile)
|
||||
|
||||
# Check numbered profiles
|
||||
for item in base_path.iterdir():
|
||||
if item.is_dir() and item.name.startswith("Profile "):
|
||||
if (item / "History").exists():
|
||||
profiles.append(item)
|
||||
|
||||
return profiles
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load browser history and convert to text chunks."""
|
||||
# Determine Chrome profiles
|
||||
if args.chrome_profile and not args.auto_find_profiles:
|
||||
profile_dirs = [Path(args.chrome_profile)]
|
||||
else:
|
||||
print("Auto-detecting Chrome profiles...")
|
||||
profile_dirs = self._find_chrome_profiles()
|
||||
|
||||
# If specific profile given, filter to just that one
|
||||
if args.chrome_profile:
|
||||
profile_path = Path(args.chrome_profile)
|
||||
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found!")
|
||||
print("Please specify --chrome-profile manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
|
||||
# Create reader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
# Process each profile
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per profile
|
||||
max_per_profile = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_profile = remaining
|
||||
|
||||
# Load history
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir),
|
||||
max_count=max_per_profile,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} history entries from this profile")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No browser history found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for browser history RAG
|
||||
print("\n🌐 Browser History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What websites did I visit about machine learning?'")
|
||||
print("- 'Find my search history about programming'")
|
||||
print("- 'What YouTube videos did I watch recently?'")
|
||||
print("- 'Show me websites about travel planning'")
|
||||
print("\nNote: Make sure Chrome is closed before running\n")
|
||||
|
||||
rag = BrowserRAG()
|
||||
asyncio.run(rag.run())
|
||||
44
apps/chunking/__init__.py
Normal file
44
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Unified chunking utilities facade.
|
||||
|
||||
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||
that both repo apps (importing `chunking`) and installed wheels share one
|
||||
single implementation. When running from the repo without installation, it
|
||||
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
detect_code_files,
|
||||
get_language_from_extension,
|
||||
)
|
||||
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||
if leann_src.exists():
|
||||
sys.path.insert(0, str(leann_src))
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
detect_code_files,
|
||||
get_language_from_extension,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
__all__ = [
|
||||
"CODE_EXTENSIONS",
|
||||
"create_ast_chunks",
|
||||
"create_text_chunks",
|
||||
"create_traditional_chunks",
|
||||
"detect_code_files",
|
||||
"get_language_from_extension",
|
||||
]
|
||||
211
apps/code_rag.py
Normal file
211
apps/code_rag.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||
Specialized for code repositories with automatic language detection and
|
||||
optimized chunking parameters.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class CodeRAG(BaseRAGExample):
|
||||
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Code",
|
||||
description="Process and query code repositories with AST-aware chunking",
|
||||
default_index_name="code_index",
|
||||
)
|
||||
# Override defaults for code-specific usage
|
||||
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||
self.max_items_default = -1 # Process all code files by default
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add code-specific arguments."""
|
||||
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||
|
||||
code_group.add_argument(
|
||||
"--repo-dir",
|
||||
type=str,
|
||||
default=".",
|
||||
help="Code repository directory to index (default: current directory)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--include-extensions",
|
||||
nargs="+",
|
||||
default=list(CODE_EXTENSIONS.keys()),
|
||||
help="File extensions to include (default: supported code extensions)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--exclude-dirs",
|
||||
nargs="+",
|
||||
default=[
|
||||
".git",
|
||||
"__pycache__",
|
||||
"node_modules",
|
||||
"venv",
|
||||
".venv",
|
||||
"build",
|
||||
"dist",
|
||||
"target",
|
||||
],
|
||||
help="Directories to exclude from indexing",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--max-file-size",
|
||||
type=int,
|
||||
default=1000000, # 1MB
|
||||
help="Maximum file size in bytes to process (default: 1MB)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--include-comments",
|
||||
action="store_true",
|
||||
help="Include comments in chunking (useful for documentation)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--preserve-imports",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Try to preserve import statements in chunks (default: True)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load code files and convert to AST-aware chunks."""
|
||||
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||
print(f"📁 Including extensions: {args.include_extensions}")
|
||||
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||
|
||||
# Check if repository directory exists
|
||||
repo_path = Path(args.repo_dir)
|
||||
if not repo_path.exists():
|
||||
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||
|
||||
# Load code files with filtering
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
"required_exts": args.include_extensions,
|
||||
"exclude_hidden": True,
|
||||
}
|
||||
|
||||
# Create exclusion filter
|
||||
def file_filter(file_path: str) -> bool:
|
||||
"""Filter out unwanted files and directories."""
|
||||
path = Path(file_path)
|
||||
|
||||
# Check file size
|
||||
try:
|
||||
if path.stat().st_size > args.max_file_size:
|
||||
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Check if in excluded directory
|
||||
for exclude_dir in args.exclude_dirs:
|
||||
if exclude_dir in path.parts:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
try:
|
||||
# Load documents with file filtering
|
||||
documents = SimpleDirectoryReader(
|
||||
args.repo_dir,
|
||||
file_extractor=None, # Use default extractors
|
||||
**reader_kwargs,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Apply custom filtering
|
||||
filtered_docs = []
|
||||
for doc in documents:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_filter(file_path):
|
||||
filtered_docs.append(doc)
|
||||
|
||||
documents = filtered_docs
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading code files: {e}")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
print(
|
||||
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||
)
|
||||
return []
|
||||
|
||||
print(f"✅ Loaded {len(documents)} code files")
|
||||
|
||||
# Show breakdown by language/extension
|
||||
ext_counts = {}
|
||||
for doc in documents:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_path:
|
||||
ext = Path(file_path).suffix.lower()
|
||||
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||
|
||||
print("📊 Files by extension:")
|
||||
for ext, count in sorted(ext_counts.items()):
|
||||
print(f" {ext}: {count} files")
|
||||
|
||||
# Use AST-aware chunking by default for code
|
||||
print(
|
||||
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||
)
|
||||
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=256, # Fallback for non-code files
|
||||
chunk_overlap=64,
|
||||
use_ast_chunking=True, # Always use AST for code RAG
|
||||
ast_chunk_size=args.ast_chunk_size,
|
||||
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||
code_file_extensions=args.include_extensions,
|
||||
ast_fallback_traditional=True,
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for code RAG
|
||||
print("\n💻 Code RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'How does the embedding computation work?'")
|
||||
print("- 'What are the main classes in this codebase?'")
|
||||
print("- 'Show me the search implementation'")
|
||||
print("- 'How is error handling implemented?'")
|
||||
print("- 'What design patterns are used?'")
|
||||
print("- 'Explain the chunking logic'")
|
||||
print("\n🚀 Features:")
|
||||
print("- ✅ AST-aware chunking preserves code structure")
|
||||
print("- ✅ Automatic language detection")
|
||||
print("- ✅ Smart filtering of large files and common excludes")
|
||||
print("- ✅ Optimized for code understanding")
|
||||
print("\nUsage examples:")
|
||||
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||
print(
|
||||
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||
)
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = CodeRAG()
|
||||
asyncio.run(rag.run())
|
||||
131
apps/document_rag.py
Normal file
131
apps/document_rag.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Document RAG example using the unified interface.
|
||||
Supports PDF, TXT, MD, and other document formats.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class DocumentRAG(BaseRAGExample):
|
||||
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Document",
|
||||
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||
default_index_name="test_doc_files",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add document-specific arguments."""
|
||||
doc_group = parser.add_argument_group("Document Parameters")
|
||||
doc_group.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="data",
|
||||
help="Directory containing documents to index (default: data)",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--file-types",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--enable-code-chunking",
|
||||
action="store_true",
|
||||
help="Enable AST-aware chunking for code files in the data directory",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
print(f"Filtering by file types: {args.file_types}")
|
||||
else:
|
||||
print("Processing all supported file types")
|
||||
|
||||
# Check if data directory exists
|
||||
data_path = Path(args.data_dir)
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||
|
||||
# Load documents
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
if args.file_types:
|
||||
reader_kwargs["required_exts"] = args.file_types
|
||||
|
||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} documents")
|
||||
|
||||
# Determine chunking strategy
|
||||
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||
|
||||
if use_ast:
|
||||
print("Using AST-aware chunking for code files")
|
||||
|
||||
# Convert to text chunks with optional AST support
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
use_ast_chunking=use_ast,
|
||||
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for document RAG
|
||||
print("\n📄 Document RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What are the main techniques LEANN uses?'")
|
||||
print("- 'What is the technique DLPM?'")
|
||||
print("- 'Who does Elizabeth Bennet marry?'")
|
||||
print(
|
||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||
)
|
||||
print("\n🚀 NEW: Code-aware chunking available!")
|
||||
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||
print("- Supports Python, Java, C#, TypeScript files")
|
||||
print("- Better semantic understanding of code structure")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = DocumentRAG()
|
||||
asyncio.run(rag.run())
|
||||
167
apps/email_data/LEANN_email_reader.py
Normal file
167
apps/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import email
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
||||
"""
|
||||
Recursively find all 'Messages' directories under the given root.
|
||||
Returns a list of Path objects.
|
||||
"""
|
||||
if root is None:
|
||||
# Auto-detect user's mail path
|
||||
home_dir = os.path.expanduser("~")
|
||||
root = os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
messages_dirs = []
|
||||
for dirpath, _dirnames, _filenames in os.walk(root):
|
||||
if os.path.basename(dirpath) == "Messages":
|
||||
messages_dirs.append(Path(dirpath))
|
||||
return messages_dirs
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with embedded metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self, include_html: bool = False) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
include_html: Whether to include HTML content in the email body (default: False)
|
||||
"""
|
||||
self.include_html = include_html
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
total_files = 0
|
||||
successful_files = 0
|
||||
failed_files = 0
|
||||
|
||||
print(f"Starting to process directory: {input_dir}")
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
# Check if we've reached the max count (skip if max_count == -1)
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
total_files += 1
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if (
|
||||
part.get_content_type() == "text/plain"
|
||||
or part.get_content_type() == "text/html"
|
||||
):
|
||||
if (
|
||||
part.get_content_type() == "text/html"
|
||||
and not self.include_html
|
||||
):
|
||||
continue
|
||||
try:
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
body += payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding payload: {e}")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
payload = msg.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding single part payload: {e}")
|
||||
body = ""
|
||||
|
||||
# Only create document if we have some content
|
||||
if body.strip() or subject != "No Subject":
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
[Subject]: {subject}
|
||||
[Date]: {date}
|
||||
[EMAIL BODY Start]:
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
successful_files += 1
|
||||
|
||||
# Print first few successful files for debugging
|
||||
if successful_files <= 3:
|
||||
print(
|
||||
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print("Processing summary:")
|
||||
print(f" Total .emlx files found: {total_files}")
|
||||
print(f" Successfully loaded: {successful_files}")
|
||||
print(f" Failed to load: {failed_files}")
|
||||
print(f" Final documents: {len(docs)}")
|
||||
|
||||
return docs
|
||||
@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fsspec import AbstractFileSystem
|
||||
from typing import Any
|
||||
|
||||
from fsspec import AbstractFileSystem
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.schema import Document
|
||||
|
||||
@@ -27,11 +27,7 @@ class MboxReader(BaseReader):
|
||||
"""
|
||||
|
||||
DEFAULT_MESSAGE_FORMAT: str = (
|
||||
"Date: {_date}\n"
|
||||
"From: {_from}\n"
|
||||
"To: {_to}\n"
|
||||
"Subject: {_subject}\n"
|
||||
"Content: {_content}"
|
||||
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -45,9 +41,7 @@ class MboxReader(BaseReader):
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
||||
)
|
||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_count = max_count
|
||||
@@ -56,9 +50,9 @@ class MboxReader(BaseReader):
|
||||
def load_data(
|
||||
self,
|
||||
file: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
extra_info: dict | None = None,
|
||||
fs: AbstractFileSystem | None = None,
|
||||
) -> list[Document]:
|
||||
"""Parse file into string."""
|
||||
# Import required libraries
|
||||
import mailbox
|
||||
@@ -74,7 +68,7 @@ class MboxReader(BaseReader):
|
||||
)
|
||||
|
||||
i = 0
|
||||
results: List[str] = []
|
||||
results: list[str] = []
|
||||
# Load file using mailbox
|
||||
bytes_parser = BytesParser(policy=default).parse
|
||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||
@@ -124,7 +118,7 @@ class MboxReader(BaseReader):
|
||||
class EmlxMboxReader(MboxReader):
|
||||
"""
|
||||
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||
|
||||
|
||||
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||
1. Reading .emlx files from a directory
|
||||
2. Converting them to mbox format in memory
|
||||
@@ -134,13 +128,13 @@ class EmlxMboxReader(MboxReader):
|
||||
def load_data(
|
||||
self,
|
||||
directory: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
extra_info: dict | None = None,
|
||||
fs: AbstractFileSystem | None = None,
|
||||
) -> list[Document]:
|
||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
import tempfile
|
||||
|
||||
if fs:
|
||||
logger.warning(
|
||||
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||
@@ -150,37 +144,37 @@ class EmlxMboxReader(MboxReader):
|
||||
# Find all .emlx files in the directory
|
||||
emlx_files = list(directory.glob("*.emlx"))
|
||||
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||
|
||||
|
||||
if not emlx_files:
|
||||
logger.warning(f"No .emlx files found in {directory}")
|
||||
return []
|
||||
|
||||
# Create a temporary mbox file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
||||
temp_mbox_path = temp_mbox.name
|
||||
|
||||
|
||||
# Convert .emlx files to mbox format
|
||||
for emlx_file in emlx_files:
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
|
||||
# .emlx format: first line is length, rest is email content
|
||||
lines = content.split('\n', 1)
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1] # Skip the length line
|
||||
|
||||
|
||||
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Close the temporary file so MboxReader can read it
|
||||
temp_mbox.close()
|
||||
|
||||
|
||||
try:
|
||||
# Use the parent MboxReader's logic to parse the mbox file
|
||||
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||
@@ -188,5 +182,5 @@ class EmlxMboxReader(MboxReader):
|
||||
# Clean up temporary file
|
||||
try:
|
||||
os.unlink(temp_mbox_path)
|
||||
except:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
157
apps/email_rag.py
Normal file
157
apps/email_rag.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Email RAG example using the unified interface.
|
||||
Supports Apple Mail on macOS.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
|
||||
class EmailRAG(BaseRAGExample):
|
||||
"""RAG example for Apple Mail processing."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all emails by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Email",
|
||||
description="Process and query Apple Mail emails with LEANN",
|
||||
default_index_name="mail_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add email-specific arguments."""
|
||||
email_group = parser.add_argument_group("Email Parameters")
|
||||
email_group.add_argument(
|
||||
"--mail-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||
)
|
||||
|
||||
def _find_mail_directories(self) -> list[Path]:
|
||||
"""Auto-detect all Apple Mail directories."""
|
||||
mail_base = Path.home() / "Library" / "Mail"
|
||||
if not mail_base.exists():
|
||||
return []
|
||||
|
||||
# Find all Messages directories
|
||||
messages_dirs = []
|
||||
for item in mail_base.rglob("Messages"):
|
||||
if item.is_dir():
|
||||
messages_dirs.append(item)
|
||||
|
||||
return messages_dirs
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load emails and convert to text chunks."""
|
||||
# Determine mail directories
|
||||
if args.mail_path:
|
||||
messages_dirs = [Path(args.mail_path)]
|
||||
else:
|
||||
print("Auto-detecting Apple Mail directories...")
|
||||
messages_dirs = self._find_mail_directories()
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Apple Mail directories found!")
|
||||
print("Please specify --mail-path manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(messages_dirs)} mail directories")
|
||||
|
||||
# Create reader
|
||||
reader = EmlxReader(include_html=args.include_html)
|
||||
|
||||
# Process each directory
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
# Count emlx files
|
||||
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||
print(f"Found {len(emlx_files)} email files")
|
||||
|
||||
# Apply max_items limit per directory
|
||||
max_per_dir = -1 # Default to process all
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_dir = remaining
|
||||
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||
|
||||
# Load emails - fix the parameter passing
|
||||
documents = reader.load_data(
|
||||
input_dir=str(messages_dir),
|
||||
max_count=max_per_dir,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} emails from this directory")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No emails found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
# Email reader uses chunk_overlap=25 as in original
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||
print(" Windows/Linux support coming soon!\n")
|
||||
|
||||
# Example queries for email RAG
|
||||
print("\n📧 Email RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did my boss say about deadlines?'")
|
||||
print("- 'Find emails about travel expenses'")
|
||||
print("- 'Show me emails from last month about the project'")
|
||||
print("- 'What food did I order from DoorDash?'")
|
||||
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||
|
||||
rag = EmailRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -1,3 +1,3 @@
|
||||
from .history import ChromeHistoryReader
|
||||
|
||||
__all__ = ['ChromeHistoryReader']
|
||||
__all__ = ["ChromeHistoryReader"]
|
||||
@@ -1,122 +1,126 @@
|
||||
import sqlite3
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ChromeHistoryReader(BaseReader):
|
||||
"""
|
||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||
|
||||
|
||||
Reads Chrome history from the default Chrome profile location and creates documents
|
||||
with embedded metadata similar to the email reader structure.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load Chrome history data from the default Chrome profile location.
|
||||
|
||||
|
||||
Args:
|
||||
input_dir: Not used for Chrome history (kept for compatibility)
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of history entries to read.
|
||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
||||
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
||||
|
||||
# Default Chrome profile path on macOS
|
||||
if chrome_profile_path is None:
|
||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
chrome_profile_path = os.path.expanduser(
|
||||
"~/Library/Application Support/Google/Chrome/Default"
|
||||
)
|
||||
|
||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||
|
||||
|
||||
if not os.path.exists(history_db_path):
|
||||
print(f"Chrome history database not found at: {history_db_path}")
|
||||
return docs
|
||||
|
||||
|
||||
try:
|
||||
# Connect to the Chrome history database
|
||||
print(f"Connecting to database: {history_db_path}")
|
||||
conn = sqlite3.connect(history_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
# Query to get browsing history with metadata (removed created_time column)
|
||||
query = """
|
||||
SELECT
|
||||
SELECT
|
||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
hidden
|
||||
FROM urls
|
||||
FROM urls
|
||||
ORDER BY last_visit_time DESC
|
||||
"""
|
||||
|
||||
|
||||
print(f"Executing query on database: {history_db_path}")
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
print(f"Query returned {len(rows)} rows")
|
||||
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||
|
||||
|
||||
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[BROWSING HISTORY METADATA]
|
||||
URL: {url}
|
||||
Title: {title}
|
||||
Last Visit: {last_visit}
|
||||
Visit Count: {visit_count}
|
||||
Typed Count: {typed_count}
|
||||
Hidden: {hidden}
|
||||
[END METADATA]
|
||||
|
||||
Title: {title}
|
||||
URL: {url}
|
||||
Last visited: {last_visit}
|
||||
[Title]: {title}
|
||||
[URL of the page]: {url}
|
||||
[Last visited time]: {last_visit}
|
||||
[Visit times]: {visit_count}
|
||||
[Typed times]: {typed_count}
|
||||
"""
|
||||
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc = Document(text=doc_content, metadata={"title": title[0:150]})
|
||||
# if len(title) > 150:
|
||||
# print(f"Title is too long: {title}")
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
conn.close()
|
||||
print(f"Loaded {len(docs)} Chrome history documents")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading Chrome history: {e}")
|
||||
# add you may need to close your browser to make the database file available
|
||||
# also highlight in red
|
||||
print(
|
||||
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
return docs
|
||||
|
||||
@staticmethod
|
||||
def find_chrome_profiles() -> List[Path]:
|
||||
def find_chrome_profiles() -> list[Path]:
|
||||
"""
|
||||
Find all Chrome profile directories.
|
||||
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to Chrome profile directories
|
||||
"""
|
||||
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
||||
profile_dirs = []
|
||||
|
||||
|
||||
if not chrome_base_path.exists():
|
||||
print(f"Chrome directory not found at: {chrome_base_path}")
|
||||
return profile_dirs
|
||||
|
||||
|
||||
# Find all profile directories
|
||||
for profile_dir in chrome_base_path.iterdir():
|
||||
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
||||
@@ -124,53 +128,59 @@ Last visited: {last_visit}
|
||||
if history_path.exists():
|
||||
profile_dirs.append(profile_dir)
|
||||
print(f"Found Chrome profile: {profile_dir}")
|
||||
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
return profile_dirs
|
||||
|
||||
@staticmethod
|
||||
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
||||
def export_history_to_file(
|
||||
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
||||
):
|
||||
"""
|
||||
Export Chrome history to a text file using the same SQL query format.
|
||||
|
||||
|
||||
Args:
|
||||
output_file: Path to the output file
|
||||
max_count: Maximum number of entries to export
|
||||
"""
|
||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
chrome_profile_path = os.path.expanduser(
|
||||
"~/Library/Application Support/Google/Chrome/Default"
|
||||
)
|
||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||
|
||||
|
||||
if not os.path.exists(history_db_path):
|
||||
print(f"Chrome history database not found at: {history_db_path}")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(history_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
SELECT
|
||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
hidden
|
||||
FROM urls
|
||||
FROM urls
|
||||
ORDER BY last_visit_time DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
|
||||
cursor.execute(query, (max_count,))
|
||||
rows = cursor.fetchall()
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
||||
|
||||
f.write(
|
||||
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
||||
)
|
||||
|
||||
conn.close()
|
||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error exporting Chrome history: {e}")
|
||||
print(f"Error exporting Chrome history: {e}")
|
||||
File diff suppressed because it is too large
Load Diff
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
WeChat History RAG example using the unified interface.
|
||||
Supports WeChat chat history export and search.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
|
||||
from .history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
|
||||
class WeChatRAG(BaseRAGExample):
|
||||
"""RAG example for WeChat chat history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Match original default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="WeChat History",
|
||||
description="Process and query WeChat chat history with LEANN",
|
||||
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add WeChat-specific arguments."""
|
||||
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||
wechat_group.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default="./wechat_export",
|
||||
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||
)
|
||||
|
||||
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||
"""Export WeChat data using wechattweak-cli."""
|
||||
print("Exporting WeChat data...")
|
||||
|
||||
# Check if WeChat is running
|
||||
try:
|
||||
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print("WeChat is not running. Please start WeChat first.")
|
||||
return False
|
||||
except Exception:
|
||||
pass # pgrep might not be available on all systems
|
||||
|
||||
# Create export directory
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run export command
|
||||
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||
|
||||
try:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("WeChat data exported successfully!")
|
||||
return True
|
||||
else:
|
||||
print(f"Export failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
except FileNotFoundError:
|
||||
print("\nError: wechattweak-cli not found!")
|
||||
print("Please install it first:")
|
||||
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Export error: {e}")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load WeChat history and convert to text chunks."""
|
||||
# Initialize WeChat reader with export capabilities
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||
# Try to find any existing exports in common locations
|
||||
export_dirs = reader.find_wechat_export_dirs()
|
||||
if not export_dirs:
|
||||
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||
return []
|
||||
|
||||
# Load documents from all found export directories
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per export
|
||||
max_per_export = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_export = remaining
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_per_export,
|
||||
concatenate_messages=True, # Enable message concatenation for better context
|
||||
)
|
||||
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return []
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks with contact information
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
text_splitter = SentenceSplitter(
|
||||
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
# Add contact information to each chunk
|
||||
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||
print(" You can still query existing exports on other platforms\n")
|
||||
|
||||
# Example queries for WeChat RAG
|
||||
print("\n💬 WeChat History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'Show me conversations about travel plans'")
|
||||
print("- 'Find group chats about weekend activities'")
|
||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||
print("- 'What did we discuss about the project last month?'")
|
||||
print("\nNote: WeChat must be running for export to work\n")
|
||||
|
||||
rag = WeChatRAG()
|
||||
asyncio.run(rag.run())
|
||||
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
BIN
assets/wechat_user_group.JPG
Normal file
BIN
assets/wechat_user_group.JPG
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 152 KiB |
@@ -1,13 +1,28 @@
|
||||
# 🧪 Leann Sanity Checks
|
||||
# 🧪 LEANN Benchmarks & Testing
|
||||
|
||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||
|
||||
## 📁 Test Files
|
||||
|
||||
### `diskann_vs_hnsw_speed_comparison.py`
|
||||
Performance comparison between DiskANN and HNSW backends:
|
||||
- ✅ **Search latency** comparison with both backends using recompute
|
||||
- ✅ **Index size** and **build time** measurements
|
||||
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||
- ✅ **Configurable dataset sizes** for different scales
|
||||
|
||||
```bash
|
||||
# Quick comparison with 500 docs, 10 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||
|
||||
# Large-scale comparison with 2000 docs, 20 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||
```
|
||||
|
||||
### `test_distance_functions.py`
|
||||
Tests all supported distance functions across DiskANN backend:
|
||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **Cosine** (Cosine Similarity)
|
||||
|
||||
```bash
|
||||
@@ -27,7 +42,7 @@ uv run python tests/sanity_checks/test_l2_verification.py
|
||||
### `test_sanity_check.py`
|
||||
Comprehensive end-to-end verification including:
|
||||
- Distance function testing
|
||||
- Embedding model compatibility
|
||||
- Embedding model compatibility
|
||||
- Search result correctness validation
|
||||
- Backend integration testing
|
||||
|
||||
@@ -64,7 +79,7 @@ When all tests pass, you should see:
|
||||
```
|
||||
📊 测试结果总结:
|
||||
mips : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
cosine : ✅ 通过
|
||||
|
||||
🎉 测试完成!
|
||||
@@ -98,7 +113,7 @@ pkill -f "embedding_server"
|
||||
|
||||
### Typical Timing (3 documents, consumer hardware):
|
||||
- **Index Building**: 2-5 seconds per distance function
|
||||
- **Search Query**: 50-200ms
|
||||
- **Search Query**: 50-200ms
|
||||
- **Recompute Mode**: 5-15 seconds (higher accuracy)
|
||||
|
||||
### Memory Usage:
|
||||
@@ -117,4 +132,4 @@ These tests are designed to be run in automated environments:
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
@@ -1,43 +1,46 @@
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from mlx_lm import load
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||
|
||||
# --- Generate Dummy Data ---
|
||||
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||
|
||||
# --- Benchmark Functions ---b
|
||||
|
||||
|
||||
def benchmark_torch(model, sentences):
|
||||
start_time = time.time()
|
||||
model.encode(sentences, convert_to_numpy=True)
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
def benchmark_mlx(model, tokenizer, sentences):
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Tokenize sentences using MLX tokenizer
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
token_ids = tokenizer.encode(sentence)
|
||||
tokens.append(token_ids)
|
||||
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_len = max(len(t) for t in tokens)
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
|
||||
|
||||
for token_seq in tokens:
|
||||
# Pad sequence
|
||||
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
|
||||
@@ -45,24 +48,25 @@ def benchmark_mlx(model, tokenizer, sentences):
|
||||
# Create attention mask (1 for real tokens, 0 for padding)
|
||||
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
|
||||
attention_mask.append(mask)
|
||||
|
||||
|
||||
# Convert to MLX arrays
|
||||
input_ids = mx.array(input_ids)
|
||||
attention_mask = mx.array(attention_mask)
|
||||
|
||||
|
||||
# Get embeddings
|
||||
embeddings = model(input_ids)
|
||||
|
||||
|
||||
# Mean pooling
|
||||
mask = mx.expand_dims(attention_mask, -1)
|
||||
sum_embeddings = (embeddings * mask).sum(axis=1)
|
||||
sum_mask = mask.sum(axis=1)
|
||||
_ = sum_embeddings / sum_mask
|
||||
|
||||
|
||||
mx.eval() # Ensure computation is finished
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
def main():
|
||||
print("--- Initializing Models ---")
|
||||
@@ -92,13 +96,15 @@ def main():
|
||||
for batch_size in BATCH_SIZES:
|
||||
print(f"Benchmarking batch size: {batch_size}")
|
||||
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
||||
|
||||
|
||||
# Benchmark PyTorch
|
||||
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
results_torch.append(np.mean(torch_times))
|
||||
|
||||
|
||||
# Benchmark MLX
|
||||
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
mlx_times = [
|
||||
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
||||
]
|
||||
results_mlx.append(np.mean(mlx_times))
|
||||
|
||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||
@@ -109,20 +115,27 @@ def main():
|
||||
# --- Plotting ---
|
||||
print("\n--- Generating Plot ---")
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
|
||||
plt.plot(
|
||||
BATCH_SIZES,
|
||||
results_torch,
|
||||
marker="o",
|
||||
linestyle="-",
|
||||
label=f"PyTorch ({device})",
|
||||
)
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||
|
||||
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
|
||||
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||
plt.xlabel("Batch Size")
|
||||
plt.ylabel("Average Time per Batch (ms)")
|
||||
plt.xticks(BATCH_SIZES)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
|
||||
# Save the plot
|
||||
output_filename = "embedding_benchmark.png"
|
||||
plt.savefig(output_filename)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
148
benchmarks/benchmark_no_recompute.py
Normal file
148
benchmarks/benchmark_no_recompute.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
def _meta_exists(index_path: str) -> bool:
|
||||
p = Path(index_path)
|
||||
return (p.parent / f"{p.stem}.meta.json").exists()
|
||||
|
||||
|
||||
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
||||
# if _meta_exists(index_path):
|
||||
# return
|
||||
kwargs = {}
|
||||
if backend_name == "hnsw":
|
||||
kwargs["is_compact"] = is_recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
||||
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=is_recompute,
|
||||
num_threads=4,
|
||||
**kwargs,
|
||||
)
|
||||
for i in range(num_docs):
|
||||
builder.add_text(
|
||||
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
||||
)
|
||||
builder.build_index(index_path)
|
||||
|
||||
|
||||
def _bench_group(
|
||||
index_path: str,
|
||||
recompute: bool,
|
||||
query: str,
|
||||
repeats: int,
|
||||
complexity: int = 32,
|
||||
top_k: int = 10,
|
||||
) -> float:
|
||||
# Independent searcher per group; fixed port when recompute
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
|
||||
# Warm-up once
|
||||
_ = searcher.search(
|
||||
query,
|
||||
top_k=top_k,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute,
|
||||
)
|
||||
|
||||
def _once() -> float:
|
||||
t0 = time.time()
|
||||
_ = searcher.search(
|
||||
query,
|
||||
top_k=top_k,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute,
|
||||
)
|
||||
return time.time() - t0
|
||||
|
||||
if repeats <= 1:
|
||||
t = _once()
|
||||
else:
|
||||
vals = [_once() for _ in range(repeats)]
|
||||
vals.sort()
|
||||
t = vals[len(vals) // 2]
|
||||
|
||||
searcher.cleanup()
|
||||
return t
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-docs", type=int, default=5000)
|
||||
parser.add_argument("--repeats", type=int, default=3)
|
||||
parser.add_argument("--complexity", type=int, default=32)
|
||||
args = parser.parse_args()
|
||||
|
||||
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
||||
base.parent.mkdir(parents=True, exist_ok=True)
|
||||
# ---------- Build HNSW variants ----------
|
||||
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
||||
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
||||
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
||||
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
||||
|
||||
# ---------- Build DiskANN variants ----------
|
||||
diskann_r = str(base / "diskann_r.leann")
|
||||
diskann_nr = str(base / "diskann_nr.leann")
|
||||
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
||||
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
||||
|
||||
# ---------- Helpers ----------
|
||||
def _size_for(prefix: str) -> int:
|
||||
p = Path(prefix)
|
||||
base_dir = p.parent
|
||||
stem = p.stem
|
||||
total = 0
|
||||
for f in base_dir.iterdir():
|
||||
if f.is_file() and f.name.startswith(stem):
|
||||
total += f.stat().st_size
|
||||
return total
|
||||
|
||||
# ---------- HNSW benchmark ----------
|
||||
t_hnsw_r = _bench_group(
|
||||
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||
)
|
||||
t_hnsw_nr = _bench_group(
|
||||
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||
)
|
||||
size_hnsw_r = _size_for(hnsw_r)
|
||||
size_hnsw_nr = _size_for(hnsw_nr)
|
||||
|
||||
print("Benchmark results (HNSW):")
|
||||
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
||||
print(
|
||||
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
||||
)
|
||||
print(" Expectation: no-recompute should be faster but larger on disk.")
|
||||
|
||||
# ---------- DiskANN benchmark ----------
|
||||
t_diskann_r = _bench_group(
|
||||
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
||||
)
|
||||
t_diskann_nr = _bench_group(
|
||||
diskann_nr,
|
||||
False,
|
||||
"DiskANN NR test doc 123",
|
||||
repeats=args.repeats,
|
||||
complexity=args.complexity,
|
||||
)
|
||||
size_diskann_r = _size_for(diskann_r)
|
||||
size_diskann_nr = _size_for(diskann_nr)
|
||||
|
||||
print("\nBenchmark results (DiskANN):")
|
||||
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
||||
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
||||
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
||||
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,14 +3,15 @@
|
||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import psutil
|
||||
import gc
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import psutil
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Setup logging
|
||||
@@ -61,7 +62,7 @@ def test_faiss_hnsw():
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "examples/faiss_only.py"],
|
||||
[sys.executable, "benchmarks/faiss_only.py"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
@@ -83,9 +84,7 @@ def test_faiss_hnsw():
|
||||
|
||||
for line in lines:
|
||||
if "Peak Memory:" in line:
|
||||
peak_memory = float(
|
||||
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
||||
)
|
||||
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||
|
||||
return {"peak_memory": peak_memory}
|
||||
|
||||
@@ -111,13 +110,12 @@ def test_leann_hnsw():
|
||||
|
||||
tracker.checkpoint("After imports")
|
||||
|
||||
from leann.api import LeannBuilder
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
# Load and parse documents
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -135,6 +133,7 @@ def test_leann_hnsw():
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
print(f"Total number of chunks: {len(all_texts)}")
|
||||
|
||||
tracker.checkpoint("After text chunking")
|
||||
|
||||
@@ -196,16 +195,14 @@ def test_leann_hnsw():
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
|
||||
# Load searcher
|
||||
searcher = LeannSearcher(index_path)
|
||||
tracker.checkpoint("After searcher loading")
|
||||
|
||||
|
||||
|
||||
print("Running search queries...")
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
@@ -303,21 +300,15 @@ def main():
|
||||
|
||||
print("\nLEANN vs Faiss Performance:")
|
||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||
print(
|
||||
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
||||
)
|
||||
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||
|
||||
# Storage comparison
|
||||
if leann_storage_size > faiss_storage_size:
|
||||
storage_ratio = leann_storage_size / faiss_storage_size
|
||||
print(
|
||||
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
||||
)
|
||||
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||
elif faiss_storage_size > leann_storage_size:
|
||||
storage_ratio = faiss_storage_size / leann_storage_size
|
||||
print(
|
||||
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
||||
)
|
||||
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||
else:
|
||||
print(" Storage Size: similar")
|
||||
else:
|
||||
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiskANN vs HNSW Search Performance Comparison
|
||||
|
||||
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||
- HNSW: With recompute enabled (is_recompute=True)
|
||||
- Tests performance across different dataset sizes
|
||||
- Measures search latency, recall, and index size
|
||||
"""
|
||||
|
||||
import gc
|
||||
import multiprocessing as mp
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
||||
try:
|
||||
mp.set_start_method("fork", force=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def create_test_texts(n_docs: int) -> list[str]:
|
||||
"""Create synthetic test documents for benchmarking."""
|
||||
np.random.seed(42)
|
||||
topics = [
|
||||
"machine learning and artificial intelligence",
|
||||
"natural language processing and text analysis",
|
||||
"computer vision and image recognition",
|
||||
"data science and statistical analysis",
|
||||
"deep learning and neural networks",
|
||||
"information retrieval and search engines",
|
||||
"database systems and data management",
|
||||
"software engineering and programming",
|
||||
"cybersecurity and network protection",
|
||||
"cloud computing and distributed systems",
|
||||
]
|
||||
|
||||
texts = []
|
||||
for i in range(n_docs):
|
||||
topic = topics[i % len(topics)]
|
||||
variation = np.random.randint(1, 100)
|
||||
text = (
|
||||
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||
f"Additional information about {topic} with details and examples. "
|
||||
f"Technical discussion of {topic} including implementation aspects."
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
def benchmark_backend(
|
||||
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||
) -> dict[str, float]:
|
||||
"""Benchmark a specific backend with the given configuration."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||
|
||||
# Build index
|
||||
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||
start_time = time.time()
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
**backend_kwargs,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
build_time = time.time() - start_time
|
||||
|
||||
# Measure index size
|
||||
index_dir = Path(index_path).parent
|
||||
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
|
||||
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||
|
||||
# Search benchmark
|
||||
print("🔍 Running search benchmark...")
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
search_times = []
|
||||
all_results = []
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=5)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
all_results.append(results)
|
||||
|
||||
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||
|
||||
# Check for valid scores (detect -inf issues)
|
||||
all_scores = [
|
||||
result.score
|
||||
for results in all_results
|
||||
for result in results
|
||||
if result.score is not None
|
||||
]
|
||||
valid_scores = [
|
||||
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||
]
|
||||
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||
|
||||
# Clean up (ensure embedding server shutdown and object GC)
|
||||
try:
|
||||
if hasattr(searcher, "cleanup"):
|
||||
searcher.cleanup()
|
||||
del searcher
|
||||
del builder
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||
|
||||
return {
|
||||
"build_time": build_time,
|
||||
"avg_search_time_ms": avg_search_time,
|
||||
"index_size_mb": size_mb,
|
||||
"score_validity_rate": score_validity_rate,
|
||||
}
|
||||
|
||||
|
||||
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||
"""Run performance comparison between DiskANN and HNSW."""
|
||||
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||
|
||||
# Create test data
|
||||
texts = create_test_texts(n_docs)
|
||||
test_queries = [
|
||||
"machine learning algorithms",
|
||||
"natural language processing",
|
||||
"computer vision techniques",
|
||||
"data analysis methods",
|
||||
"neural network architectures",
|
||||
"database query optimization",
|
||||
"software development practices",
|
||||
"security vulnerabilities",
|
||||
"cloud infrastructure",
|
||||
"distributed computing",
|
||||
][:n_queries]
|
||||
|
||||
# HNSW benchmark
|
||||
hnsw_results = benchmark_backend(
|
||||
backend_name="hnsw",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable recompute for fair comparison
|
||||
"M": 16,
|
||||
"efConstruction": 200,
|
||||
},
|
||||
)
|
||||
|
||||
# DiskANN benchmark
|
||||
diskann_results = benchmark_backend(
|
||||
backend_name="diskann",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable graph partitioning
|
||||
"num_neighbors": 32,
|
||||
"search_list_size": 50,
|
||||
},
|
||||
)
|
||||
|
||||
# Performance comparison
|
||||
print("\n📈 Performance Comparison Results")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||
print(f"{'-' * 60}")
|
||||
|
||||
# Build time comparison
|
||||
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||
print(
|
||||
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Search time comparison
|
||||
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||
print(
|
||||
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Index size comparison
|
||||
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||
print(
|
||||
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||
)
|
||||
|
||||
# Score validity
|
||||
print(
|
||||
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||
)
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print("\n🎯 Summary:")
|
||||
if search_speedup > 1:
|
||||
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||
else:
|
||||
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||
|
||||
if size_ratio > 1:
|
||||
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||
else:
|
||||
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||
|
||||
print(
|
||||
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
try:
|
||||
# Handle help request
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||
print()
|
||||
print("Arguments:")
|
||||
print(" n_docs Number of documents to index (default: 500)")
|
||||
print(" n_queries Number of test queries to run (default: 10)")
|
||||
print()
|
||||
print("Examples:")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||
sys.exit(0)
|
||||
|
||||
# Parse command line arguments
|
||||
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||
print()
|
||||
|
||||
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Benchmark interrupted by user")
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Benchmark failed: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
||||
try:
|
||||
gc.collect()
|
||||
print("\n🧹 Cleanup completed")
|
||||
# Flush stdio to ensure message is visible before hard-exit
|
||||
try:
|
||||
import sys as _sys
|
||||
|
||||
_sys.stdout.flush()
|
||||
_sys.stderr.flush()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
||||
import os as _os
|
||||
|
||||
_os._exit(0)
|
||||
@@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test only Faiss HNSW"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import psutil
|
||||
import gc
|
||||
import os
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
@@ -37,20 +37,20 @@ def main():
|
||||
import faiss
|
||||
except ImportError:
|
||||
print("Faiss is not installed.")
|
||||
print("Please install it with `uv pip install faiss-cpu`")
|
||||
print(
|
||||
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
from llama_index.core import (
|
||||
SimpleDirectoryReader,
|
||||
VectorStoreIndex,
|
||||
StorageContext,
|
||||
Settings,
|
||||
node_parser,
|
||||
Document,
|
||||
SimpleDirectoryReader,
|
||||
StorageContext,
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
tracker = MemoryTracker("Faiss HNSW")
|
||||
tracker.checkpoint("Initial")
|
||||
@@ -65,7 +65,7 @@ def main():
|
||||
tracker.checkpoint("After Faiss index creation")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -90,8 +90,9 @@ def main():
|
||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||
)
|
||||
from llama_index.core import load_index_from_storage
|
||||
|
||||
index = load_index_from_storage(storage_context=storage_context)
|
||||
print(f"Index loaded from ./storage_faiss")
|
||||
print("Index loaded from ./storage_faiss")
|
||||
tracker.checkpoint("After loading existing index")
|
||||
index_loaded = True
|
||||
except Exception as e:
|
||||
@@ -99,19 +100,18 @@ def main():
|
||||
print("Cleaning up corrupted index and building new one...")
|
||||
# Clean up corrupted index
|
||||
import shutil
|
||||
|
||||
if os.path.exists("./storage_faiss"):
|
||||
shutil.rmtree("./storage_faiss")
|
||||
|
||||
|
||||
if not index_loaded:
|
||||
print("Building new Faiss HNSW index...")
|
||||
|
||||
|
||||
# Use the correct Faiss building pattern from the example
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
transformations=[node_parser]
|
||||
documents, storage_context=storage_context, transformations=[node_parser]
|
||||
)
|
||||
tracker.checkpoint("After index building")
|
||||
|
||||
@@ -124,10 +124,10 @@ def main():
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
|
||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
@@ -141,7 +141,7 @@ def main():
|
||||
|
||||
runtime_end_mem = get_memory_usage()
|
||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||
|
||||
|
||||
peak_memory = tracker.summary()
|
||||
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||
@@ -2,20 +2,20 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: List[int]
|
||||
batch_sizes: list[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
@@ -28,47 +28,45 @@ class BenchmarkConfig:
|
||||
|
||||
class GraphContainer:
|
||||
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
|
||||
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.graphs: Dict[int, 'GraphWrapper'] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> 'GraphWrapper':
|
||||
self.graphs: dict[int, GraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> "GraphWrapper":
|
||||
if batch_size not in self.graphs:
|
||||
self.graphs[batch_size] = GraphWrapper(
|
||||
self.model, batch_size, self.seq_length
|
||||
)
|
||||
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
|
||||
return self.graphs[batch_size]
|
||||
|
||||
|
||||
class GraphWrapper:
|
||||
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
|
||||
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.device = self._get_device()
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
|
||||
# Only use CUDA graphs on NVIDIA GPUs
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'CUDAGraph'):
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
self.use_cuda_graph = True
|
||||
else:
|
||||
# For MPS or CPU, just store the model
|
||||
self.use_cuda_graph = False
|
||||
self.static_output = None
|
||||
|
||||
|
||||
def _get_device(self) -> str:
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
@@ -76,22 +74,20 @@ class GraphWrapper:
|
||||
return "mps"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
|
||||
)
|
||||
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_cuda_graph:
|
||||
self.static_input.copy_(input_ids)
|
||||
@@ -105,14 +101,14 @@ class GraphWrapper:
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Cannot optimize None model")
|
||||
|
||||
|
||||
# Move to GPU
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
@@ -124,53 +120,59 @@ class ModelOptimizer:
|
||||
model = model.cpu()
|
||||
device = "cpu"
|
||||
print(f"- Model moved to {device}")
|
||||
|
||||
|
||||
# FP16
|
||||
if config.use_fp16 and not config.use_int4:
|
||||
model = model.half()
|
||||
# use torch compile
|
||||
model = torch.compile(model)
|
||||
print("- Using FP16 precision")
|
||||
|
||||
|
||||
# Check if using SDPA (only on CUDA)
|
||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and float(torch.version.cuda[:3]) >= 11.6
|
||||
):
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
|
||||
# Flash Attention (only on CUDA)
|
||||
if config.use_flash_attention and torch.cuda.is_available():
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
from flash_attn.flash_attention import FlashAttention # noqa: F401
|
||||
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
|
||||
# Memory efficient attention (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using GPU events or CPU timing."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
if torch.cuda.is_available():
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -182,7 +184,7 @@ class Timer:
|
||||
else:
|
||||
# CPU timing
|
||||
self.use_gpu_timing = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
if self.use_gpu_timing:
|
||||
@@ -195,7 +197,7 @@ class Timer:
|
||||
start_time = time.time()
|
||||
yield
|
||||
self.cpu_elapsed = time.time() - start_time
|
||||
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
if self.use_gpu_timing:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
@@ -205,14 +207,14 @@ class Timer:
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
try:
|
||||
self.model = self._load_model()
|
||||
if self.model is None:
|
||||
raise ValueError("Model initialization failed - model is None")
|
||||
|
||||
|
||||
# Only use CUDA graphs on NVIDIA GPUs
|
||||
if config.use_cuda_graphs and torch.cuda.is_available():
|
||||
self.graphs = GraphContainer(self.model, config.seq_length)
|
||||
@@ -220,25 +222,27 @@ class Benchmark:
|
||||
self.graphs = None
|
||||
self.timer = Timer()
|
||||
except Exception as e:
|
||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
||||
print(f"ERROR in benchmark initialization: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
|
||||
try:
|
||||
# Int4 quantization using HuggingFace integration
|
||||
if self.config.use_int4:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||
|
||||
# 检查是否使用自定义的8bit量化
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
|
||||
# Check if using custom 8bit quantization
|
||||
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||
|
||||
# 加载原始模型(不使用量化配置)
|
||||
|
||||
# Load original model (without quantization config)
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# set default to half
|
||||
torch.set_default_dtype(torch.float16)
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
@@ -246,112 +250,121 @@ class Benchmark:
|
||||
self.config.model_path,
|
||||
torch_dtype=compute_dtype,
|
||||
)
|
||||
|
||||
# 定义替换函数
|
||||
|
||||
# Define replacement function
|
||||
def replace_linear_with_linear8bitlt(model):
|
||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
||||
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
|
||||
for name, module in list(model.named_children()):
|
||||
if isinstance(module, nn.Linear):
|
||||
# 获取原始线性层的参数
|
||||
# Get original linear layer parameters
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
|
||||
# 创建8bit线性层
|
||||
|
||||
# Create 8bit linear layer
|
||||
# print size
|
||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False,
|
||||
)
|
||||
|
||||
# 复制权重和偏置
|
||||
|
||||
# Copy weights and bias
|
||||
new_module.weight.data = module.weight.data
|
||||
if bias:
|
||||
new_module.bias.data = module.bias.data
|
||||
|
||||
# 替换模块
|
||||
|
||||
# Replace module
|
||||
setattr(model, name, new_module)
|
||||
else:
|
||||
# 递归处理子模块
|
||||
# Process child modules recursively
|
||||
replace_linear_with_linear8bitlt(module)
|
||||
|
||||
|
||||
return model
|
||||
|
||||
# 替换所有线性层
|
||||
|
||||
# Replace all linear layers
|
||||
model = replace_linear_with_linear8bitlt(model)
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# 将模型移到GPU(量化发生在这里)
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
# Move model to GPU (quantization happens here)
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
model = model.to(device)
|
||||
|
||||
|
||||
print("- All linear layers replaced with Linear8bitLt")
|
||||
|
||||
|
||||
else:
|
||||
# 使用原来的Int4量化方法
|
||||
# Use original Int4 quantization method
|
||||
print("- Using bitsandbytes for Int4 quantization")
|
||||
|
||||
|
||||
# Create quantization config
|
||||
|
||||
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
|
||||
print("- Quantization config:", quantization_config)
|
||||
|
||||
|
||||
# Load model directly with quantization config
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto" # Let HF decide on device mapping
|
||||
device_map="auto", # Let HF decide on device mapping
|
||||
)
|
||||
|
||||
|
||||
# Check if model loaded successfully
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
|
||||
# Apply optimizations directly here
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
|
||||
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||
else:
|
||||
# Skip moving to GPU since device_map="auto" already did that
|
||||
print("- Model already on GPU due to device_map='auto'")
|
||||
|
||||
|
||||
# Skip FP16 conversion since we specified compute_dtype
|
||||
print(f"- Using {compute_dtype} for compute dtype")
|
||||
|
||||
|
||||
# Check CUDA and SDPA
|
||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and float(torch.version.cuda[:3]) >= 11.6
|
||||
):
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
|
||||
# Try xformers if available (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
@@ -365,76 +378,83 @@ class Benchmark:
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
)
|
||||
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto"
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
|
||||
else:
|
||||
# Standard loading for FP16/FP32
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
print("- Model loaded in standard precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
|
||||
# Apply standard optimizations
|
||||
# set default to half
|
||||
import torch
|
||||
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = ModelOptimizer.optimize(model, self.config)
|
||||
model = model.half()
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
|
||||
# Final check to ensure model is not None
|
||||
if model is None:
|
||||
raise ValueError("Model is None after optimization")
|
||||
|
||||
|
||||
print(f"- Final model type: {type(model)}")
|
||||
return model
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR loading model: {str(e)}")
|
||||
print(f"ERROR loading model: {e!s}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=device,
|
||||
dtype=torch.long
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
def _run_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
graph_wrapper: Optional[GraphWrapper] = None
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
|
||||
) -> tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if graph_wrapper is not None:
|
||||
output = graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
|
||||
# Reset peak memory stats
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@@ -443,22 +463,20 @@ class Benchmark:
|
||||
pass
|
||||
else:
|
||||
print("- No GPU memory stats available")
|
||||
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
|
||||
# Get or create graph for this batch size
|
||||
graph_wrapper = (
|
||||
self.graphs.get_or_create(batch_size)
|
||||
if self.graphs is not None
|
||||
else None
|
||||
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
|
||||
)
|
||||
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
@@ -469,44 +487,44 @@ class Benchmark:
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
|
||||
if not times:
|
||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
||||
continue
|
||||
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
|
||||
# Log memory usage
|
||||
if torch.cuda.is_available():
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
elif torch.backends.mps.is_available():
|
||||
# MPS doesn't have max_memory_allocated, use 0
|
||||
peak_memory_gb = 0.0
|
||||
else:
|
||||
peak_memory_gb = 0.0
|
||||
print("- No GPU memory usage available")
|
||||
|
||||
|
||||
if peak_memory_gb > 0:
|
||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
||||
else:
|
||||
print("\n- GPU memory usage not available")
|
||||
|
||||
|
||||
# Add memory info to results
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -566,14 +584,14 @@ def main():
|
||||
action="store_true",
|
||||
help="Enable Linear8bitLt quantization for all linear layers",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Print arguments for debugging
|
||||
print("\nCommand line arguments:")
|
||||
for arg, value in vars(args).items():
|
||||
print(f"- {arg}: {value}")
|
||||
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
@@ -586,45 +604,56 @@ def main():
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
use_linear8bitlt=args.use_linear8bitlt,
|
||||
)
|
||||
|
||||
|
||||
# Print configuration for debugging
|
||||
print("\nBenchmark configuration:")
|
||||
for field, value in vars(config).items():
|
||||
print(f"- {field}: {value}")
|
||||
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
|
||||
# Save results to file
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
# Create results directory if it doesn't exist
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
|
||||
# Generate filename based on configuration
|
||||
precision_type = "int4" if config.use_int4 else "int8" if config.use_int8 else "fp16" if config.use_fp16 else "fp32"
|
||||
precision_type = (
|
||||
"int4"
|
||||
if config.use_int4
|
||||
else "int8"
|
||||
if config.use_int8
|
||||
else "fp16"
|
||||
if config.use_fp16
|
||||
else "fp32"
|
||||
)
|
||||
model_name = os.path.basename(config.model_path)
|
||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||
|
||||
|
||||
# Save results
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
||||
"results": {str(k): v for k, v in results.items()}
|
||||
},
|
||||
f,
|
||||
indent=2
|
||||
"config": {
|
||||
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
|
||||
},
|
||||
"results": {str(k): v for k, v in results.items()},
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
@@ -5,24 +5,21 @@ It correctly compares results by fetching the text content for both the new sear
|
||||
results and the golden standard results, making the comparison robust to ID changes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from leann.api import LeannSearcher, LeannBuilder
|
||||
import numpy as np
|
||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
|
||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
if not data_root.exists():
|
||||
print(f"Data directory '{data_root}' not found.")
|
||||
print(
|
||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
||||
)
|
||||
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@@ -63,7 +60,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||
"""Download embeddings files specifically."""
|
||||
embeddings_dir = data_root / "embeddings"
|
||||
|
||||
@@ -101,7 +98,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||
|
||||
|
||||
# --- Helper Function to get Golden Passages ---
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||
"""
|
||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||
passage manager.
|
||||
@@ -113,24 +110,20 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||
golden_texts.add(passage_data["text"])
|
||||
except KeyError:
|
||||
print(
|
||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
||||
)
|
||||
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||
return golden_texts
|
||||
|
||||
|
||||
def load_queries(file_path: Path) -> List[str]:
|
||||
def load_queries(file_path: Path) -> list[str]:
|
||||
queries = []
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
queries.append(data["query"])
|
||||
return queries
|
||||
|
||||
|
||||
def build_index_from_embeddings(
|
||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||
):
|
||||
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||
"""
|
||||
Build a LEANN index from pre-computed embeddings.
|
||||
|
||||
@@ -173,9 +166,7 @@ def build_index_from_embeddings(
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run recall evaluation on a LEANN index."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||
parser.add_argument(
|
||||
"index_path",
|
||||
type=str,
|
||||
@@ -202,26 +193,41 @@ def main():
|
||||
parser.add_argument(
|
||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
||||
)
|
||||
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||
parser.add_argument(
|
||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Batch size for HNSW batched search (0 disables batching)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-type",
|
||||
type=str,
|
||||
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||
default="ollama",
|
||||
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default="qwen3:1.7b",
|
||||
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'examples/'
|
||||
# and data is in 'data/' at the project root.
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
# Assumes a project structure where the script is in 'benchmarks/'
|
||||
# and evaluation data is in 'benchmarks/data/'.
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
data_root = script_dir / "data"
|
||||
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
# For building mode, we need embeddings
|
||||
download_data_if_needed(
|
||||
data_root, download_embeddings=False
|
||||
) # Basic data first
|
||||
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||
|
||||
# Auto-detect dataset type and download embeddings
|
||||
if args.embeddings_file:
|
||||
@@ -262,9 +268,7 @@ def main():
|
||||
print(f"Index built successfully: {built_index_path}")
|
||||
|
||||
# Ask if user wants to run evaluation
|
||||
eval_response = (
|
||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
)
|
||||
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
if eval_response != "y":
|
||||
print("Index building complete. Exiting.")
|
||||
return
|
||||
@@ -293,11 +297,9 @@ def main():
|
||||
break
|
||||
|
||||
if not args.index_path:
|
||||
print("No indices found. The data download should have included pre-built indices.")
|
||||
print(
|
||||
"No indices found. The data download should have included pre-built indices."
|
||||
)
|
||||
print(
|
||||
"Please check the data/indices/ directory or provide --index-path manually."
|
||||
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -310,14 +312,10 @@ def main():
|
||||
else:
|
||||
# Fallback: try to infer from the index directory name
|
||||
dataset_type = Path(args.index_path).name
|
||||
print(
|
||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
||||
)
|
||||
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||
|
||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||
golden_results_file = (
|
||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||
)
|
||||
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||
|
||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||
print(f"INFO: Using queries file: {queries_file}")
|
||||
@@ -327,7 +325,7 @@ def main():
|
||||
searcher = LeannSearcher(args.index_path)
|
||||
queries = load_queries(queries_file)
|
||||
|
||||
with open(golden_results_file, "r") as f:
|
||||
with open(golden_results_file) as f:
|
||||
golden_results_data = json.load(f)
|
||||
|
||||
num_eval_queries = min(args.num_queries, len(queries))
|
||||
@@ -340,10 +338,23 @@ def main():
|
||||
for i in range(num_eval_queries):
|
||||
start_time = time.time()
|
||||
new_results = searcher.search(
|
||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
||||
queries[i],
|
||||
top_k=args.top_k,
|
||||
complexity=args.ef_search,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
search_times.append(time.time() - start_time)
|
||||
|
||||
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||
answer = chat.ask(
|
||||
queries[i],
|
||||
top_k=args.top_k,
|
||||
complexity=args.ef_search,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
print(f"Answer: {answer}")
|
||||
# Correct Recall Calculation: Based on TEXT content
|
||||
new_texts = {result.text for result in new_results}
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModel
|
||||
|
||||
# Add MLX imports
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
|
||||
MLX_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||
MLX_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str = "facebook/contriever"
|
||||
batch_sizes: List[int] = None
|
||||
model_path: str = "facebook/contriever-msmarco"
|
||||
batch_sizes: list[int] = None
|
||||
seq_length: int = 256
|
||||
num_runs: int = 5
|
||||
use_fp16: bool = True
|
||||
@@ -30,18 +31,19 @@ class BenchmarkConfig:
|
||||
use_flash_attention: bool = False
|
||||
use_linear8bitlt: bool = False
|
||||
use_mlx: bool = False # New flag for MLX testing
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.batch_sizes is None:
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
|
||||
class MLXBenchmark:
|
||||
"""MLX-specific benchmark for embedding models"""
|
||||
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.model, self.tokenizer = self._load_model()
|
||||
|
||||
|
||||
def _load_model(self):
|
||||
"""Load MLX model and tokenizer following the API pattern"""
|
||||
print(f"Loading MLX model from {self.config.model_path}...")
|
||||
@@ -52,55 +54,51 @@ class MLXBenchmark:
|
||||
except Exception as e:
|
||||
print(f"Error loading MLX model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int):
|
||||
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
"""Run MLX inference with same input as PyTorch"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Convert PyTorch tensor to MLX array
|
||||
input_ids_mlx = mx.array(input_ids.numpy())
|
||||
|
||||
|
||||
# Get embeddings
|
||||
embeddings = self.model(input_ids_mlx)
|
||||
|
||||
|
||||
# Mean pooling (following the API pattern)
|
||||
pooled = embeddings.mean(axis=1)
|
||||
|
||||
|
||||
# Convert to numpy (following the API pattern)
|
||||
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
|
||||
|
||||
|
||||
# Force computation
|
||||
_ = pooled_numpy.shape
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"MLX inference error: {e}")
|
||||
return float('inf')
|
||||
return float("inf")
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
"""Run the MLX benchmark across all batch sizes"""
|
||||
results = {}
|
||||
|
||||
|
||||
print(f"Starting MLX benchmark with model: {self.config.model_path}")
|
||||
print(f"Testing batch sizes: {self.config.batch_sizes}")
|
||||
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\n=== Testing MLX batch size: {batch_size} ===")
|
||||
times = []
|
||||
|
||||
|
||||
# Create input batch (same as PyTorch)
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
|
||||
# Warm up
|
||||
print("Warming up...")
|
||||
for _ in range(3):
|
||||
@@ -109,26 +107,26 @@ class MLXBenchmark:
|
||||
except Exception as e:
|
||||
print(f"Warmup error: {e}")
|
||||
break
|
||||
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time = self._run_inference(input_ids)
|
||||
if elapsed_time != float('inf'):
|
||||
if elapsed_time != float("inf"):
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during MLX inference: {e}")
|
||||
break
|
||||
|
||||
|
||||
if not times:
|
||||
print(f"Skipping batch size {batch_size} due to errors")
|
||||
continue
|
||||
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
@@ -136,122 +134,133 @@ class MLXBenchmark:
|
||||
"min_time": np.min(times),
|
||||
"max_time": np.max(times),
|
||||
}
|
||||
|
||||
|
||||
print(f"MLX Results for batch size {batch_size}:")
|
||||
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f" Min Time: {np.min(times):.4f}s")
|
||||
print(f" Max Time: {np.max(times):.4f}s")
|
||||
print(f" Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
self.device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
self.model = self._load_model()
|
||||
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
|
||||
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
if self.config.use_fp16:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
model = model.to(self.device)
|
||||
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# print shape of input_ids and attention_mask
|
||||
print(f"input_ids shape: {input_ids.shape}")
|
||||
print(f"attention_mask shape: {attention_mask.shape}")
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
|
||||
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time = self._run_inference(input_ids)
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
|
||||
if not times:
|
||||
continue
|
||||
|
||||
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
else:
|
||||
peak_memory_gb = 0.0
|
||||
|
||||
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_benchmark():
|
||||
"""Main function to run the benchmark with optimized parameters."""
|
||||
config = BenchmarkConfig()
|
||||
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
|
||||
|
||||
return {
|
||||
"max_throughput": max_throughput,
|
||||
"avg_throughput": avg_throughput,
|
||||
"results": results
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||
|
||||
|
||||
def run_mlx_benchmark():
|
||||
"""Run MLX-specific benchmark"""
|
||||
@@ -260,55 +269,49 @@ def run_mlx_benchmark():
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "MLX not available"
|
||||
"error": "MLX not available",
|
||||
}
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path="mlx-community/all-MiniLM-L6-v2-4bit",
|
||||
use_mlx=True
|
||||
)
|
||||
|
||||
|
||||
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||
|
||||
try:
|
||||
benchmark = MLXBenchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
|
||||
if not results:
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "No valid results"
|
||||
"error": "No valid results",
|
||||
}
|
||||
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
|
||||
|
||||
return {
|
||||
"max_throughput": max_throughput,
|
||||
"avg_throughput": avg_throughput,
|
||||
"results": results
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"MLX benchmark failed: {e}")
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== PyTorch Benchmark ===")
|
||||
pytorch_result = run_benchmark()
|
||||
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
|
||||
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
|
||||
|
||||
|
||||
print("\n=== MLX Benchmark ===")
|
||||
mlx_result = run_mlx_benchmark()
|
||||
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
|
||||
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||
|
||||
|
||||
# Compare results
|
||||
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0:
|
||||
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput']
|
||||
print(f"\n=== Comparison ===")
|
||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
||||
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
||||
print("\n=== Comparison ===")
|
||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||
82
data/.gitattributes
vendored
82
data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - uncompressed
|
||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - compressed
|
||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - uncompressed
|
||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||
*.png filter=lfs diff=lfs merge=lfs -text
|
||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - compressed
|
||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||
# Video files - compressed
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
@@ -1,5 +1,5 @@
|
||||
The Project Gutenberg eBook of Pride and Prejudice
|
||||
|
||||
|
||||
This ebook is for the use of anyone anywhere in the United States and
|
||||
most other parts of the world at no cost and with almost no restrictions
|
||||
whatsoever. You may copy it, give it away or re-use it under the terms
|
||||
@@ -14557,7 +14557,7 @@ her into Derbyshire, had been the means of uniting them.
|
||||
*** END OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE ***
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Updated editions will replace the previous one—the old editions will
|
||||
be renamed.
|
||||
@@ -14662,7 +14662,7 @@ performed, viewed, copied or distributed:
|
||||
at www.gutenberg.org. If you
|
||||
are not located in the United States, you will have to check the laws
|
||||
of the country where you are located before using this eBook.
|
||||
|
||||
|
||||
1.E.2. If an individual Project Gutenberg™ electronic work is
|
||||
derived from texts not protected by U.S. copyright law (does not
|
||||
contain a notice indicating that it is posted with permission of the
|
||||
@@ -14724,7 +14724,7 @@ provided that:
|
||||
Gutenberg Literary Archive Foundation at the address specified in
|
||||
Section 4, “Information about donations to the Project Gutenberg
|
||||
Literary Archive Foundation.”
|
||||
|
||||
|
||||
• You provide a full refund of any money paid by a user who notifies
|
||||
you in writing (or by e-mail) within 30 days of receipt that s/he
|
||||
does not agree to the terms of the full Project Gutenberg™
|
||||
@@ -14732,15 +14732,15 @@ provided that:
|
||||
copies of the works possessed in a physical medium and discontinue
|
||||
all use of and all access to other copies of Project Gutenberg™
|
||||
works.
|
||||
|
||||
|
||||
• You provide, in accordance with paragraph 1.F.3, a full refund of
|
||||
any money paid for a work or a replacement copy, if a defect in the
|
||||
electronic work is discovered and reported to you within 90 days of
|
||||
receipt of the work.
|
||||
|
||||
|
||||
• You comply with all other terms of this agreement for free
|
||||
distribution of Project Gutenberg™ works.
|
||||
|
||||
|
||||
|
||||
1.E.9. If you wish to charge a fee or distribute a Project
|
||||
Gutenberg™ electronic work or group of works on different terms than
|
||||
@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
|
||||
including how to make donations to the Project Gutenberg Literary
|
||||
Archive Foundation, how to help produce our new eBooks, and how to
|
||||
subscribe to our email newsletter to hear about new eBooks.
|
||||
|
||||
|
||||
125
demo.ipynb
125
demo.ipynb
@@ -1,35 +1,116 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Quick Start \n",
|
||||
"\n",
|
||||
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||
"\n",
|
||||
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
||||
"# 1. Build index (no embeddings stored!)\n",
|
||||
"# install this if you are using colab\n",
|
||||
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
||||
"! uv pip install leann --no-deps\n",
|
||||
"# For Colab environment, we need to set some environment variables\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"INDEX_DIR = Path(\"./\").resolve()\n",
|
||||
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build the index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder\n",
|
||||
"\n",
|
||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language but it is not very popular\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\") \n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your macbook\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")\n",
|
||||
"# 2. Search with real-time embeddings\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2, recompute_beighbor_embeddings=True)\n",
|
||||
"print(results)\n",
|
||||
"\n",
|
||||
"llm_config = {\"type\": \"ollama\", \"model\": \"qwen3:8b\"}\n",
|
||||
"\n",
|
||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||
"\n",
|
||||
"response = chat.ask(\n",
|
||||
" \"Compare the two retrieved programming languages and say which one is more popular today. Respond in a single well-formed sentence.\",\n",
|
||||
" top_k=2,\n",
|
||||
" recompute_beighbor_embeddings=True,\n",
|
||||
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||
"builder.add_text(\n",
|
||||
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
||||
")\n",
|
||||
"print(response)"
|
||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||
"builder.build_index(INDEX_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Search with real-time embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannSearcher\n",
|
||||
"\n",
|
||||
"searcher = LeannSearcher(INDEX_PATH)\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||
"results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chat with LEANN using retrieved results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannChat\n",
|
||||
"\n",
|
||||
"llm_config = {\n",
|
||||
" \"type\": \"hf\",\n",
|
||||
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
||||
"response = chat.ask(\n",
|
||||
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||
" top_k=2,\n",
|
||||
" llm_kwargs={\"max_tokens\": 128},\n",
|
||||
")\n",
|
||||
"response"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
220
docs/CONTRIBUTING.md
Normal file
220
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
|
||||
## 🚀 Development Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **Install uv** (fast Python package installer):
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
2. **Clone the repository**:
|
||||
```bash
|
||||
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||
cd LEANN-RAG
|
||||
```
|
||||
|
||||
3. **Install system dependencies**:
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
```
|
||||
|
||||
**Ubuntu/Debian:**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
```
|
||||
|
||||
4. **Build from source**:
|
||||
```bash
|
||||
# macOS
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
|
||||
# Ubuntu/Debian
|
||||
uv sync
|
||||
```
|
||||
|
||||
## 🔨 Pre-commit Hooks
|
||||
|
||||
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||
|
||||
### Setup Pre-commit
|
||||
|
||||
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||
```bash
|
||||
uv pip install pre-commit
|
||||
```
|
||||
|
||||
2. **Install the git hooks**:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
3. **Run pre-commit manually** (optional):
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Pre-commit Checks
|
||||
|
||||
Our pre-commit configuration includes:
|
||||
- **Trailing whitespace removal**
|
||||
- **End-of-file fixing**
|
||||
- **YAML validation**
|
||||
- **Large file prevention**
|
||||
- **Merge conflict detection**
|
||||
- **Debug statement detection**
|
||||
- **Code formatting with ruff**
|
||||
- **Code linting with ruff**
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
# Run specific test file
|
||||
uv run pytest test/test_filename.py
|
||||
|
||||
# Run with coverage
|
||||
uv run pytest --cov=leann
|
||||
```
|
||||
|
||||
### Writing Tests
|
||||
|
||||
- Place tests in the `test/` directory
|
||||
- Follow the naming convention `test_*.py`
|
||||
- Use descriptive test names that explain what's being tested
|
||||
- Include both positive and negative test cases
|
||||
|
||||
## 📝 Code Style
|
||||
|
||||
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||
|
||||
### Format Your Code
|
||||
|
||||
```bash
|
||||
# Format all files
|
||||
ruff format
|
||||
|
||||
# Check formatting without changing files
|
||||
ruff format --check
|
||||
```
|
||||
|
||||
### Lint Your Code
|
||||
|
||||
```bash
|
||||
# Run linter with auto-fix
|
||||
ruff check --fix
|
||||
|
||||
# Just check without fixing
|
||||
ruff check
|
||||
```
|
||||
|
||||
### Style Guidelines
|
||||
|
||||
- Follow PEP 8 conventions
|
||||
- Use descriptive variable names
|
||||
- Add type hints where appropriate
|
||||
- Write docstrings for all public functions and classes
|
||||
- Keep functions focused and single-purpose
|
||||
|
||||
## 🚦 CI/CD
|
||||
|
||||
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||
|
||||
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||
4. **Wheel building**: Ensures packages can be built and distributed
|
||||
|
||||
### CI Commands
|
||||
|
||||
The CI uses the same commands as pre-commit to ensure consistency:
|
||||
```bash
|
||||
# Linting
|
||||
ruff check .
|
||||
|
||||
# Format checking
|
||||
ruff format --check .
|
||||
```
|
||||
|
||||
Make sure your code passes these checks locally before pushing!
|
||||
|
||||
## 🔄 Pull Request Process
|
||||
|
||||
1. **Fork the repository** and create your branch from `main`:
|
||||
```bash
|
||||
git checkout -b feature/your-feature-name
|
||||
```
|
||||
|
||||
2. **Make your changes**:
|
||||
- Write clean, documented code
|
||||
- Add tests for new functionality
|
||||
- Update documentation as needed
|
||||
|
||||
3. **Run pre-commit checks**:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
4. **Test your changes**:
|
||||
```bash
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
5. **Commit with descriptive messages**:
|
||||
```bash
|
||||
git commit -m "feat: add new search algorithm"
|
||||
```
|
||||
|
||||
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||
- `feat:` for new features
|
||||
- `fix:` for bug fixes
|
||||
- `docs:` for documentation changes
|
||||
- `test:` for test additions/changes
|
||||
- `refactor:` for code refactoring
|
||||
- `perf:` for performance improvements
|
||||
|
||||
6. **Push and create a pull request**:
|
||||
- Provide a clear description of your changes
|
||||
- Reference any related issues
|
||||
- Include examples or screenshots if applicable
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
When adding new features or making significant changes:
|
||||
|
||||
1. Update relevant documentation in `/docs`
|
||||
2. Add docstrings to new functions/classes
|
||||
3. Update README.md if needed
|
||||
4. Include usage examples
|
||||
|
||||
## 🤔 Getting Help
|
||||
|
||||
- **Discord**: Join our community for discussions
|
||||
- **Issues**: Check existing issues or create a new one
|
||||
- **Discussions**: For general questions and ideas
|
||||
|
||||
## 📄 License
|
||||
|
||||
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||
|
||||
---
|
||||
|
||||
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||
22
docs/RELEASE.md
Normal file
22
docs/RELEASE.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# Release Guide
|
||||
|
||||
## Setup (One-time)
|
||||
|
||||
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||
1. Get token: https://pypi.org/manage/account/token/
|
||||
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||
|
||||
## Release (One-click)
|
||||
|
||||
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||
2. Click "Run workflow"
|
||||
3. Enter version: `0.1.2`
|
||||
4. Click green "Run workflow" button
|
||||
|
||||
That's it! The workflow will automatically:
|
||||
- ✅ Update version in all packages
|
||||
- ✅ Build all packages
|
||||
- ✅ Publish to PyPI
|
||||
- ✅ Create GitHub tag and release
|
||||
|
||||
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# Thinking Budget Feature Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
||||
|
||||
## Feature Description
|
||||
|
||||
The thinking budget feature provides three levels of computational effort for reasoning models:
|
||||
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
||||
- **`medium`**: Balanced speed and reasoning depth
|
||||
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### 1. Command Line Interface
|
||||
|
||||
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
||||
|
||||
```bash
|
||||
# LEANN CLI
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# RAG Examples
|
||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
||||
```
|
||||
|
||||
### 2. LLM Backend Support
|
||||
|
||||
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
||||
|
||||
```python
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Handle thinking budget for reasoning models
|
||||
options = kwargs.copy()
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget:
|
||||
options.pop("thinking_budget", None)
|
||||
if thinking_budget in ["low", "medium", "high"]:
|
||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
||||
```
|
||||
|
||||
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
||||
|
||||
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
||||
|
||||
```python
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Handle thinking budget for reasoning models
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
||||
# Check if this is an o-series model
|
||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
||||
if any(model in self.model for model in o_series_models):
|
||||
params["reasoning_effort"] = thinking_budget
|
||||
```
|
||||
|
||||
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
||||
|
||||
### 3. Parameter Propagation
|
||||
|
||||
The thinking budget parameter is properly propagated through the LEANN architecture:
|
||||
|
||||
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
||||
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
||||
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
||||
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
||||
|
||||
## Files Modified
|
||||
|
||||
### Core Implementation
|
||||
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
||||
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
||||
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
||||
|
||||
### Documentation
|
||||
- `README.md`: Added thinking budget parameter to usage examples
|
||||
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
||||
|
||||
### Examples
|
||||
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Usage
|
||||
```bash
|
||||
# High reasoning effort for complex questions
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# Medium reasoning for balanced performance
|
||||
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
||||
|
||||
# Low reasoning for fast responses
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
||||
```
|
||||
|
||||
### RAG Examples
|
||||
```bash
|
||||
# Email RAG with high reasoning
|
||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# Document RAG with medium reasoning
|
||||
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
### Ollama Models
|
||||
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
||||
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
||||
|
||||
### OpenAI Models
|
||||
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
||||
- **GPT-OSS models**: Models that support reasoning capabilities
|
||||
|
||||
## Testing
|
||||
|
||||
The implementation includes comprehensive testing:
|
||||
- Parameter handling verification
|
||||
- Backend-specific API format validation
|
||||
- CLI argument parsing tests
|
||||
- Integration with existing LEANN architecture
|
||||
143
docs/ast_chunking_guide.md
Normal file
143
docs/ast_chunking_guide.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# AST-Aware Code chunking guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Enable AST chunking for mixed content (code + docs)
|
||||
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
|
||||
|
||||
# Specialized code repository indexing
|
||||
python -m apps.code_rag --repo-dir ./my_codebase
|
||||
|
||||
# Global CLI with AST support
|
||||
leann build my-code-index --docs ./src --use-ast-chunking
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Install LEANN with AST chunking support
|
||||
uv pip install -e "."
|
||||
```
|
||||
|
||||
#### For normal users (PyPI install)
|
||||
- Use `pip install leann` or `uv pip install leann`.
|
||||
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
|
||||
|
||||
#### For developers (from source, editable)
|
||||
```bash
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
uv sync
|
||||
```
|
||||
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
|
||||
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
|
||||
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
|
||||
|
||||
## Best Practices
|
||||
|
||||
### When to Use AST Chunking
|
||||
|
||||
✅ **Recommended for:**
|
||||
- Code repositories with multiple languages
|
||||
- Mixed documentation and code content
|
||||
- Complex codebases with deep function/class hierarchies
|
||||
- When working with Claude Code for code assistance
|
||||
|
||||
❌ **Not recommended for:**
|
||||
- Pure text documents
|
||||
- Very large files (>1MB)
|
||||
- Languages not supported by tree-sitter
|
||||
|
||||
### Optimal Configuration
|
||||
|
||||
```bash
|
||||
# Recommended settings for most codebases
|
||||
python -m apps.code_rag \
|
||||
--repo-dir ./src \
|
||||
--ast-chunk-size 768 \
|
||||
--ast-chunk-overlap 96 \
|
||||
--exclude-dirs .git __pycache__ node_modules build dist
|
||||
```
|
||||
|
||||
### Supported Languages
|
||||
|
||||
| Extension | Language | Status |
|
||||
|-----------|----------|--------|
|
||||
| `.py` | Python | ✅ Full support |
|
||||
| `.java` | Java | ✅ Full support |
|
||||
| `.cs` | C# | ✅ Full support |
|
||||
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
|
||||
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### Document RAG with Code Support
|
||||
|
||||
```python
|
||||
# Enable code chunking in document RAG
|
||||
python -m apps.document_rag \
|
||||
--enable-code-chunking \
|
||||
--data-dir ./project \
|
||||
--query "How does authentication work in the codebase?"
|
||||
```
|
||||
|
||||
### Claude Code Integration
|
||||
|
||||
When using with Claude Code MCP server, AST chunking provides better context for:
|
||||
- Code completion and suggestions
|
||||
- Bug analysis and debugging
|
||||
- Architecture understanding
|
||||
- Refactoring assistance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Fallback to Traditional Chunking**
|
||||
- Normal behavior for unsupported languages
|
||||
- Check logs for specific language support
|
||||
|
||||
2. **Performance with Large Files**
|
||||
- Adjust `--max-file-size` parameter
|
||||
- Use `--exclude-dirs` to skip unnecessary directories
|
||||
|
||||
3. **Quality Issues**
|
||||
- Try different `--ast-chunk-size` values (512, 768, 1024)
|
||||
- Adjust overlap for better context preservation
|
||||
|
||||
### Debug Mode
|
||||
|
||||
```bash
|
||||
export LEANN_LOG_LEVEL=DEBUG
|
||||
python -m apps.code_rag --repo-dir ./my_code
|
||||
```
|
||||
|
||||
## Migration from Traditional Chunking
|
||||
|
||||
Existing workflows continue to work without changes. To enable AST chunking:
|
||||
|
||||
```bash
|
||||
# Before
|
||||
python -m apps.document_rag --chunk-size 256
|
||||
|
||||
# After (maintains traditional chunking for non-code files)
|
||||
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
|
||||
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
|
||||
- [Research Paper](https://arxiv.org/html/2506.15655v1)
|
||||
|
||||
---
|
||||
|
||||
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.
|
||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Comparison between Sentence Transformers and OpenAI embeddings
|
||||
|
||||
This example shows how different embedding models handle complex queries
|
||||
and demonstrates the differences between local and API-based embeddings.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
|
||||
# OpenAI API key should be set as environment variable
|
||||
# export OPENAI_API_KEY="your-api-key-here"
|
||||
|
||||
# Test data
|
||||
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||
|
||||
# Two queries with same intent but different wording
|
||||
query1 = "Tell me my browser history about some conference i often visit"
|
||||
query2 = "browser history about conference I often visit"
|
||||
|
||||
texts = [query1, query2, conference_text, browser_text]
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) # Already normalized
|
||||
|
||||
|
||||
def analyze_embeddings(embeddings, model_name):
|
||||
print(f"\n=== {model_name} Results ===")
|
||||
|
||||
# Results for Query 1
|
||||
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||
|
||||
print(f"Query 1: '{query1}'")
|
||||
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||
|
||||
# Results for Query 2
|
||||
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||
|
||||
print(f"\nQuery 2: '{query2}'")
|
||||
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||
|
||||
# Show the impact
|
||||
print("\n=== Impact Analysis ===")
|
||||
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||
|
||||
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||
print("✅ STABLE: Conference remains winner in both queries")
|
||||
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||
print("✅ STABLE: Browser remains winner in both queries")
|
||||
else:
|
||||
print("🔄 MIXED: Results vary between queries")
|
||||
|
||||
return {
|
||||
"query1_conf": sim1_conf,
|
||||
"query1_browser": sim1_browser,
|
||||
"query2_conf": sim2_conf,
|
||||
"query2_browser": sim2_browser,
|
||||
}
|
||||
|
||||
|
||||
# Test Sentence Transformers
|
||||
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||
try:
|
||||
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||
except Exception as e:
|
||||
print(f"❌ Sentence Transformers failed: {e}")
|
||||
st_results = None
|
||||
|
||||
# Test OpenAI
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing OpenAI (text-embedding-3-small)...")
|
||||
try:
|
||||
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||
except Exception as e:
|
||||
print(f"❌ OpenAI failed: {e}")
|
||||
openai_results = None
|
||||
|
||||
# Compare results
|
||||
if st_results and openai_results:
|
||||
print("\n" + "=" * 60)
|
||||
print("=== COMPARISON SUMMARY ===")
|
||||
384
docs/configuration-guide.md
Normal file
384
docs/configuration-guide.md
Normal file
@@ -0,0 +1,384 @@
|
||||
# LEANN Configuration Guide
|
||||
|
||||
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
||||
|
||||
## Getting Started: Simple is Better
|
||||
|
||||
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
||||
|
||||
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
||||
```bash
|
||||
python -m apps.document_rag --query "What techniques does LEANN use?"
|
||||
```
|
||||
|
||||
**For other data sources**: Limit the dataset size for quick testing
|
||||
```bash
|
||||
# WeChat: Test with recent messages only
|
||||
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
||||
|
||||
# Browser history: Last few days
|
||||
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
||||
|
||||
# Email: Recent inbox
|
||||
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
||||
```
|
||||
|
||||
Once validated, scale up gradually:
|
||||
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
||||
- This helps identify issues early before committing to long processing times
|
||||
|
||||
## Embedding Model Selection: Understanding the Trade-offs
|
||||
|
||||
Based on our experience developing LEANN, embedding models fall into three categories:
|
||||
|
||||
### Small Models (< 100M parameters)
|
||||
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
||||
- **Pros**: Lightweight, fast for both indexing and inference
|
||||
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
||||
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
||||
|
||||
### Medium Models (100M-500M parameters)
|
||||
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
||||
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
||||
- **Cons**: Requires more compute than small models
|
||||
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
||||
|
||||
### Large Models (500M+ parameters)
|
||||
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
||||
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
||||
- **Cons**: Slower inference, longer index build times
|
||||
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
||||
|
||||
### Quick Start: Cloud and Local Embedding Options
|
||||
|
||||
**OpenAI Embeddings (Fastest Setup)**
|
||||
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
|
||||
```bash
|
||||
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
||||
--embedding-mode openai --embedding-model text-embedding-3-small
|
||||
```
|
||||
|
||||
**Ollama Embeddings (Privacy-Focused)**
|
||||
For local embeddings with complete privacy:
|
||||
```bash
|
||||
# First, pull an embedding model
|
||||
ollama pull nomic-embed-text
|
||||
|
||||
# Use Ollama embeddings
|
||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
||||
|
||||
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
||||
- **Pros**: No local compute needed, consistently fast, high quality
|
||||
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
||||
|
||||
**Local Embeddings**
|
||||
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
||||
- **Cons**: Slower than cloud APIs, requires local compute resources
|
||||
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
||||
|
||||
</details>
|
||||
|
||||
## Index Selection: Matching Your Scale
|
||||
|
||||
### HNSW (Hierarchical Navigable Small World)
|
||||
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
||||
- Full recomputation required
|
||||
- High memory usage during build phase
|
||||
- Excellent recall (95%+)
|
||||
|
||||
```bash
|
||||
# Optimal for most use cases
|
||||
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
||||
```
|
||||
|
||||
### DiskANN
|
||||
**Best for**: Large datasets, especially when you want `recompute=True`.
|
||||
|
||||
**Key advantages:**
|
||||
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
|
||||
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
|
||||
- **Better scaling**: Designed for 100k+ documents
|
||||
|
||||
**Recompute behavior:**
|
||||
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
|
||||
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
|
||||
|
||||
```bash
|
||||
# Recommended for most use cases
|
||||
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||
```
|
||||
|
||||
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||
|
||||
## LLM Selection: Engine and Model Comparison
|
||||
|
||||
### LLM Engines
|
||||
|
||||
**OpenAI** (`--llm openai`)
|
||||
- **Pros**: Best quality, consistent performance, no local resources needed
|
||||
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
||||
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
||||
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
||||
|
||||
**Ollama** (`--llm ollama`)
|
||||
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
||||
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
||||
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
||||
|
||||
**HuggingFace** (`--llm hf`)
|
||||
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
||||
- **Cons**: More complex initial setup
|
||||
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
||||
|
||||
## Parameter Tuning Guide
|
||||
|
||||
### Search Complexity Parameters
|
||||
|
||||
**`--build-complexity`** (index building)
|
||||
- Controls thoroughness during index construction
|
||||
- Higher = better recall but slower build
|
||||
- Recommendations:
|
||||
- 32: Quick prototyping
|
||||
- 64: Balanced (default)
|
||||
- 128: Production systems
|
||||
- 256: Maximum quality
|
||||
|
||||
**`--search-complexity`** (query time)
|
||||
- Controls search thoroughness
|
||||
- Higher = better results but slower
|
||||
- Recommendations:
|
||||
- 16: Fast/Interactive search
|
||||
- 32: High quality with diversity
|
||||
- 64+: Maximum accuracy
|
||||
|
||||
### Top-K Selection
|
||||
|
||||
**`--top-k`** (number of retrieved chunks)
|
||||
- More chunks = better context but slower LLM processing
|
||||
- Should be always smaller than `--search-complexity`
|
||||
- Guidelines:
|
||||
- 10-20: General questions (default: 20)
|
||||
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
||||
|
||||
**Trade-off formula**:
|
||||
- Retrieval time ∝ log(n) × search_complexity
|
||||
- LLM processing time ∝ top_k × chunk_size
|
||||
- Total context = top_k × chunk_size tokens
|
||||
|
||||
### Thinking Budget for Reasoning Models
|
||||
|
||||
**`--thinking-budget`** (reasoning effort level)
|
||||
- Controls the computational effort for reasoning models
|
||||
- Options: `low`, `medium`, `high`
|
||||
- Guidelines:
|
||||
- `low`: Fast responses, basic reasoning (default for simple queries)
|
||||
- `medium`: Balanced speed and reasoning depth
|
||||
- `high`: Maximum reasoning effort, best for complex analytical questions
|
||||
- **Supported Models**:
|
||||
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
||||
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
||||
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
||||
- **Example**: `--thinking-budget high` for complex analytical questions
|
||||
|
||||
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
||||
|
||||
**💡 Quick Examples:**
|
||||
```bash
|
||||
# OpenAI o-series reasoning model
|
||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||
--index-dir hnswbuild --backend hnsw \
|
||||
--llm openai --llm-model o3 --thinking-budget medium
|
||||
|
||||
# Ollama reasoning model
|
||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||
--index-dir hnswbuild --backend hnsw \
|
||||
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
```
|
||||
|
||||
### Graph Degree (HNSW/DiskANN)
|
||||
|
||||
**`--graph-degree`**
|
||||
- Number of connections per node in the graph
|
||||
- Higher = better recall but more memory
|
||||
- HNSW: 16-32 (default: 32)
|
||||
- DiskANN: 32-128 (default: 64)
|
||||
|
||||
|
||||
## Performance Optimization Checklist
|
||||
|
||||
### If Embedding is Too Slow
|
||||
|
||||
1. **Switch to smaller model**:
|
||||
```bash
|
||||
# From large model
|
||||
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
||||
# To small model
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
2. **Limit dataset size for testing**:
|
||||
```bash
|
||||
--max-items 1000 # Process first 1k items only
|
||||
```
|
||||
|
||||
3. **Use MLX on Apple Silicon** (optional optimization):
|
||||
```bash
|
||||
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
||||
```
|
||||
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
||||
|
||||
4. **Use Ollama**
|
||||
```bash
|
||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||
```
|
||||
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
||||
### If Search Quality is Poor
|
||||
|
||||
1. **Increase retrieval count**:
|
||||
```bash
|
||||
--top-k 30 # Retrieve more candidates
|
||||
```
|
||||
|
||||
2. **Upgrade embedding model**:
|
||||
```bash
|
||||
# For English
|
||||
--embedding-model BAAI/bge-base-en-v1.5
|
||||
# For multilingual
|
||||
--embedding-model intfloat/multilingual-e5-large
|
||||
```
|
||||
|
||||
## Understanding the Trade-offs
|
||||
|
||||
Every configuration choice involves trade-offs:
|
||||
|
||||
| Factor | Small/Fast | Large/Quality |
|
||||
|--------|------------|---------------|
|
||||
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
||||
| Chunk Size | 512 tokens | 128 tokens |
|
||||
| Index Type | HNSW | DiskANN |
|
||||
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
||||
|
||||
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
||||
|
||||
## Low-resource setups
|
||||
|
||||
If you don’t have a local GPU or builds/searches are too slow, use one or more of the options below.
|
||||
|
||||
### 1) Use OpenAI embeddings (no local compute)
|
||||
|
||||
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY=sk-...
|
||||
|
||||
# Build with OpenAI embeddings
|
||||
leann build my-index \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-3-small
|
||||
|
||||
# Search with OpenAI embeddings (recompute at query time)
|
||||
leann search my-index "your query" \
|
||||
--recompute
|
||||
```
|
||||
|
||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||
|
||||
```bash
|
||||
# One-time: install and configure SkyPilot
|
||||
pip install skypilot
|
||||
|
||||
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
|
||||
sky launch -c leann-gpu sky/leann-build.yaml
|
||||
|
||||
# Override parameters via -e key=value (optional)
|
||||
sky launch -c leann-gpu sky/leann-build.yaml \
|
||||
-e index_name=my-index \
|
||||
-e backend=hnsw \
|
||||
-e embedding_mode=sentence-transformers \
|
||||
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
|
||||
|
||||
# Copy the built index back to your local .leann (use rsync)
|
||||
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
|
||||
```
|
||||
|
||||
### 3) Disable recomputation to trade storage for speed
|
||||
|
||||
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
|
||||
|
||||
```bash
|
||||
# Build without recomputation (HNSW requires non-compact in this mode)
|
||||
leann build my-index --no-recompute --no-compact
|
||||
|
||||
# Search without recomputation
|
||||
leann search my-index "your query" --no-recompute
|
||||
```
|
||||
|
||||
When to use:
|
||||
- Extreme low latency requirements (high QPS, interactive assistants)
|
||||
- Read-heavy workloads where storage is cheaper than latency
|
||||
- No always-available GPU
|
||||
|
||||
Constraints:
|
||||
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
|
||||
- DiskANN: supported; `--no-recompute` skips selective recompute during search
|
||||
|
||||
Storage impact:
|
||||
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
|
||||
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
|
||||
|
||||
Converting an existing index (rebuild required):
|
||||
```bash
|
||||
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
|
||||
leann build my-index --force --no-recompute --no-compact
|
||||
```
|
||||
|
||||
Python API usage:
|
||||
```python
|
||||
from leann import LeannSearcher
|
||||
|
||||
searcher = LeannSearcher("/path/to/my-index.leann")
|
||||
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
|
||||
```
|
||||
|
||||
Trade-offs:
|
||||
- Lower latency and fewer network hops at query time
|
||||
- Significantly higher storage (10–100× vs selective recomputation)
|
||||
- Slightly larger memory footprint during build and search
|
||||
|
||||
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
|
||||
|
||||
- HNSW
|
||||
|
||||
```text
|
||||
recompute=True: search_time=0.818s, size=1.1MB
|
||||
recompute=False: search_time=0.012s, size=16.6MB
|
||||
```
|
||||
|
||||
- DiskANN
|
||||
|
||||
```text
|
||||
recompute=True: search_time=0.041s, size=5.9MB
|
||||
recompute=False: search_time=0.013s, size=24.6MB
|
||||
```
|
||||
|
||||
Conclusion:
|
||||
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
|
||||
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
|
||||
|
||||
|
||||
|
||||
## Further Reading
|
||||
|
||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||
10
docs/faq.md
Normal file
10
docs/faq.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# FAQ
|
||||
|
||||
## 1. My building time seems long
|
||||
|
||||
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||
|
||||
```bash
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
23
docs/features.md
Normal file
23
docs/features.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# ✨ Detailed Features
|
||||
|
||||
## 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||
|
||||
## 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
||||
|
||||
## 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
149
docs/grep_search.md
Normal file
149
docs/grep_search.md
Normal file
@@ -0,0 +1,149 @@
|
||||
# LEANN Grep Search Usage Guide
|
||||
|
||||
## Overview
|
||||
|
||||
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Simple Grep Search
|
||||
|
||||
```python
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
searcher = LeannSearcher("your_index_path")
|
||||
|
||||
# Exact text search
|
||||
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
|
||||
|
||||
for result in results:
|
||||
print(f"Score: {result.score}")
|
||||
print(f"Text: {result.text[:100]}...")
|
||||
print("-" * 40)
|
||||
```
|
||||
|
||||
### Comparison: Semantic vs Grep Search
|
||||
|
||||
```python
|
||||
# Semantic search - finds conceptually similar content
|
||||
semantic_results = searcher.search("machine learning algorithms", top_k=3)
|
||||
|
||||
# Grep search - finds exact text matches
|
||||
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
|
||||
```
|
||||
|
||||
## When to Use Grep Search
|
||||
|
||||
### Use Cases
|
||||
|
||||
- **Code Search**: Finding specific function definitions, class names, or variable references
|
||||
- **Error Debugging**: Locating exact error messages or stack traces
|
||||
- **Documentation**: Finding specific API endpoints or exact terminology
|
||||
|
||||
### Examples
|
||||
|
||||
```python
|
||||
# Find function definitions
|
||||
functions = searcher.search("def __init__", use_grep=True)
|
||||
|
||||
# Find import statements
|
||||
imports = searcher.search("from sklearn import", use_grep=True)
|
||||
|
||||
# Find specific error types
|
||||
errors = searcher.search("FileNotFoundError", use_grep=True)
|
||||
|
||||
# Find TODO comments
|
||||
todos = searcher.search("TODO:", use_grep=True)
|
||||
|
||||
# Find configuration entries
|
||||
configs = searcher.search("server_port=", use_grep=True)
|
||||
```
|
||||
|
||||
## Technical Details
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
|
||||
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
|
||||
3. **Result Processing**: Parses JSON lines and extracts text and metadata
|
||||
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
|
||||
|
||||
### Search Process
|
||||
|
||||
```
|
||||
Query: "def train_model"
|
||||
↓
|
||||
grep -i -n "def train_model" documents.leann.passages.jsonl
|
||||
↓
|
||||
Parse matching JSON lines
|
||||
↓
|
||||
Calculate scores based on term frequency
|
||||
↓
|
||||
Return top_k results
|
||||
```
|
||||
|
||||
### Scoring Algorithm
|
||||
|
||||
```python
|
||||
# Term frequency in document
|
||||
score = text.lower().count(query.lower())
|
||||
```
|
||||
|
||||
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### Grep Command Not Found
|
||||
```
|
||||
RuntimeError: grep command not found. Please install grep or use semantic search.
|
||||
```
|
||||
|
||||
**Solution**: Install grep on your system:
|
||||
- **Ubuntu/Debian**: `sudo apt-get install grep`
|
||||
- **macOS**: grep is pre-installed
|
||||
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
|
||||
|
||||
#### No Results Found
|
||||
```python
|
||||
# Check if your query exists in the raw data
|
||||
results = searcher.search("your_query", use_grep=True)
|
||||
if not results:
|
||||
print("No exact matches found. Try:")
|
||||
print("1. Check spelling and case")
|
||||
print("2. Use partial terms")
|
||||
print("3. Switch to semantic search")
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Grep Search Example
|
||||
Demonstrates grep search for exact text matching.
|
||||
"""
|
||||
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
def demonstrate_grep_search():
|
||||
# Initialize searcher
|
||||
searcher = LeannSearcher("my_index")
|
||||
|
||||
print("=== Function Search ===")
|
||||
functions = searcher.search("def __init__", use_grep=True, top_k=5)
|
||||
for i, result in enumerate(functions, 1):
|
||||
print(f"{i}. Score: {result.score}")
|
||||
print(f" Preview: {result.text[:60]}...")
|
||||
print()
|
||||
|
||||
print("=== Error Search ===")
|
||||
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
|
||||
for result in errors:
|
||||
print(f"Content: {result.text.strip()}")
|
||||
print("-" * 40)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demonstrate_grep_search()
|
||||
```
|
||||
300
docs/metadata_filtering.md
Normal file
300
docs/metadata_filtering.md
Normal file
@@ -0,0 +1,300 @@
|
||||
# LEANN Metadata Filtering Usage Guide
|
||||
|
||||
## Overview
|
||||
|
||||
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Adding Metadata to Your Documents
|
||||
|
||||
When building your index, add metadata to each text chunk:
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder("hnsw")
|
||||
|
||||
# Add text with metadata
|
||||
builder.add_text(
|
||||
text="Chapter 1: Alice falls down the rabbit hole",
|
||||
metadata={
|
||||
"chapter": 1,
|
||||
"character": "Alice",
|
||||
"themes": ["adventure", "curiosity"],
|
||||
"word_count": 150
|
||||
}
|
||||
)
|
||||
|
||||
builder.build_index("alice_in_wonderland_index")
|
||||
```
|
||||
|
||||
### Searching with Metadata Filters
|
||||
|
||||
Use the `metadata_filters` parameter in search calls:
|
||||
|
||||
```python
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
searcher = LeannSearcher("alice_in_wonderland_index")
|
||||
|
||||
# Search with filters
|
||||
results = searcher.search(
|
||||
query="What happens to Alice?",
|
||||
top_k=10,
|
||||
metadata_filters={
|
||||
"chapter": {"<=": 5}, # Only chapters 1-5
|
||||
"spoiler_level": {"!=": "high"} # No high spoilers
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Filter Syntax
|
||||
|
||||
### Basic Structure
|
||||
|
||||
```python
|
||||
metadata_filters = {
|
||||
"field_name": {"operator": value},
|
||||
"another_field": {"operator": value}
|
||||
}
|
||||
```
|
||||
|
||||
### Supported Operators
|
||||
|
||||
#### Comparison Operators
|
||||
- `"=="`: Equal to
|
||||
- `"!="`: Not equal to
|
||||
- `"<"`: Less than
|
||||
- `"<="`: Less than or equal
|
||||
- `">"`: Greater than
|
||||
- `">="`: Greater than or equal
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"chapter": {"==": 1}} # Exactly chapter 1
|
||||
{"page": {">": 100}} # Pages after 100
|
||||
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
||||
{"word_count": {"<": 500}} # Short passages
|
||||
```
|
||||
|
||||
#### Membership Operators
|
||||
- `"in"`: Value is in list
|
||||
- `"not_in"`: Value is not in list
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
||||
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
||||
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
||||
```
|
||||
|
||||
#### String Operators
|
||||
- `"contains"`: String contains substring
|
||||
- `"starts_with"`: String starts with prefix
|
||||
- `"ends_with"`: String ends with suffix
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"title": {"contains": "alice"}} # Title contains "alice"
|
||||
{"filename": {"ends_with": ".py"}} # Python files
|
||||
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
||||
```
|
||||
|
||||
#### Boolean Operators
|
||||
- `"is_true"`: Field is truthy
|
||||
- `"is_false"`: Field is falsy
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"is_published": {"is_true": True}} # Published content
|
||||
{"is_draft": {"is_false": False}} # Not drafts
|
||||
```
|
||||
|
||||
### Multiple Operators on Same Field
|
||||
|
||||
You can apply multiple operators to the same field (AND logic):
|
||||
|
||||
```python
|
||||
metadata_filters = {
|
||||
"word_count": {
|
||||
">=": 100, # At least 100 words
|
||||
"<=": 500 # At most 500 words
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Compound Filters
|
||||
|
||||
Multiple fields are combined with AND logic:
|
||||
|
||||
```python
|
||||
metadata_filters = {
|
||||
"chapter": {"<=": 10}, # Up to chapter 10
|
||||
"character": {"==": "Alice"}, # About Alice
|
||||
"spoiler_level": {"!=": "high"} # No major spoilers
|
||||
}
|
||||
```
|
||||
|
||||
## Use Case Examples
|
||||
|
||||
### 1. Spoiler-Free Book Search
|
||||
|
||||
```python
|
||||
# Reader has only read up to chapter 5
|
||||
def search_spoiler_free(query, max_chapter):
|
||||
return searcher.search(
|
||||
query=query,
|
||||
metadata_filters={
|
||||
"chapter": {"<=": max_chapter},
|
||||
"spoiler_level": {"in": ["none", "low"]}
|
||||
}
|
||||
)
|
||||
|
||||
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
||||
```
|
||||
|
||||
### 2. Document Management by Date
|
||||
|
||||
```python
|
||||
# Find recent documents
|
||||
recent_docs = searcher.search(
|
||||
query="project updates",
|
||||
metadata_filters={
|
||||
"date": {">=": "2024-01-01"},
|
||||
"document_type": {"==": "report"}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Code Search by File Type
|
||||
|
||||
```python
|
||||
# Search only Python files
|
||||
python_code = searcher.search(
|
||||
query="authentication function",
|
||||
metadata_filters={
|
||||
"file_extension": {"==": ".py"},
|
||||
"lines_of_code": {"<": 100}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Content Filtering by Audience
|
||||
|
||||
```python
|
||||
# Age-appropriate content
|
||||
family_content = searcher.search(
|
||||
query="adventure stories",
|
||||
metadata_filters={
|
||||
"age_rating": {"in": ["G", "PG"]},
|
||||
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Multi-Book Series Management
|
||||
|
||||
```python
|
||||
# Search across first 3 books only
|
||||
early_series = searcher.search(
|
||||
query="character development",
|
||||
metadata_filters={
|
||||
"series": {"==": "Harry Potter"},
|
||||
"book_number": {"<=": 3}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Running the Example
|
||||
|
||||
You can see metadata filtering in action with our spoiler-free book RAG example:
|
||||
|
||||
```bash
|
||||
# Don't forget to set up the environment
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
|
||||
# Run the spoiler-free book RAG example
|
||||
uv run examples/spoiler_free_book_rag.py
|
||||
```
|
||||
|
||||
This example demonstrates:
|
||||
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
||||
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
||||
- Different scenarios for readers at various points in the book
|
||||
|
||||
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Custom Chunking with metadata
|
||||
|
||||
```python
|
||||
def chunk_book_with_metadata(book_text, book_info):
|
||||
chunks = []
|
||||
|
||||
for chapter_num, chapter_text in parse_chapters(book_text):
|
||||
# Extract entities, themes, etc.
|
||||
characters = extract_characters(chapter_text)
|
||||
themes = classify_themes(chapter_text)
|
||||
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
||||
|
||||
# Create chunks with rich metadata
|
||||
for paragraph in split_paragraphs(chapter_text):
|
||||
chunks.append({
|
||||
"text": paragraph,
|
||||
"metadata": {
|
||||
"book_title": book_info["title"],
|
||||
"chapter": chapter_num,
|
||||
"characters": characters,
|
||||
"themes": themes,
|
||||
"spoiler_level": spoiler_level,
|
||||
"word_count": len(paragraph.split()),
|
||||
"reading_level": calculate_reading_level(paragraph)
|
||||
}
|
||||
})
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Efficient Filtering Strategies
|
||||
|
||||
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
||||
|
||||
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
||||
|
||||
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
||||
|
||||
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
||||
|
||||
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
||||
|
||||
### Adding Metadata to Existing Indices
|
||||
|
||||
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
||||
|
||||
```python
|
||||
# Read existing passages and add metadata
|
||||
def add_metadata_to_existing_chunks(chunks):
|
||||
for chunk in chunks:
|
||||
# Extract or assign metadata based on content
|
||||
chunk["metadata"] = extract_metadata(chunk["text"])
|
||||
return chunks
|
||||
|
||||
# Rebuild index with metadata
|
||||
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
||||
builder = LeannBuilder("hnsw")
|
||||
for chunk in enhanced_chunks:
|
||||
builder.add_text(chunk["text"], chunk["metadata"])
|
||||
builder.build_index("enhanced_index")
|
||||
```
|
||||
75
docs/normalized_embeddings.md
Normal file
75
docs/normalized_embeddings.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# Normalized Embeddings Support in LEANN
|
||||
|
||||
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
||||
|
||||
## What are Normalized Embeddings?
|
||||
|
||||
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
||||
|
||||
## Automatic Detection
|
||||
|
||||
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
||||
|
||||
1. **Automatically set `distance_metric="cosine"`** if not specified
|
||||
2. **Show a warning** if you manually specify a different distance metric
|
||||
3. **Provide optimal search performance** with the correct metric
|
||||
|
||||
## Supported Normalized Embedding Models
|
||||
|
||||
### OpenAI
|
||||
All OpenAI text embedding models are normalized:
|
||||
- `text-embedding-ada-002`
|
||||
- `text-embedding-3-small`
|
||||
- `text-embedding-3-large`
|
||||
|
||||
### Voyage AI
|
||||
All Voyage AI embedding models are normalized:
|
||||
- `voyage-2`
|
||||
- `voyage-3`
|
||||
- `voyage-large-2`
|
||||
- `voyage-multilingual-2`
|
||||
- `voyage-code-2`
|
||||
|
||||
### Cohere
|
||||
All Cohere embedding models are normalized:
|
||||
- `embed-english-v3.0`
|
||||
- `embed-multilingual-v3.0`
|
||||
- `embed-english-light-v3.0`
|
||||
- `embed-multilingual-light-v3.0`
|
||||
|
||||
## Example Usage
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
# Automatic detection - will use cosine distance
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai"
|
||||
)
|
||||
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
||||
# Automatically setting distance_metric='cosine'
|
||||
|
||||
# Manual override (not recommended)
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
distance_metric="mips" # Will show warning
|
||||
)
|
||||
# Warning: Using 'mips' distance metric with normalized embeddings...
|
||||
```
|
||||
|
||||
## Non-Normalized Embeddings
|
||||
|
||||
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
||||
|
||||
## Why This Matters
|
||||
|
||||
Using the wrong distance metric with normalized embeddings can lead to:
|
||||
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
||||
- **Incorrect ranking** of search results
|
||||
- **Suboptimal performance** compared to using the correct metric
|
||||
|
||||
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
||||
21
docs/roadmap.md
Normal file
21
docs/roadmap.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# 📈 Roadmap
|
||||
|
||||
## 🎯 Q2 2025
|
||||
|
||||
- [X] HNSW backend integration
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
## 🚀 Q3 2025
|
||||
|
||||
- [ ] Advanced caching strategies
|
||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||
- [ ] Add OpenAI recompute API
|
||||
|
||||
## 🌟 Q4 2025
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
@@ -1,21 +1,28 @@
|
||||
"""
|
||||
Simple demo showing basic leann usage
|
||||
Run: uv run python examples/simple_demo.py
|
||||
Run: uv run python examples/basic_demo.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
|
||||
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Simple demo of Leann with selectable embedding models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||
print()
|
||||
|
||||
|
||||
# Sample knowledge base
|
||||
chunks = [
|
||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
||||
@@ -27,7 +34,7 @@ def main():
|
||||
"Big data refers to extremely large datasets that require special tools and techniques to process.",
|
||||
"Cloud computing provides on-demand access to computing resources over the internet.",
|
||||
]
|
||||
|
||||
|
||||
print("1. Building index (no embeddings stored)...")
|
||||
builder = LeannBuilder(
|
||||
embedding_model=args.embedding_model,
|
||||
@@ -37,45 +44,45 @@ def main():
|
||||
builder.add_text(chunk)
|
||||
builder.build_index("demo_knowledge.leann")
|
||||
print()
|
||||
|
||||
|
||||
print("2. Searching with real-time embeddings...")
|
||||
searcher = LeannSearcher("demo_knowledge.leann")
|
||||
|
||||
|
||||
queries = [
|
||||
"What is machine learning?",
|
||||
"How does neural network work?",
|
||||
"How does neural network work?",
|
||||
"Tell me about data processing",
|
||||
]
|
||||
|
||||
|
||||
for query in queries:
|
||||
print(f"Query: {query}")
|
||||
results = searcher.search(query, top_k=2)
|
||||
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f" {i}. Score: {result.score:.3f}")
|
||||
print(f" Text: {result.text[:100]}...")
|
||||
print()
|
||||
|
||||
|
||||
print("3. Interactive chat demo:")
|
||||
print(" (Note: Requires OpenAI API key for real responses)")
|
||||
|
||||
|
||||
chat = LeannChat("demo_knowledge.leann")
|
||||
|
||||
|
||||
# Demo questions
|
||||
demo_questions: list[str] = [
|
||||
"What is the difference between machine learning and deep learning?",
|
||||
"How is data science related to big data?",
|
||||
]
|
||||
|
||||
|
||||
for question in demo_questions:
|
||||
print(f" Q: {question}")
|
||||
response = chat.ask(question)
|
||||
print(f" A: {response}")
|
||||
print()
|
||||
|
||||
|
||||
print("Demo completed! Try running:")
|
||||
print(" uv run python examples/document_search.py")
|
||||
print(" uv run python apps/document_rag.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
@@ -1,146 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Document search demo with recompute mode
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import time
|
||||
|
||||
# Import backend packages to trigger plugin registration
|
||||
try:
|
||||
import leann_backend_diskann
|
||||
import leann_backend_hnsw
|
||||
print("INFO: Backend packages imported successfully.")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||
|
||||
# Import upper-level API from leann-core
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
|
||||
def load_sample_documents():
|
||||
"""Create sample documents for demonstration"""
|
||||
docs = [
|
||||
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
|
||||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
||||
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
|
||||
]
|
||||
return docs
|
||||
|
||||
def main():
|
||||
print("==========================================================")
|
||||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
||||
print("==========================================================")
|
||||
|
||||
INDEX_DIR = Path("./test_indices")
|
||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||||
BACKEND_TO_TEST = "diskann"
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
# --- 1. Build index ---
|
||||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=BACKEND_TO_TEST,
|
||||
graph_degree=32,
|
||||
complexity=64
|
||||
)
|
||||
|
||||
documents = load_sample_documents()
|
||||
print(f"Loaded {len(documents)} sample documents.")
|
||||
for doc in documents:
|
||||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nIndex built!")
|
||||
|
||||
# --- 2. Basic search demo ---
|
||||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
||||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
||||
|
||||
query = "What is machine learning?"
|
||||
print(f"\nQuery: '{query}'")
|
||||
|
||||
print("\n--- Basic search mode (PQ computation) ---")
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=2)
|
||||
basic_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||
print(">>> Basic search results <<<")
|
||||
for i, res in enumerate(results, 1):
|
||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||
|
||||
# --- 3. Recompute search demo ---
|
||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
||||
|
||||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
||||
|
||||
# Configure recompute parameters
|
||||
recompute_params = {
|
||||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
||||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
||||
"skip_search_reorder": True, # Skip search reordering
|
||||
"dedup_node_dis": True, # Enable node distance deduplication
|
||||
"prune_ratio": 0.1, # Pruning ratio 10%
|
||||
"batch_recompute": False, # Don't use batch recomputation
|
||||
"global_pruning": False, # Don't use global pruning
|
||||
"zmq_port": 5555, # ZMQ port
|
||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
|
||||
}
|
||||
|
||||
print("Recompute parameter configuration:")
|
||||
for key, value in recompute_params.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print(f"\n🔄 Executing Recompute search...")
|
||||
try:
|
||||
start_time = time.time()
|
||||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
||||
recompute_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||
print(">>> Recompute search results <<<")
|
||||
for i, res in enumerate(recompute_results, 1):
|
||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||
|
||||
# Compare results
|
||||
print(f"\n--- Result comparison ---")
|
||||
print(f"Basic search time: {basic_time:.3f} seconds")
|
||||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
||||
|
||||
print("\nBasic search vs Recompute results:")
|
||||
for i in range(min(len(results), len(recompute_results))):
|
||||
basic_score = results[i].score
|
||||
recompute_score = recompute_results[i].score
|
||||
score_diff = abs(basic_score - recompute_score)
|
||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
||||
|
||||
if recompute_time > basic_time:
|
||||
print(f"✅ Recompute mode working correctly (more accurate but slower)")
|
||||
else:
|
||||
print(f"ℹ️ Recompute time is unusually fast, network recomputation may not be enabled")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Recompute search failed: {e}")
|
||||
print("This usually indicates an embedding server connection issue")
|
||||
|
||||
# --- 4. Chat demo ---
|
||||
print(f"\n[PHASE 4] Starting chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
chat_response = chat.ask(query)
|
||||
print(f"You: {query}")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
print("\n==========================================================")
|
||||
print("✅ Demo finished successfully!")
|
||||
print("==========================================================")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
380
examples/dynamic_add_leann_no_recompute.py
Normal file
380
examples/dynamic_add_leann_no_recompute.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Dynamic add example for LEANN using HNSW backend without recompute.
|
||||
|
||||
- Builds a base index from a directory of documents
|
||||
- Incrementally adds new documents without recomputing stored embeddings
|
||||
|
||||
Defaults:
|
||||
- Base data: /Users/yichuan/Desktop/code/LEANN/leann/data
|
||||
- Incremental data: /Users/yichuan/Desktop/code/LEANN/leann/test_add
|
||||
- Index path: <index_dir>/documents.leann
|
||||
|
||||
Usage examples:
|
||||
uv run python examples/dynamic_add_leann_no_recompute.py --build-base \
|
||||
--base-dir /Users/yichuan/Desktop/code/LEANN/leann/data \
|
||||
--index-dir ./test_doc_files
|
||||
|
||||
uv run python examples/dynamic_add_leann_no_recompute.py --add-incremental \
|
||||
--add-dir /Users/yichuan/Desktop/code/LEANN/leann/test_add \
|
||||
--index-dir ./test_doc_files
|
||||
|
||||
Quick recompute test (both true):
|
||||
# Recompute build
|
||||
uv run python examples/dynamic_add_leann_no_recompute.py --build-base \
|
||||
--recompute-build --ef-construction 200 \
|
||||
--base-dir /Users/yichuan/Desktop/code/LEANN/leann/data \
|
||||
--index-dir ./test_doc_files --index-name documents.leann
|
||||
|
||||
# Recompute add
|
||||
uv run python examples/dynamic_add_leann_no_recompute.py --add-incremental \
|
||||
--recompute-add --ef-construction 32 \
|
||||
--add-dir /Users/yichuan/Desktop/code/LEANN/leann/test_add \
|
||||
--index-dir ./test_doc_files --index-name documents.leann
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pickle
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
# Ensure we can import from the local packages and apps folders
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
CORE_SRC = ROOT / "packages" / "leann-core" / "src"
|
||||
HNSW_PKG_DIR = ROOT / "packages" / "leann-backend-hnsw"
|
||||
APPS_DIR = ROOT / "apps"
|
||||
|
||||
|
||||
# Prefer the installed backend if available (it contains the compiled extension)
|
||||
def _prefer_installed(pkg_name: str) -> bool:
|
||||
try:
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.find_spec(pkg_name)
|
||||
if spec and spec.origin and "site-packages" in spec.origin:
|
||||
# ensure the faiss shim/extension is importable from the installed package
|
||||
importlib.import_module(f"{pkg_name}.faiss")
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
# Prepend paths, but only add the repo backend if the installed one is not present
|
||||
paths_to_prepend = [CORE_SRC, APPS_DIR]
|
||||
if not _prefer_installed("leann_backend_hnsw"):
|
||||
paths_to_prepend.insert(1, HNSW_PKG_DIR)
|
||||
|
||||
for p in paths_to_prepend:
|
||||
p_str = str(p)
|
||||
if p_str not in sys.path:
|
||||
sys.path.insert(0, p_str)
|
||||
|
||||
# Defer non-stdlib imports until after sys.path setup within functions (avoid E402)
|
||||
|
||||
|
||||
def _load_documents(data_dir: str, required_exts: Optional[list[str]] = None) -> list[Any]:
|
||||
from llama_index.core import SimpleDirectoryReader # type: ignore
|
||||
|
||||
reader_kwargs: dict[str, Any] = {"recursive": True, "encoding": "utf-8"}
|
||||
if required_exts:
|
||||
reader_kwargs["required_exts"] = required_exts
|
||||
documents = SimpleDirectoryReader(data_dir, **reader_kwargs).load_data(show_progress=True)
|
||||
return documents
|
||||
|
||||
|
||||
def _ensure_index_dir(index_dir: Path) -> None:
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _index_files(index_path: Path) -> tuple[Path, Path, Path]:
|
||||
"""Return (passages.jsonl, passages.idx, index.index) paths for a given index base path.
|
||||
|
||||
Note: HNSWBackend writes the FAISS index using the stem (without .leann),
|
||||
i.e., for base 'documents.leann' the file is 'documents.index'. We prefer the
|
||||
existing file among candidates.
|
||||
"""
|
||||
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||
offsets_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||
candidate_name_index = index_path.parent / f"{index_path.name}.index"
|
||||
candidate_stem_index = index_path.parent / f"{index_path.stem}.index"
|
||||
index_file = candidate_stem_index if candidate_stem_index.exists() else candidate_name_index
|
||||
return passages_file, offsets_file, index_file
|
||||
|
||||
|
||||
def _read_meta(index_path: Path) -> dict[str, Any]:
|
||||
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"Metadata file not found: {meta_path}")
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _autodetect_index_base(index_dir: Path) -> Optional[Path]:
|
||||
"""If exactly one *.leann.meta.json exists, return its base path (without .meta.json)."""
|
||||
candidates = list(index_dir.glob("*.leann.meta.json"))
|
||||
if len(candidates) == 1:
|
||||
meta = candidates[0]
|
||||
base = meta.with_suffix("") # remove .json
|
||||
base = base.with_suffix("") # remove .meta
|
||||
return base
|
||||
return None
|
||||
|
||||
|
||||
def _load_offset_map(offsets_file: Path) -> dict[str, int]:
|
||||
if not offsets_file.exists():
|
||||
return {}
|
||||
with open(offsets_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def _next_numeric_id(existing_ids: list[str]) -> int:
|
||||
numeric_ids = [int(x) for x in existing_ids if x.isdigit()]
|
||||
if not numeric_ids:
|
||||
return 0
|
||||
return max(numeric_ids) + 1
|
||||
|
||||
|
||||
def build_base_index(
|
||||
base_dir: str,
|
||||
index_dir: str,
|
||||
index_name: str,
|
||||
embedding_model: str,
|
||||
embedding_mode: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
file_types: Optional[list[str]] = None,
|
||||
max_items: int = -1,
|
||||
ef_construction: Optional[int] = None,
|
||||
recompute_build: bool = False,
|
||||
) -> str:
|
||||
print(f"Building base index from: {base_dir}")
|
||||
documents = _load_documents(base_dir, required_exts=file_types)
|
||||
if not documents:
|
||||
raise ValueError(f"No documents found in base_dir: {base_dir}")
|
||||
|
||||
from chunking import create_text_chunks
|
||||
|
||||
texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
if max_items > 0 and len(texts) > max_items:
|
||||
texts = texts[:max_items]
|
||||
print(f"Limiting to {max_items} chunks")
|
||||
|
||||
index_dir_path = Path(index_dir)
|
||||
_ensure_index_dir(index_dir_path)
|
||||
index_path = index_dir_path / index_name
|
||||
|
||||
print("Creating HNSW index (non-compact)...")
|
||||
from leann.api import LeannBuilder
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode=embedding_mode,
|
||||
is_recompute=recompute_build,
|
||||
is_compact=False,
|
||||
efConstruction=(ef_construction if ef_construction is not None else 200),
|
||||
)
|
||||
for t in texts:
|
||||
builder.add_text(t)
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
# Register for discovery
|
||||
register_project_directory(Path.cwd())
|
||||
|
||||
print(f"Base index built at: {index_path}")
|
||||
return str(index_path)
|
||||
|
||||
|
||||
def add_incremental(
|
||||
add_dir: str,
|
||||
index_dir: str,
|
||||
index_name: Optional[str] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_mode: Optional[str] = None,
|
||||
chunk_size: int = 256,
|
||||
chunk_overlap: int = 128,
|
||||
file_types: Optional[list[str]] = None,
|
||||
max_items: int = -1,
|
||||
ef_construction: Optional[int] = None,
|
||||
recompute_add: bool = False,
|
||||
) -> str:
|
||||
print(f"Adding incremental data from: {add_dir}")
|
||||
index_dir_path = Path(index_dir)
|
||||
index_path = index_dir_path / (index_name or "documents.leann")
|
||||
|
||||
# If specified base doesn't exist, try to auto-detect an existing base
|
||||
try:
|
||||
_read_meta(index_path)
|
||||
except FileNotFoundError:
|
||||
auto_base = _autodetect_index_base(index_dir_path)
|
||||
if auto_base is not None:
|
||||
print(f"Auto-detected index base: {auto_base.name}")
|
||||
index_path = auto_base
|
||||
_read_meta(index_path)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"No index metadata found for base '{index_path.name}'. Build base first with --build-base "
|
||||
f"or provide --index-name to match an existing index (e.g., 'test_doc_files.leann')."
|
||||
)
|
||||
|
||||
# Prepare validated context from core (checks backend/no-recompute and resolves embedding defaults)
|
||||
from leann.api import create_incremental_add_context, incremental_add_texts_with_context
|
||||
|
||||
ctx = create_incremental_add_context(
|
||||
str(index_path),
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode=embedding_mode,
|
||||
data_dir=add_dir,
|
||||
required_exts=file_types,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
max_items=max_items,
|
||||
)
|
||||
|
||||
# Use prepared texts from context to perform the add
|
||||
prepared_texts = ctx.prepared_texts or []
|
||||
if not prepared_texts:
|
||||
print("No new chunks to add.")
|
||||
return str(index_path)
|
||||
|
||||
added = incremental_add_texts_with_context(
|
||||
ctx,
|
||||
prepared_texts,
|
||||
ef_construction=ef_construction,
|
||||
recompute=recompute_add,
|
||||
)
|
||||
|
||||
print(f"Incremental add completed. Added {added} chunks. Index: {index_path}")
|
||||
return str(index_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Dynamic add to LEANN HNSW index without recompute",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument("--build-base", action="store_true", help="Build base index")
|
||||
parser.add_argument("--add-incremental", action="store_true", help="Add incremental data")
|
||||
|
||||
parser.add_argument(
|
||||
"--base-dir",
|
||||
type=str,
|
||||
default="/Users/yichuan/Desktop/code/LEANN/leann/data",
|
||||
help="Base data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-dir",
|
||||
type=str,
|
||||
default="/Users/yichuan/Desktop/code/LEANN/leann/test_add",
|
||||
help="Incremental data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./test_doc_files",
|
||||
help="Directory containing the index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-name",
|
||||
type=str,
|
||||
default="documents.leann",
|
||||
help=(
|
||||
"Index base file name. If you built via document_rag.py, use 'test_doc_files.leann'. "
|
||||
"Default: documents.leann"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
|
||||
parser.add_argument("--chunk-size", type=int, default=256)
|
||||
parser.add_argument("--chunk-overlap", type=int, default=128)
|
||||
parser.add_argument("--file-types", nargs="+", default=None)
|
||||
parser.add_argument("--max-items", type=int, default=-1)
|
||||
parser.add_argument("--ef-construction", type=int, default=32)
|
||||
parser.add_argument(
|
||||
"--recompute-add", action="store_true", help="Enable recompute-mode add (non-compact only)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recompute-build",
|
||||
action="store_true",
|
||||
help="Enable recompute-mode base build (non-compact only)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.build_base and not args.add_incremental:
|
||||
print("Nothing to do. Use --build-base and/or --add-incremental.")
|
||||
return
|
||||
|
||||
index_path_str: Optional[str] = None
|
||||
|
||||
if args.build_base:
|
||||
index_path_str = build_base_index(
|
||||
base_dir=args.base_dir,
|
||||
index_dir=args.index_dir,
|
||||
index_name=args.index_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
file_types=args.file_types,
|
||||
max_items=args.max_items,
|
||||
ef_construction=args.ef_construction,
|
||||
recompute_build=args.recompute_build,
|
||||
)
|
||||
|
||||
if args.add_incremental:
|
||||
index_path_str = add_incremental(
|
||||
add_dir=args.add_dir,
|
||||
index_dir=args.index_dir,
|
||||
index_name=args.index_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
file_types=args.file_types,
|
||||
max_items=args.max_items,
|
||||
ef_construction=args.ef_construction,
|
||||
recompute_add=args.recompute_add,
|
||||
)
|
||||
|
||||
# Optional: quick test query using searcher
|
||||
if index_path_str:
|
||||
try:
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
searcher = LeannSearcher(index_path_str)
|
||||
query = "what is LEANN?"
|
||||
if args.add_incremental:
|
||||
query = "what is the multi vector search and how it works?"
|
||||
results = searcher.search(query, top_k=5)
|
||||
if results:
|
||||
print(f"Sample result: {results[0].text[:80]}...")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,124 +0,0 @@
|
||||
import os
|
||||
import email
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
def find_all_messages_directories(root: str = None) -> List[Path]:
|
||||
"""
|
||||
Recursively find all 'Messages' directories under the given root.
|
||||
Returns a list of Path objects.
|
||||
"""
|
||||
if root is None:
|
||||
# Auto-detect user's mail path
|
||||
home_dir = os.path.expanduser("~")
|
||||
root = os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
messages_dirs = []
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
if os.path.basename(dirpath) == "Messages":
|
||||
messages_dirs.append(Path(dirpath))
|
||||
return messages_dirs
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with embedded metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self, include_html: bool = False) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
include_html: Whether to include HTML content in the email body (default: False)
|
||||
"""
|
||||
self.include_html = include_html
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split('\n', 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get('Subject', 'No Subject')
|
||||
from_addr = msg.get('From', 'Unknown')
|
||||
to_addr = msg.get('To', 'Unknown')
|
||||
date = msg.get('Date', 'Unknown')
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
||||
if part.get_content_type() == "text/html" and not self.include_html:
|
||||
continue
|
||||
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||
# break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[EMAIL METADATA]
|
||||
File: {filename}
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
[END METADATA]
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
@@ -1,281 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
try:
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
except ModuleNotFoundError:
|
||||
# python-dotenv is not installed; skip loading environment variables
|
||||
dotenv = None
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
||||
|
||||
# Default Chrome profile path
|
||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
||||
"""
|
||||
Create LEANN index from multiple Chrome profile data sources.
|
||||
|
||||
Args:
|
||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process per profile
|
||||
"""
|
||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Chrome profile directory
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir),
|
||||
max_count=max_count
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {profile_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
|
||||
"""
|
||||
Create LEANN index from Chrome history data.
|
||||
|
||||
Args:
|
||||
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process
|
||||
"""
|
||||
print("Creating LEANN index from Chrome history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=profile_path,
|
||||
max_count=max_count
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} history documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
|
||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||
parser.add_argument('--max-entries', type=int, default=1000,
|
||||
help='Maximum number of history entries to process (default: 1000)')
|
||||
parser.add_argument('--query', type=str, default=None,
|
||||
help='Single query to run (default: runs example queries)')
|
||||
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
||||
help='Automatically find all Chrome profiles (default: True)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||
|
||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Find Chrome profile directories
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
if args.auto_find_profiles:
|
||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found automatically. Exiting.")
|
||||
return
|
||||
else:
|
||||
# Use single specified profile
|
||||
profile_path = Path(args.chrome_profile)
|
||||
if not profile_path.exists():
|
||||
print(f"Chrome profile not found: {profile_path}")
|
||||
return
|
||||
profile_dirs = [profile_path]
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"What websites did I visit about machine learning?",
|
||||
"Find my search history about programming"
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "="*60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
35
examples/grep_search_example.py
Normal file
35
examples/grep_search_example.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Grep Search Example
|
||||
|
||||
Shows how to use grep-based text search instead of semantic search.
|
||||
Useful when you need exact text matches rather than meaning-based results.
|
||||
"""
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
# Load your index
|
||||
searcher = LeannSearcher("my-documents.leann")
|
||||
|
||||
# Regular semantic search
|
||||
print("=== Semantic Search ===")
|
||||
results = searcher.search("machine learning algorithms", top_k=3)
|
||||
for result in results:
|
||||
print(f"Score: {result.score:.3f}")
|
||||
print(f"Text: {result.text[:80]}...")
|
||||
print()
|
||||
|
||||
# Grep-based search for exact text matches
|
||||
print("=== Grep Search ===")
|
||||
results = searcher.search("def train_model", top_k=3, use_grep=True)
|
||||
for result in results:
|
||||
print(f"Score: {result.score}")
|
||||
print(f"Text: {result.text[:80]}...")
|
||||
print()
|
||||
|
||||
# Find specific error messages
|
||||
error_results = searcher.search("FileNotFoundError", use_grep=True)
|
||||
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
|
||||
|
||||
# Search for function definitions
|
||||
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
|
||||
print(f"Found {len(func_results)} class definitions")
|
||||
@@ -1,286 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import dotenv
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
|
||||
# Add the project root to Python path so we can import from examples
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Auto-detect user's mail path
|
||||
def get_mail_path():
|
||||
"""Get the mail path for the current user"""
|
||||
home_dir = os.path.expanduser("~")
|
||||
return os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
# Default mail path for macOS
|
||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
|
||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||
"""
|
||||
Create LEANN index from multiple mail data sources.
|
||||
|
||||
Args:
|
||||
messages_dirs: List of Path objects pointing to Messages directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process per directory
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from multiple mail data sources...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Messages directory
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(messages_dir)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {messages_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||
"""
|
||||
Create LEANN index from mail data.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from mail data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
documents = reader.load_data(Path(mail_path))
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} email documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path,
|
||||
llm_config={"type": "openai", "model": "gpt-4o"})
|
||||
|
||||
print(f"You: {query}")
|
||||
import time
|
||||
start_time = time.time()
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=12,
|
||||
beam_width=1,
|
||||
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||
# Remove --mail-path argument and auto-detect all Messages directories
|
||||
# Remove DEFAULT_MAIL_PATH
|
||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
|
||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||
parser.add_argument('--max-emails', type=int, default=1000,
|
||||
help='Maximum number of emails to process (-1 means all)')
|
||||
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
|
||||
help='Single query to run (default: runs example queries)')
|
||||
parser.add_argument('--include-html', action='store_true', default=False,
|
||||
help='Include HTML content in email processing (default: False)')
|
||||
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
||||
help='Embedding model to use (default: facebook/contriever)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"args: {args}")
|
||||
|
||||
# Automatically find all Messages directories under the current user's Mail directory
|
||||
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
||||
mail_path = get_mail_path()
|
||||
print(f"Searching for email data in: {mail_path}")
|
||||
messages_dirs = find_all_messages_directories(mail_path)
|
||||
|
||||
print('len(messages_dirs): ', len(messages_dirs))
|
||||
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Messages directories found. Exiting.")
|
||||
return
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "="*60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,108 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
|
||||
# Add the project root to Python path so we can import from examples
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from llama_index.core import VectorStoreIndex, StorageContext
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# --- EMBEDDING MODEL ---
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
import torch
|
||||
|
||||
# --- END EMBEDDING MODEL ---
|
||||
|
||||
# Import EmlxReader from the new module
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000, include_html: bool = False):
|
||||
print("Creating index from mail data with embedded metadata...")
|
||||
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
# Use facebook/contriever as the embedder
|
||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||
# set on device
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
embed_model._model.to("cuda")
|
||||
# set mps
|
||||
elif torch.backends.mps.is_available():
|
||||
embed_model._model.to("mps")
|
||||
else:
|
||||
embed_model._model.to("cpu")
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
transformations=[text_splitter],
|
||||
embed_model=embed_model
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
return index
|
||||
|
||||
def load_index(save_dir: str = "mail_index_embedded"):
|
||||
try:
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store,
|
||||
storage_context=storage_context
|
||||
)
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
def query_index(index, query: str):
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LlamaIndex Mail Reader - Create and query email index')
|
||||
parser.add_argument('--mail-path', type=str,
|
||||
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||
help='Path to mail data directory')
|
||||
parser.add_argument('--save-dir', type=str, default="mail_index_embedded",
|
||||
help='Directory to store the index (default: mail_index_embedded)')
|
||||
parser.add_argument('--max-emails', type=int, default=10000,
|
||||
help='Maximum number of emails to process')
|
||||
parser.add_argument('--include-html', action='store_true', default=False,
|
||||
help='Include HTML content in email processing (default: False)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mail_path = args.mail_path
|
||||
save_dir = args.save_dir
|
||||
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html)
|
||||
if index:
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "="*50)
|
||||
query_index(index, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,110 +0,0 @@
|
||||
import argparse
|
||||
from llama_index.core import SimpleDirectoryReader, Settings
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
import asyncio
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
|
||||
async def main(args):
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
|
||||
# query = (
|
||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
# )
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Leann Chat with various LLM backends."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="hf",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
help="The LLM backend to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-0.6B",
|
||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="The host for the Ollama API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./test_doc_files",
|
||||
help="Directory where the Leann index will be stored.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
|
||||
# Define the path for our new MLX-based index
|
||||
INDEX_PATH = "./mlx_diskann_index/leann"
|
||||
@@ -12,7 +13,7 @@ else:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
|
||||
use_mlx=True,
|
||||
embedding_mode="mlx",
|
||||
)
|
||||
|
||||
# 2. Add documents
|
||||
@@ -38,7 +39,5 @@ chat = LeannChat(index_path=INDEX_PATH)
|
||||
# add query
|
||||
query = "MLX is an array framework for machine learning on Apple silicon."
|
||||
print(f"Query: {query}")
|
||||
response = chat.ask(
|
||||
query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1
|
||||
)
|
||||
response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1)
|
||||
print(f"Response: {response}")
|
||||
@@ -1,319 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Multi-Vector Aggregator for Fat Embeddings
|
||||
==========================================
|
||||
|
||||
This module implements aggregation strategies for multi-vector embeddings,
|
||||
similar to ColPali's approach where multiple patch vectors represent a single document.
|
||||
|
||||
Key features:
|
||||
- MaxSim aggregation (take maximum similarity across patches)
|
||||
- Voting-based aggregation (count patch matches)
|
||||
- Weighted aggregation (attention-score weighted)
|
||||
- Spatial clustering of matching patches
|
||||
- Document-level result consolidation
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
@dataclass
|
||||
class PatchResult:
|
||||
"""Represents a single patch search result."""
|
||||
patch_id: int
|
||||
image_name: str
|
||||
image_path: str
|
||||
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||
score: float
|
||||
attention_score: float
|
||||
scale: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class AggregatedResult:
|
||||
"""Represents an aggregated document-level result."""
|
||||
image_name: str
|
||||
image_path: str
|
||||
doc_score: float
|
||||
patch_count: int
|
||||
best_patch: PatchResult
|
||||
all_patches: List[PatchResult]
|
||||
aggregation_method: str
|
||||
spatial_clusters: Optional[List[List[PatchResult]]] = None
|
||||
|
||||
class MultiVectorAggregator:
|
||||
"""
|
||||
Aggregates multiple patch-level results into document-level results.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
aggregation_method: str = "maxsim",
|
||||
spatial_clustering: bool = True,
|
||||
cluster_distance_threshold: float = 100.0):
|
||||
"""
|
||||
Initialize the aggregator.
|
||||
|
||||
Args:
|
||||
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||
spatial_clustering: Whether to cluster spatially close patches
|
||||
cluster_distance_threshold: Distance threshold for spatial clustering
|
||||
"""
|
||||
self.aggregation_method = aggregation_method
|
||||
self.spatial_clustering = spatial_clustering
|
||||
self.cluster_distance_threshold = cluster_distance_threshold
|
||||
|
||||
def aggregate_results(self,
|
||||
search_results: List[Dict[str, Any]],
|
||||
top_k: int = 10) -> List[AggregatedResult]:
|
||||
"""
|
||||
Aggregate patch-level search results into document-level results.
|
||||
|
||||
Args:
|
||||
search_results: List of search results from LeannSearcher
|
||||
top_k: Number of top documents to return
|
||||
|
||||
Returns:
|
||||
List of aggregated document results
|
||||
"""
|
||||
# Group results by image
|
||||
image_groups = defaultdict(list)
|
||||
|
||||
for result in search_results:
|
||||
metadata = result.metadata
|
||||
if "image_name" in metadata and "patch_id" in metadata:
|
||||
patch_result = PatchResult(
|
||||
patch_id=metadata["patch_id"],
|
||||
image_name=metadata["image_name"],
|
||||
image_path=metadata["image_path"],
|
||||
coordinates=tuple(metadata["coordinates"]),
|
||||
score=result.score,
|
||||
attention_score=metadata.get("attention_score", 0.0),
|
||||
scale=metadata.get("scale", 1.0),
|
||||
metadata=metadata
|
||||
)
|
||||
image_groups[metadata["image_name"]].append(patch_result)
|
||||
|
||||
# Aggregate each image group
|
||||
aggregated_results = []
|
||||
for image_name, patches in image_groups.items():
|
||||
if len(patches) == 0:
|
||||
continue
|
||||
|
||||
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||
aggregated_results.append(agg_result)
|
||||
|
||||
# Sort by aggregated score and return top-k
|
||||
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||
return aggregated_results[:top_k]
|
||||
|
||||
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
|
||||
"""Aggregate patches for a single image."""
|
||||
|
||||
if self.aggregation_method == "maxsim":
|
||||
doc_score = max(patch.score for patch in patches)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "voting":
|
||||
# Count patches above threshold
|
||||
threshold = np.percentile([p.score for p in patches], 75)
|
||||
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "weighted":
|
||||
# Weight by attention scores
|
||||
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||
total_weights = sum(p.attention_score for p in patches)
|
||||
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||
|
||||
elif self.aggregation_method == "mean":
|
||||
doc_score = np.mean([patch.score for patch in patches])
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||
|
||||
# Spatial clustering if enabled
|
||||
spatial_clusters = None
|
||||
if self.spatial_clustering:
|
||||
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||
|
||||
return AggregatedResult(
|
||||
image_name=image_name,
|
||||
image_path=patches[0].image_path,
|
||||
doc_score=float(doc_score),
|
||||
patch_count=len(patches),
|
||||
best_patch=best_patch,
|
||||
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||
aggregation_method=self.aggregation_method,
|
||||
spatial_clusters=spatial_clusters
|
||||
)
|
||||
|
||||
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
|
||||
"""Cluster patches that are spatially close to each other."""
|
||||
if len(patches) <= 1:
|
||||
return [patches]
|
||||
|
||||
clusters = []
|
||||
remaining_patches = patches.copy()
|
||||
|
||||
while remaining_patches:
|
||||
# Start new cluster with highest scoring remaining patch
|
||||
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||
current_cluster = [seed_patch]
|
||||
remaining_patches.remove(seed_patch)
|
||||
|
||||
# Add nearby patches to cluster
|
||||
added_to_cluster = True
|
||||
while added_to_cluster:
|
||||
added_to_cluster = False
|
||||
for patch in remaining_patches.copy():
|
||||
if self._is_patch_nearby(patch, current_cluster):
|
||||
current_cluster.append(patch)
|
||||
remaining_patches.remove(patch)
|
||||
added_to_cluster = True
|
||||
|
||||
clusters.append(current_cluster)
|
||||
|
||||
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||
|
||||
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
|
||||
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||
patch_center = self._get_patch_center(patch.coordinates)
|
||||
|
||||
for cluster_patch in cluster:
|
||||
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
|
||||
(patch_center[1] - cluster_center[1])**2)
|
||||
|
||||
if distance <= self.cluster_distance_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
||||
"""Get center point of a patch."""
|
||||
x1, y1, x2, y2 = coordinates
|
||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
|
||||
"""Pretty print aggregated results."""
|
||||
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||
print("=" * 80)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n{i+1}. {result.image_name}")
|
||||
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||
print(f" Path: {result.image_path}")
|
||||
|
||||
# Show best patch
|
||||
best = result.best_patch
|
||||
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
|
||||
|
||||
# Show top patches
|
||||
print(f" 📍 Top Patches:")
|
||||
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
|
||||
|
||||
# Show spatial clusters if available
|
||||
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||
cluster_score = max(p.score for p in cluster)
|
||||
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
|
||||
|
||||
def demo_aggregation():
|
||||
"""Demonstrate the multi-vector aggregation functionality."""
|
||||
print("=== Multi-Vector Aggregation Demo ===")
|
||||
|
||||
# Simulate some patch-level search results
|
||||
# In real usage, these would come from LeannSearcher.search()
|
||||
|
||||
class MockResult:
|
||||
def __init__(self, score, metadata):
|
||||
self.score = score
|
||||
self.metadata = metadata
|
||||
|
||||
# Simulate results for 2 images with multiple patches each
|
||||
mock_results = [
|
||||
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||
MockResult(0.85, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 3,
|
||||
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||
"attention_score": 0.92,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.78, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 7,
|
||||
"coordinates": [200, 300, 324, 424], # Cat area
|
||||
"attention_score": 0.88,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.72, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 12,
|
||||
"coordinates": [150, 100, 274, 224], # Appliances
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.65, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 15,
|
||||
"coordinates": [50, 250, 174, 374], # Furniture
|
||||
"attention_score": 0.70,
|
||||
"scale": 1.0
|
||||
}),
|
||||
|
||||
# Image 2: city_street.jpg - 3 patches
|
||||
MockResult(0.68, {
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 2,
|
||||
"coordinates": [300, 100, 424, 224], # Buildings
|
||||
"attention_score": 0.80,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.62, {
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 8,
|
||||
"coordinates": [100, 350, 224, 474], # Street level
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.55, {
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 11,
|
||||
"coordinates": [400, 200, 524, 324], # Sky area
|
||||
"attention_score": 0.60,
|
||||
"scale": 1.0
|
||||
}),
|
||||
]
|
||||
|
||||
# Test different aggregation methods
|
||||
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||
|
||||
for method in methods:
|
||||
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
|
||||
|
||||
aggregator = MultiVectorAggregator(
|
||||
aggregation_method=method,
|
||||
spatial_clustering=True,
|
||||
cluster_distance_threshold=100.0
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||
aggregator.print_aggregated_results(aggregated)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_aggregation()
|
||||
@@ -1,108 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OpenAI Embedding Example
|
||||
|
||||
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
|
||||
"""
|
||||
|
||||
import os
|
||||
import dotenv
|
||||
from pathlib import Path
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
# Load environment variables
|
||||
dotenv.load_dotenv()
|
||||
|
||||
def main():
|
||||
# Check if OpenAI API key is available
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
print("ERROR: OPENAI_API_KEY environment variable not set")
|
||||
return False
|
||||
|
||||
print(f"✅ OpenAI API key found: {api_key[:10]}...")
|
||||
|
||||
# Sample texts
|
||||
sample_texts = [
|
||||
"Machine learning is a powerful technology that enables computers to learn from data.",
|
||||
"Natural language processing helps computers understand and generate human language.",
|
||||
"Deep learning uses neural networks with multiple layers to solve complex problems.",
|
||||
"Computer vision allows machines to interpret and understand visual information.",
|
||||
"Reinforcement learning trains agents to make decisions through trial and error.",
|
||||
"Data science combines statistics, math, and programming to extract insights from data.",
|
||||
"Artificial intelligence aims to create machines that can perform human-like tasks.",
|
||||
"Python is a popular programming language used extensively in data science and AI.",
|
||||
"Neural networks are inspired by the structure and function of the human brain.",
|
||||
"Big data refers to extremely large datasets that require special tools to process."
|
||||
]
|
||||
|
||||
INDEX_DIR = Path("./simple_openai_test_index")
|
||||
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
|
||||
|
||||
print(f"\n=== Building Index with OpenAI Embeddings ===")
|
||||
print(f"Index path: {INDEX_PATH}")
|
||||
|
||||
try:
|
||||
# Use proper configuration for OpenAI embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
# HNSW settings for OpenAI embeddings
|
||||
M=16, # Smaller graph degree
|
||||
efConstruction=64, # Smaller construction complexity
|
||||
is_compact=True, # Enable compact storage for recompute
|
||||
is_recompute=True, # MUST enable for OpenAI embeddings
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
print(f"Adding {len(sample_texts)} texts to the index...")
|
||||
for i, text in enumerate(sample_texts):
|
||||
metadata = {"id": f"doc_{i}", "topic": "AI"}
|
||||
builder.add_text(text, metadata)
|
||||
|
||||
print("Building index...")
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"✅ Index built successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error building index: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print(f"\n=== Testing Search ===")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
|
||||
test_queries = [
|
||||
"What is machine learning?",
|
||||
"How do neural networks work?",
|
||||
"Programming languages for data science"
|
||||
]
|
||||
|
||||
for query in test_queries:
|
||||
print(f"\n🔍 Query: '{query}'")
|
||||
results = searcher.search(query, top_k=3)
|
||||
|
||||
print(f" Found {len(results)} results:")
|
||||
for i, result in enumerate(results):
|
||||
print(f" {i+1}. Score: {result.score:.4f}")
|
||||
print(f" Text: {result.text[:80]}...")
|
||||
|
||||
print(f"\n✅ Search test completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during search: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
if success:
|
||||
print(f"\n🎉 Simple OpenAI index test completed successfully!")
|
||||
else:
|
||||
print(f"\n💥 Simple OpenAI index test failed!")
|
||||
@@ -1,18 +0,0 @@
|
||||
import asyncio
|
||||
from leann.api import LeannChat
|
||||
from pathlib import Path
|
||||
|
||||
INDEX_DIR = Path("./test_pdf_index_huawei")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
async def main():
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
|
||||
print(f"\n[PHASE 2] Response: {response}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
250
examples/spoiler_free_book_rag.py
Normal file
250
examples/spoiler_free_book_rag.py
Normal file
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
|
||||
|
||||
This example demonstrates how to use LEANN's metadata filtering to create
|
||||
a spoiler-free book RAG system where users can search for information
|
||||
up to a specific chapter they've read.
|
||||
|
||||
Usage:
|
||||
python spoiler_free_book_rag.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
# Add LEANN to path (adjust path as needed)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Create sample book chunks with metadata for demonstration.
|
||||
|
||||
In a real implementation, this would parse actual book files (epub, txt, etc.)
|
||||
and extract chapter boundaries, character mentions, etc.
|
||||
|
||||
Args:
|
||||
book_title: Title of the book
|
||||
|
||||
Returns:
|
||||
List of chunk dictionaries with text and metadata
|
||||
"""
|
||||
# Sample book chunks with metadata
|
||||
# In practice, you'd use proper text processing libraries
|
||||
|
||||
sample_chunks = [
|
||||
{
|
||||
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 1,
|
||||
"page": 1,
|
||||
"characters": ["Alice", "Sister"],
|
||||
"themes": ["boredom", "curiosity"],
|
||||
"location": "riverbank",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 1,
|
||||
"page": 2,
|
||||
"characters": ["Alice", "White Rabbit"],
|
||||
"themes": ["decision", "surprise", "magic"],
|
||||
"location": "riverbank",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 2,
|
||||
"page": 15,
|
||||
"characters": ["Alice"],
|
||||
"themes": ["falling", "wonder", "transformation"],
|
||||
"location": "rabbit hole",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 6,
|
||||
"page": 85,
|
||||
"characters": ["Alice", "Cheshire Cat"],
|
||||
"themes": ["madness", "philosophy", "identity"],
|
||||
"location": "Duchess's house",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 8,
|
||||
"page": 120,
|
||||
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
|
||||
"themes": ["justice", "absurdity", "authority"],
|
||||
"location": "Queen's court",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 12,
|
||||
"page": 180,
|
||||
"characters": ["Alice", "Sister", "Rabbit"],
|
||||
"themes": ["revelation", "reality", "growth"],
|
||||
"location": "riverbank",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
return sample_chunks
|
||||
|
||||
|
||||
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
|
||||
"""
|
||||
Build a LEANN index with book chunks that include spoiler metadata.
|
||||
|
||||
Args:
|
||||
book_chunks: List of book chunks with metadata
|
||||
index_name: Name for the index
|
||||
|
||||
Returns:
|
||||
Path to the built index
|
||||
"""
|
||||
print(f"📚 Building spoiler-free book index: {index_name}")
|
||||
|
||||
# Initialize LEANN builder
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
|
||||
)
|
||||
|
||||
# Add each chunk with its metadata
|
||||
for chunk in book_chunks:
|
||||
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
|
||||
|
||||
# Build the index
|
||||
index_path = f"{index_name}_book_index"
|
||||
builder.build_index(index_path)
|
||||
|
||||
print(f"✅ Index built successfully: {index_path}")
|
||||
return index_path
|
||||
|
||||
|
||||
def spoiler_free_search(
|
||||
index_path: str,
|
||||
query: str,
|
||||
max_chapter: int,
|
||||
character_filter: Optional[list[str]] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Perform a spoiler-free search on the book index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: Search query
|
||||
max_chapter: Maximum chapter number to include
|
||||
character_filter: Optional list of characters to focus on
|
||||
|
||||
Returns:
|
||||
List of search results safe for the reader
|
||||
"""
|
||||
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
|
||||
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
metadata_filters = {"chapter": {"<=": max_chapter}}
|
||||
|
||||
if character_filter:
|
||||
metadata_filters["characters"] = {"contains": character_filter[0]}
|
||||
|
||||
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def demo_spoiler_free_rag():
|
||||
"""
|
||||
Demonstrate the spoiler-free book RAG system.
|
||||
"""
|
||||
print("🎭 Spoiler-Free Book RAG Demo")
|
||||
print("=" * 40)
|
||||
|
||||
# Step 1: Prepare book data
|
||||
book_title = "Alice's Adventures in Wonderland"
|
||||
book_chunks = chunk_book_with_metadata(book_title)
|
||||
|
||||
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
|
||||
|
||||
# Step 2: Build the index (in practice, this would be done once)
|
||||
try:
|
||||
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to build index (likely missing dependencies): {e}")
|
||||
print(
|
||||
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
|
||||
)
|
||||
return
|
||||
|
||||
# Step 3: Demonstrate various spoiler-free searches
|
||||
search_scenarios = [
|
||||
{
|
||||
"description": "Reader who has only read Chapter 1",
|
||||
"query": "What can you tell me about the rabbit?",
|
||||
"max_chapter": 1,
|
||||
},
|
||||
{
|
||||
"description": "Reader who has read up to Chapter 5",
|
||||
"query": "Tell me about Alice's adventures",
|
||||
"max_chapter": 5,
|
||||
},
|
||||
{
|
||||
"description": "Reader who has read most of the book",
|
||||
"query": "What does the Cheshire Cat represent?",
|
||||
"max_chapter": 10,
|
||||
},
|
||||
{
|
||||
"description": "Reader who has read the whole book",
|
||||
"query": "What can you tell me about the rabbit?",
|
||||
"max_chapter": 12,
|
||||
},
|
||||
]
|
||||
|
||||
for scenario in search_scenarios:
|
||||
print(f"\n📚 Scenario: {scenario['description']}")
|
||||
print(f" Query: {scenario['query']}")
|
||||
|
||||
try:
|
||||
results = spoiler_free_search(
|
||||
index_path=index_path,
|
||||
query=scenario["query"],
|
||||
max_chapter=scenario["max_chapter"],
|
||||
)
|
||||
|
||||
print(f" 📄 Found {len(results)} results:")
|
||||
for i, result in enumerate(results[:3], 1): # Show top 3
|
||||
chapter = result.metadata.get("chapter", "?")
|
||||
location = result.metadata.get("location", "?")
|
||||
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Search failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("📚 LEANN Spoiler-Free Book RAG Example")
|
||||
print("=====================================")
|
||||
|
||||
try:
|
||||
demo_spoiler_free_rag()
|
||||
except ImportError as e:
|
||||
print(f"❌ Cannot run demo due to missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error running demo: {e}")
|
||||
@@ -1,318 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
import dotenv
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Optional
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
import requests
|
||||
import time
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Default WeChat export directory
|
||||
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs: List[Path],
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = -1,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple WeChat export data sources.
|
||||
|
||||
Args:
|
||||
export_dirs: List of Path objects pointing to WeChat export directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process per export
|
||||
"""
|
||||
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each WeChat export directory
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(
|
||||
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
||||
)
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_count,
|
||||
concatenate_messages=False, # Disable concatenation - one message per document
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=64)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(
|
||||
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
export_dir: str = None,
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = 1000,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from WeChat chat history data.
|
||||
|
||||
Args:
|
||||
export_dir: Path to the WeChat export directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process
|
||||
"""
|
||||
print("Creating LEANN index from WeChat chat history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=export_dir,
|
||||
max_count=max_count,
|
||||
concatenate_messages=False, # Disable concatenation - one message per document
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} chat documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=128,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function with integrated WeChat export functionality."""
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./wechat_history_june19_test",
|
||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Maximum number of chat entries to process (default: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||
|
||||
print(f"Using WeChat export directory: {args.export_dir}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Initialize WeChat reader with export capabilities
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Exiting.")
|
||||
return
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
28
llms.txt
Normal file
28
llms.txt
Normal file
@@ -0,0 +1,28 @@
|
||||
# llms.txt — LEANN MCP and Agent Integration
|
||||
product: LEANN
|
||||
homepage: https://github.com/yichuan-w/LEANN
|
||||
contact: https://github.com/yichuan-w/LEANN/issues
|
||||
|
||||
# Installation
|
||||
install: uv tool install leann-core --with leann
|
||||
|
||||
# MCP Server Entry Point
|
||||
mcp.server: leann_mcp
|
||||
mcp.protocol_version: 2024-11-05
|
||||
|
||||
# Tools
|
||||
mcp.tools: leann_list, leann_search
|
||||
|
||||
mcp.tool.leann_list.description: List available LEANN indexes
|
||||
mcp.tool.leann_list.input: {}
|
||||
|
||||
mcp.tool.leann_search.description: Semantic search across a named LEANN index
|
||||
mcp.tool.leann_search.input.index_name: string, required
|
||||
mcp.tool.leann_search.input.query: string, required
|
||||
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
|
||||
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
|
||||
|
||||
# Notes
|
||||
note: Build indexes with `leann build <name> --docs <files...>` before searching.
|
||||
example.add: claude mcp add --scope user leann-server -- leann_mcp
|
||||
example.verify: claude mcp list | cat
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
1
packages/astchunk-leann
Submodule
1
packages/astchunk-leann
Submodule
Submodule packages/astchunk-leann added at ad9afa07b9
@@ -1,8 +0,0 @@
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
||||
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(leann_backend_diskann_wrapper)
|
||||
|
||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
||||
add_subdirectory(src/third_party/DiskANN)
|
||||
@@ -1 +1 @@
|
||||
# This file makes the directory a Python package
|
||||
# This file makes the directory a Python package
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
from . import diskann_backend
|
||||
from . import diskann_backend as diskann_backend
|
||||
from . import graph_partition
|
||||
|
||||
# Export main classes and functions
|
||||
from .graph_partition import GraphPartitioner, partition_graph
|
||||
|
||||
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
||||
|
||||
@@ -1,18 +1,65 @@
|
||||
import numpy as np
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal
|
||||
import contextlib
|
||||
import pickle
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from leann.registry import register_backend
|
||||
import numpy as np
|
||||
import psutil
|
||||
from leann.interface import (
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
from leann.registry import register_backend
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_cpp_output_if_needed():
|
||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
|
||||
if os.getenv("CI") == "true":
|
||||
yield
|
||||
return
|
||||
|
||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
|
||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
if not should_suppress:
|
||||
# Don't suppress, just yield
|
||||
yield
|
||||
return
|
||||
|
||||
# Save original file descriptors
|
||||
stdout_fd = sys.stdout.fileno()
|
||||
stderr_fd = sys.stderr.fileno()
|
||||
|
||||
# Save original stdout/stderr
|
||||
stdout_dup = os.dup(stdout_fd)
|
||||
stderr_dup = os.dup(stderr_fd)
|
||||
|
||||
try:
|
||||
# Redirect to /dev/null
|
||||
devnull = os.open(os.devnull, os.O_WRONLY)
|
||||
os.dup2(devnull, stdout_fd)
|
||||
os.dup2(devnull, stderr_fd)
|
||||
os.close(devnull)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Restore original file descriptors
|
||||
os.dup2(stdout_dup, stdout_fd)
|
||||
os.dup2(stderr_dup, stderr_fd)
|
||||
os.close(stdout_dup)
|
||||
os.close(stderr_dup)
|
||||
|
||||
|
||||
def _get_diskann_metrics():
|
||||
@@ -43,6 +90,43 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||
f.write(data.tobytes())
|
||||
|
||||
|
||||
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
||||
|
||||
Args:
|
||||
data: The embedding data array
|
||||
|
||||
Returns:
|
||||
tuple: (search_memory_maximum, build_memory_maximum) in GB
|
||||
"""
|
||||
num_vectors, dim = data.shape
|
||||
|
||||
# Calculate embedding storage size
|
||||
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
|
||||
embedding_size_gb = embedding_size_bytes / (1024**3)
|
||||
|
||||
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
|
||||
# This controls Product Quantization size - smaller means more compression
|
||||
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
|
||||
|
||||
# build_memory_maximum: Based on available system RAM for sharding control
|
||||
# This controls how much memory DiskANN uses during index construction
|
||||
available_memory_gb = psutil.virtual_memory().available / (1024**3)
|
||||
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
||||
|
||||
# Use 50% of available memory, but at least 2GB and at most 75% of total
|
||||
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
|
||||
|
||||
logger.info(
|
||||
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
|
||||
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
|
||||
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
|
||||
)
|
||||
|
||||
return search_memory_gb, build_memory_gb
|
||||
|
||||
|
||||
@register_backend("diskann")
|
||||
class DiskannBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -58,29 +142,113 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
||||
"""
|
||||
Safely cleanup files after partition.
|
||||
In partition mode, C++ doesn't read _disk.index content,
|
||||
so we can delete it if all derived files exist.
|
||||
"""
|
||||
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
||||
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
||||
|
||||
# Required files that C++ partition mode needs
|
||||
# Note: C++ generates these with _disk.index suffix
|
||||
disk_suffix = "_disk.index"
|
||||
required_files = [
|
||||
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
||||
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
||||
f"{index_prefix}_pq_pivots.bin", # PQ table
|
||||
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
||||
]
|
||||
|
||||
# Check if all required files exist
|
||||
missing_files = []
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if not file_path.exists():
|
||||
missing_files.append(filename)
|
||||
|
||||
if missing_files:
|
||||
logger.warning(
|
||||
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
||||
)
|
||||
logger.info("Keeping all original files for safety")
|
||||
return
|
||||
|
||||
# Calculate space savings
|
||||
space_saved = 0
|
||||
files_to_delete = []
|
||||
|
||||
if disk_index_file.exists():
|
||||
space_saved += disk_index_file.stat().st_size
|
||||
files_to_delete.append(disk_index_file)
|
||||
|
||||
if beam_search_file.exists():
|
||||
space_saved += beam_search_file.stat().st_size
|
||||
files_to_delete.append(beam_search_file)
|
||||
|
||||
# Safe to delete!
|
||||
for file_to_delete in files_to_delete:
|
||||
try:
|
||||
os.remove(file_to_delete)
|
||||
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
||||
|
||||
if space_saved > 0:
|
||||
space_saved_mb = space_saved / (1024 * 1024)
|
||||
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
||||
|
||||
# Show what files are kept
|
||||
logger.info("📁 Kept essential files for partition mode:")
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if file_path.exists():
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||
data = data.astype(np.float32)
|
||||
|
||||
data_filename = f"{index_prefix}_data.bin"
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
|
||||
# Extract is_recompute from nested backend_kwargs if needed
|
||||
is_recompute = build_kwargs.get("is_recompute", False)
|
||||
if not is_recompute and "backend_kwargs" in build_kwargs:
|
||||
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
||||
|
||||
# Flatten all backend_kwargs parameters to top level for compatibility
|
||||
if "backend_kwargs" in build_kwargs:
|
||||
nested_params = build_kwargs.pop("backend_kwargs")
|
||||
build_kwargs.update(nested_params)
|
||||
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
build_kwargs.get("distance_metric", "mips").lower()
|
||||
)
|
||||
if metric_enum is None:
|
||||
raise ValueError("Unsupported distance_metric.")
|
||||
raise ValueError(
|
||||
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||
)
|
||||
|
||||
# Calculate smart memory configuration if not explicitly provided
|
||||
if (
|
||||
"search_memory_maximum" not in build_kwargs
|
||||
or "build_memory_maximum" not in build_kwargs
|
||||
):
|
||||
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
|
||||
else:
|
||||
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
|
||||
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
|
||||
|
||||
try:
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
@@ -92,46 +260,125 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
index_prefix,
|
||||
build_kwargs.get("complexity", 64),
|
||||
build_kwargs.get("graph_degree", 32),
|
||||
build_kwargs.get("search_memory_maximum", 4.0),
|
||||
build_kwargs.get("build_memory_maximum", 8.0),
|
||||
build_kwargs.get("search_memory_maximum", smart_search_mem),
|
||||
build_kwargs.get("build_memory_maximum", smart_build_mem),
|
||||
build_kwargs.get("num_threads", 8),
|
||||
build_kwargs.get("pq_disk_bytes", 0),
|
||||
"",
|
||||
)
|
||||
|
||||
# Auto-partition if is_recompute is enabled
|
||||
if build_kwargs.get("is_recompute", False):
|
||||
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
||||
from .graph_partition import partition_graph
|
||||
|
||||
# Partition the index using absolute paths
|
||||
# Convert to absolute paths to avoid issues with working directory changes
|
||||
absolute_index_dir = Path(index_dir).resolve()
|
||||
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
index_prefix_path=absolute_index_prefix_path,
|
||||
output_dir=str(absolute_index_dir),
|
||||
partition_prefix=index_prefix,
|
||||
)
|
||||
|
||||
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
||||
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
||||
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
||||
|
||||
logger.info("✅ Graph partitioning completed successfully!")
|
||||
logger.info(f" - Disk graph: {disk_graph_path}")
|
||||
logger.info(f" - Partition file: {partition_bin_path}")
|
||||
|
||||
finally:
|
||||
temp_data_file = index_dir / data_filename
|
||||
if temp_data_file.exists():
|
||||
os.remove(temp_data_file)
|
||||
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
|
||||
|
||||
|
||||
class DiskannSearcher(BaseSearcher):
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
super().__init__(
|
||||
index_path,
|
||||
backend_module_name="leann_backend_diskann.embedding_server",
|
||||
backend_module_name="leann_backend_diskann.diskann_embedding_server",
|
||||
**kwargs,
|
||||
)
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||
# Initialize DiskANN index with suppressed C++ output based on log level
|
||||
with suppress_cpp_output_if_needed():
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
self.zmq_port = kwargs.get("zmq_port", 6666)
|
||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
self.zmq_port,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
|
||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||
# Store the initialization parameters for later use
|
||||
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
||||
# C++ internally constructs: index_prefix + "_disk.index"
|
||||
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
||||
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
||||
|
||||
# Auto-detect partition files and set partition_prefix
|
||||
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
||||
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
||||
|
||||
partition_prefix = ""
|
||||
if partition_graph_file.exists() and partition_bin_file.exists():
|
||||
# C++ expects full path prefix, not just filename
|
||||
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
else:
|
||||
logger.debug("No partition files detected, using standard index files")
|
||||
|
||||
self._init_params = {
|
||||
"metric_enum": metric_enum,
|
||||
"full_index_prefix": full_index_prefix,
|
||||
"num_threads": self.num_threads,
|
||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||
"cache_mechanism": 1,
|
||||
"pq_prefix": "",
|
||||
"partition_prefix": partition_prefix,
|
||||
}
|
||||
|
||||
# Log partition configuration for debugging
|
||||
if partition_prefix:
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
self._diskannpy = diskannpy
|
||||
self._current_zmq_port = None
|
||||
self._index = None
|
||||
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
||||
|
||||
def _ensure_index_loaded(self, zmq_port: int):
|
||||
"""Ensure the index is loaded with the correct zmq_port."""
|
||||
if self._index is None or self._current_zmq_port != zmq_port:
|
||||
# Need to (re)load the index with the correct zmq_port
|
||||
with suppress_cpp_output_if_needed():
|
||||
if self._index is not None:
|
||||
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
||||
else:
|
||||
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
||||
|
||||
self._index = self._diskannpy.StaticDiskFloatIndex(
|
||||
self._init_params["metric_enum"],
|
||||
self._init_params["full_index_prefix"],
|
||||
self._init_params["num_threads"],
|
||||
self._init_params["num_nodes_to_cache"],
|
||||
self._init_params["cache_mechanism"],
|
||||
zmq_port,
|
||||
self._init_params["pq_prefix"],
|
||||
self._init_params["partition_prefix"],
|
||||
)
|
||||
self._current_zmq_port = zmq_port
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -142,11 +389,11 @@ class DiskannSearcher(BaseSearcher):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
zmq_port: Optional[int] = None,
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for nearest neighbors using DiskANN index.
|
||||
|
||||
@@ -161,7 +408,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
- "global": Use global pruning strategy (default)
|
||||
- "local": Use local pruning strategy
|
||||
- "proportional": Not supported in DiskANN, falls back to global
|
||||
zmq_port: ZMQ port for embedding server
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||
@@ -169,22 +416,22 @@ class DiskannSearcher(BaseSearcher):
|
||||
Returns:
|
||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||
"""
|
||||
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
self._ensure_index_loaded(zmq_port)
|
||||
else:
|
||||
# If not recomputing, we still need an index, use a default port
|
||||
if self._index is None:
|
||||
self._ensure_index_loaded(6666) # Default port when not recomputing
|
||||
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
if pruning_strategy == "proportional":
|
||||
raise NotImplementedError(
|
||||
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
||||
)
|
||||
|
||||
# Use recompute_embeddings parameter
|
||||
use_recompute = recompute_embeddings
|
||||
if use_recompute:
|
||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
if not meta_file_path.exists():
|
||||
raise RuntimeError(
|
||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
||||
)
|
||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
|
||||
@@ -194,28 +441,31 @@ class DiskannSearcher(BaseSearcher):
|
||||
else: # "global"
|
||||
use_global_pruning = True
|
||||
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
query.shape[0],
|
||||
top_k,
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
use_recompute,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning,
|
||||
)
|
||||
# Strategy:
|
||||
# - Traversal always uses PQ distances
|
||||
# - If recompute_embeddings=True, do a single final rerank via deferred fetch
|
||||
# (fetch embeddings for the final candidate set only)
|
||||
# - Do not recompute neighbor distances along the path
|
||||
use_deferred_fetch = True if recompute_embeddings else False
|
||||
recompute_neighors = False # Expected typo. For backward compatibility.
|
||||
|
||||
string_labels = [
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
for batch_labels in labels
|
||||
]
|
||||
with suppress_cpp_output_if_needed():
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
query.shape[0],
|
||||
top_k,
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
use_deferred_fetch,
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
recompute_neighors,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning,
|
||||
)
|
||||
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -0,0 +1,472 @@
|
||||
"""
|
||||
DiskANN-specific embedding server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Ensure we have a handler if none exists
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
distance_metric: str = "l2",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
||||
"""
|
||||
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
||||
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||
|
||||
# Add leann-core to path for unified embedding computation
|
||||
current_dir = Path(__file__).parent
|
||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.api import PassageManager
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
|
||||
logger.info("Successfully imported unified embedding computation module")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import embedding computation module: {e}")
|
||||
return
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
# Check port availability
|
||||
import socket
|
||||
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
logger.error(f"Port {zmq_port} is already in use")
|
||||
return
|
||||
|
||||
# Only support metadata file, fail fast for everything else
|
||||
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||
|
||||
# Load metadata to get passage sources
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||
|
||||
# Import protobuf after ensuring the path is correct
|
||||
try:
|
||||
from . import embedding_pb2
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import protobuf module: {e}")
|
||||
return
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread using REP socket for universal compatibility"""
|
||||
context = zmq.Context()
|
||||
socket = context.socket(
|
||||
zmq.REP
|
||||
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||
|
||||
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||
socket.setsockopt(zmq.LINGER, 0)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# REP socket receives single-part messages
|
||||
message = socket.recv()
|
||||
|
||||
# Check for empty messages - REP socket requires response to every request
|
||||
if len(message) == 0:
|
||||
logger.debug("Received empty message, sending empty response")
|
||||
socket.send(b"") # REP socket must respond to every request
|
||||
continue
|
||||
|
||||
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
||||
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
||||
|
||||
e2e_start = time.time()
|
||||
|
||||
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
||||
texts = []
|
||||
node_ids = []
|
||||
is_text_request = False
|
||||
|
||||
try:
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = list(req_proto.node_ids)
|
||||
|
||||
if not node_ids:
|
||||
raise RuntimeError(
|
||||
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
||||
)
|
||||
except Exception as protobuf_error:
|
||||
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
||||
# Fallback to msgpack (for BaseSearcher direct text requests)
|
||||
try:
|
||||
import msgpack
|
||||
|
||||
request = msgpack.unpackb(message)
|
||||
# For BaseSearcher compatibility, request is a list of texts directly
|
||||
if isinstance(request, list) and all(
|
||||
isinstance(item, str) for item in request
|
||||
):
|
||||
texts = request
|
||||
is_text_request = True
|
||||
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
|
||||
else:
|
||||
raise ValueError("Not a valid msgpack text request")
|
||||
except Exception as msgpack_error:
|
||||
raise RuntimeError(
|
||||
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
||||
)
|
||||
|
||||
# Look up texts by node IDs (only if not direct text request)
|
||||
if not is_text_request:
|
||||
for nid in node_ids:
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||
texts.append(txt)
|
||||
except KeyError as e:
|
||||
logger.error(f"Passage ID {nid} not found: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"Processing {len(texts)} texts")
|
||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||
|
||||
# Process embeddings using unified computation
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Prepare response based on request type
|
||||
if is_text_request:
|
||||
# For BaseSearcher compatibility: return msgpack format
|
||||
import msgpack
|
||||
|
||||
response_data = msgpack.packb(embeddings.tolist())
|
||||
else:
|
||||
# For DiskANN C++ compatibility: return protobuf format
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
|
||||
# Serialize embeddings data
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# Send response back to the client
|
||||
socket.send(response_data)
|
||||
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
except zmq.Again:
|
||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||
"""ZMQ server thread that respects shutdown signal.
|
||||
|
||||
This creates its own REP socket, binds to zmq_port, and periodically
|
||||
checks shutdown_event using recv timeouts to exit cleanly.
|
||||
"""
|
||||
logger.info("DiskANN ZMQ server thread started with shutdown support")
|
||||
|
||||
context = zmq.Context()
|
||||
rep_socket = context.socket(zmq.REP)
|
||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||
|
||||
# Set receive timeout so we can check shutdown_event periodically
|
||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
e2e_start = time.time()
|
||||
# REP socket receives single-part messages
|
||||
message = rep_socket.recv()
|
||||
|
||||
# Check for empty messages - REP socket requires response to every request
|
||||
if not message:
|
||||
logger.warning("Received empty message, sending empty response")
|
||||
rep_socket.send(b"")
|
||||
continue
|
||||
|
||||
# Try protobuf first (same logic as original)
|
||||
texts = []
|
||||
is_text_request = False
|
||||
|
||||
try:
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = list(req_proto.node_ids)
|
||||
|
||||
# Look up texts by node IDs
|
||||
for nid in node_ids:
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||
texts.append(txt)
|
||||
except KeyError:
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
|
||||
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
|
||||
except Exception:
|
||||
# Fallback to msgpack for text requests
|
||||
try:
|
||||
import msgpack
|
||||
|
||||
request = msgpack.unpackb(message)
|
||||
if isinstance(request, list) and all(
|
||||
isinstance(item, str) for item in request
|
||||
):
|
||||
texts = request
|
||||
is_text_request = True
|
||||
logger.info(
|
||||
f"ZMQ received msgpack text request for {len(texts)} texts"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Not a valid msgpack text request")
|
||||
except Exception:
|
||||
logger.error("Both protobuf and msgpack parsing failed!")
|
||||
# Send error response
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
rep_socket.send(resp_proto.SerializeToString())
|
||||
continue
|
||||
|
||||
# Process the request
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||
|
||||
# Validation
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
logger.error("NaN or Inf detected in embeddings!")
|
||||
# Send error response
|
||||
if is_text_request:
|
||||
import msgpack
|
||||
|
||||
response_data = msgpack.packb([])
|
||||
else:
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
response_data = resp_proto.SerializeToString()
|
||||
rep_socket.send(response_data)
|
||||
continue
|
||||
|
||||
# Prepare response based on request type
|
||||
if is_text_request:
|
||||
# For direct text requests, return msgpack
|
||||
import msgpack
|
||||
|
||||
response_data = msgpack.packb(embeddings.tolist())
|
||||
else:
|
||||
# For protobuf requests, return protobuf
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# Send response back to the client
|
||||
rep_socket.send(response_data)
|
||||
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
except zmq.Again:
|
||||
# Timeout - check shutdown_event and continue
|
||||
continue
|
||||
except Exception as e:
|
||||
if not shutdown_event.is_set():
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
try:
|
||||
# Send error response for REP socket
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
rep_socket.send(resp_proto.SerializeToString())
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
rep_socket.close(0)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
context.term()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("DiskANN ZMQ server thread exiting gracefully")
|
||||
|
||||
# Add shutdown coordination
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
def shutdown_zmq_server():
|
||||
"""Gracefully shutdown ZMQ server."""
|
||||
logger.info("Initiating graceful shutdown...")
|
||||
shutdown_event.set()
|
||||
|
||||
if zmq_thread.is_alive():
|
||||
logger.info("Waiting for ZMQ thread to finish...")
|
||||
zmq_thread.join(timeout=5)
|
||||
if zmq_thread.is_alive():
|
||||
logger.warning("ZMQ thread did not finish in time")
|
||||
|
||||
# Clean up ZMQ resources
|
||||
try:
|
||||
# Note: socket and context are cleaned up by thread exit
|
||||
logger.info("ZMQ resources cleaned up")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
||||
|
||||
# Clean up other resources
|
||||
try:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
logger.info("Additional resources cleaned up")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning additional resources: {e}")
|
||||
|
||||
logger.info("Graceful shutdown completed")
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handlers within this function scope
|
||||
import signal
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||
shutdown_zmq_server()
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Start ZMQ thread (NOT daemon!)
|
||||
zmq_thread = threading.Thread(
|
||||
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
||||
daemon=False, # Not daemon - we want to wait for it
|
||||
)
|
||||
zmq_thread.start()
|
||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
time.sleep(0.1) # Check shutdown more frequently
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DiskANN Server shutting down...")
|
||||
shutdown_zmq_server()
|
||||
return
|
||||
|
||||
# If we reach here, shutdown was triggered by signal
|
||||
logger.info("Main loop exited, process should be shutting down")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
# Signal handlers are now registered within create_diskann_embedding_server
|
||||
|
||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument(
|
||||
"--passages-file",
|
||||
type=str,
|
||||
help="Metadata JSON file containing passage sources",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
type=str,
|
||||
default="l2",
|
||||
choices=["l2", "mips", "cosine"],
|
||||
help="Distance metric for similarity computation",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create and start the DiskANN embedding server
|
||||
create_diskann_embedding_server(
|
||||
passages_file=args.passages_file,
|
||||
zmq_port=args.zmq_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
@@ -1,27 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: embedding.proto
|
||||
# ruff: noqa
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
|
||||
)
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_NODEEMBEDDINGREQUEST._serialized_start=35
|
||||
_NODEEMBEDDINGREQUEST._serialized_end=75
|
||||
_NODEEMBEDDINGRESPONSE._serialized_start=77
|
||||
_NODEEMBEDDINGRESPONSE._serialized_end=166
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._options = None
|
||||
_NODEEMBEDDINGREQUEST._serialized_start = 35
|
||||
_NODEEMBEDDINGREQUEST._serialized_end = 75
|
||||
_NODEEMBEDDINGRESPONSE._serialized_start = 77
|
||||
_NODEEMBEDDINGRESPONSE._serialized_end = 166
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -1,741 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import argparse
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import zmq
|
||||
import numpy as np
|
||||
import msgpack
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
RESET = "\033[0m"
|
||||
|
||||
# --- New Passage Loader from HNSW backend ---
|
||||
class SimplePassageLoader:
|
||||
"""
|
||||
Simple passage loader that replaces config.py dependencies
|
||||
"""
|
||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||
self.passages_data = passages_data or {}
|
||||
self._meta_path = ''
|
||||
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID"""
|
||||
str_id = str(passage_id)
|
||||
if str_id in self.passages_data:
|
||||
return {"text": self.passages_data[str_id]}
|
||||
else:
|
||||
# Return empty text for missing passages
|
||||
return {"text": ""}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.passages_data)
|
||||
|
||||
def keys(self):
|
||||
return self.passages_data.keys()
|
||||
|
||||
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages using metadata file with PassageManager for lazy loading
|
||||
"""
|
||||
# Load metadata to get passage sources
|
||||
with open(meta_file, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
# Import PassageManager dynamically to avoid circular imports
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Find the leann package directory relative to this file
|
||||
current_dir = Path(__file__).parent
|
||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.api import PassageManager
|
||||
passage_manager = PassageManager(meta['passage_sources'])
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
# Load label map
|
||||
passages_dir = Path(meta_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
if label_map_file.exists():
|
||||
import pickle
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
||||
|
||||
class LazyPassageLoader(SimplePassageLoader):
|
||||
def __init__(self, passage_manager, label_map):
|
||||
self.passage_manager = passage_manager
|
||||
self.label_map = label_map
|
||||
# Initialize parent with empty data
|
||||
super().__init__({})
|
||||
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID with lazy loading"""
|
||||
try:
|
||||
int_id = int(passage_id)
|
||||
if int_id in self.label_map:
|
||||
string_id = self.label_map[int_id]
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.label_map)
|
||||
|
||||
def keys(self):
|
||||
return self.label_map.keys()
|
||||
|
||||
loader = LazyPassageLoader(passage_manager, label_map)
|
||||
loader._meta_path = meta_file
|
||||
return loader
|
||||
|
||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages from a JSONL file with label map support
|
||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||
"""
|
||||
|
||||
if not os.path.exists(passages_file):
|
||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||
|
||||
if not passages_file.endswith('.jsonl'):
|
||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||
|
||||
# Load label map (int -> string_id)
|
||||
passages_dir = Path(passages_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
label_map = {}
|
||||
if label_map_file.exists():
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
# Load passages by string ID
|
||||
string_id_passages = {}
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
string_id_passages[passage['id']] = passage['text']
|
||||
|
||||
# Create int ID -> text mapping using label map
|
||||
passages_data = {}
|
||||
for int_id, string_id in label_map.items():
|
||||
if string_id in string_id_passages:
|
||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
||||
else:
|
||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
||||
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
||||
return SimplePassageLoader(passages_data)
|
||||
|
||||
def create_embedding_server_thread(
|
||||
zmq_port=5555,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
max_batch_size=128,
|
||||
passages_file: Optional[str] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
Create and run embedding server in the current thread
|
||||
This function is designed to be called in a separate thread
|
||||
"""
|
||||
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
||||
|
||||
try:
|
||||
# Check if port is already occupied
|
||||
import socket
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||
return
|
||||
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
embedding_mode = "openai"
|
||||
|
||||
if embedding_mode == "mlx":
|
||||
from leann.api import compute_embeddings_mlx
|
||||
import torch
|
||||
logger.info("Using MLX for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
mps_available = False
|
||||
elif embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
import torch
|
||||
logger.info("Using OpenAI API for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
mps_available = False
|
||||
elif embedding_mode == "sentence-transformers":
|
||||
# Initialize model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
import torch
|
||||
|
||||
# Select device
|
||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
logger.info("Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
logger.info("Using MPS device (Apple Silicon)")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU device")
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
|
||||
# Optimize model
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
logger.info(f"Using FP16 precision with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
|
||||
|
||||
# Load passages from file if provided
|
||||
if passages_file and os.path.exists(passages_file):
|
||||
# Check if it's a metadata file or a single passages file
|
||||
if passages_file.endswith('.meta.json'):
|
||||
passages = load_passages_from_metadata(passages_file)
|
||||
else:
|
||||
# Try to find metadata file in same directory
|
||||
passages_dir = Path(passages_file).parent
|
||||
meta_files = list(passages_dir.glob("*.meta.json"))
|
||||
if meta_files:
|
||||
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
|
||||
passages = load_passages_from_metadata(str(meta_files[0]))
|
||||
else:
|
||||
# Fallback to original single file loading (will cause warnings)
|
||||
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
|
||||
passages = load_passages_from_file(passages_file)
|
||||
else:
|
||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||
passages = SimplePassageLoader()
|
||||
|
||||
logger.info(f"Loaded {len(passages)} passages.")
|
||||
|
||||
def client_warmup(zmq_port):
|
||||
"""Perform client-side warmup for DiskANN server"""
|
||||
time.sleep(2)
|
||||
print(f"Performing client-side warmup with model {model_name}...")
|
||||
|
||||
# Get actual passage IDs from the loaded passages
|
||||
sample_ids = []
|
||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
||||
available_ids = list(passages.keys())
|
||||
# Take up to 5 actual IDs, but at least 1
|
||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
||||
else:
|
||||
print("No passages available for warmup, skipping warmup...")
|
||||
return
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
try:
|
||||
ids_to_send = [int(x) for x in sample_ids]
|
||||
except ValueError:
|
||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
||||
return
|
||||
|
||||
if not ids_to_send:
|
||||
print("Skipping warmup send.")
|
||||
return
|
||||
|
||||
# Use protobuf format for warmup
|
||||
from . import embedding_pb2
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.node_ids.extend(ids_to_send)
|
||||
request_bytes = req_proto.SerializeToString()
|
||||
|
||||
for i in range(3):
|
||||
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
|
||||
socket.send(request_bytes)
|
||||
response_bytes = socket.recv()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
resp_proto.ParseFromString(response_bytes)
|
||||
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
|
||||
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Client-side Protobuf ZMQ warmup complete")
|
||||
socket.close()
|
||||
context.term()
|
||||
except Exception as e:
|
||||
print(f"Error during Protobuf ZMQ warmup: {e}")
|
||||
|
||||
class DeviceTimer:
|
||||
"""Device timer"""
|
||||
def __init__(self, name="", device=device):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
else:
|
||||
self.start_event = None
|
||||
self.end_event = None
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start()
|
||||
yield
|
||||
self.end()
|
||||
|
||||
def start(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
self.start_event.record()
|
||||
else:
|
||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.start_time = time.time()
|
||||
|
||||
def end(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
self.end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed_time(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||
else:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def print_elapsed(self):
|
||||
elapsed = self.elapsed_time()
|
||||
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
|
||||
|
||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
||||
"""Process text batch"""
|
||||
if not texts_batch:
|
||||
return np.array([])
|
||||
|
||||
# Filter out empty texts and their corresponding IDs
|
||||
valid_texts = []
|
||||
valid_ids = []
|
||||
for i, text in enumerate(texts_batch):
|
||||
if text.strip(): # Only include non-empty texts
|
||||
valid_texts.append(text)
|
||||
valid_ids.append(ids_batch[i])
|
||||
|
||||
if not valid_texts:
|
||||
print("WARNING: No valid texts in batch")
|
||||
return np.array([])
|
||||
|
||||
# Tokenize
|
||||
token_timer = DeviceTimer("tokenization")
|
||||
with token_timer.timing():
|
||||
inputs = tokenizer(
|
||||
valid_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
# Compute embeddings
|
||||
embed_timer = DeviceTimer("embedding computation")
|
||||
with embed_timer.timing():
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# Mean pooling
|
||||
attention_mask = inputs['attention_mask']
|
||||
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
batch_embeddings = sum_embeddings / sum_mask
|
||||
embed_timer.print_elapsed()
|
||||
|
||||
return batch_embeddings.cpu().numpy()
|
||||
|
||||
# ZMQ server main loop - modified to use REP socket
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.ROUTER) # Changed to REP socket
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
||||
|
||||
# Set timeouts
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
|
||||
|
||||
from . import embedding_pb2
|
||||
|
||||
print(f"INFO: Embedding server ready to serve requests")
|
||||
|
||||
# Start warmup thread if enabled
|
||||
if enable_warmup and len(passages) > 0:
|
||||
import threading
|
||||
print(f"Warmup enabled: starting warmup thread")
|
||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
else:
|
||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
||||
|
||||
while True:
|
||||
try:
|
||||
parts = socket.recv_multipart()
|
||||
|
||||
# --- Restore robust message format detection ---
|
||||
# Must check parts length to avoid IndexError
|
||||
if len(parts) >= 3:
|
||||
identity = parts[0]
|
||||
# empty = parts[1] # We usually don't care about the middle empty frame
|
||||
message = parts[2]
|
||||
elif len(parts) == 2:
|
||||
# Can also handle cases without empty frame
|
||||
identity = parts[0]
|
||||
message = parts[1]
|
||||
else:
|
||||
# If received message format is wrong, print warning and ignore it instead of crashing
|
||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
||||
continue
|
||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
||||
|
||||
# Handle control messages (MessagePack format)
|
||||
try:
|
||||
request_payload = msgpack.unpackb(message)
|
||||
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||
if request_payload[0] == "__QUERY_META_PATH__":
|
||||
# Return the current meta path being used by the server
|
||||
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
|
||||
response = [current_meta_path]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
|
||||
# Update the server's meta path and reload passages
|
||||
new_meta_path = request_payload[1]
|
||||
try:
|
||||
print(f"INFO: Updating server meta path to: {new_meta_path}")
|
||||
# Reload passages from the new meta file
|
||||
passages = load_passages_from_metadata(new_meta_path)
|
||||
# Store the meta path for future queries
|
||||
passages._meta_path = new_meta_path
|
||||
response = ["SUCCESS"]
|
||||
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update meta path: {e}")
|
||||
response = ["FAILED", str(e)]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__QUERY_MODEL__":
|
||||
# Return the current model being used by the server
|
||||
response = [model_name]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
|
||||
# Update the server's embedding model
|
||||
new_model_name = request_payload[1]
|
||||
try:
|
||||
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
|
||||
|
||||
# Clean up old model to free memory
|
||||
if not use_mlx:
|
||||
print("INFO: Releasing old model from memory...")
|
||||
old_model = model
|
||||
old_tokenizer = tokenizer
|
||||
|
||||
# Load new tokenizer first
|
||||
print(f"Loading new tokenizer for {new_model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
|
||||
|
||||
# Load new model
|
||||
print(f"Loading new model {new_model_name}...")
|
||||
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
|
||||
|
||||
# Optimize new model
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"INFO: Using FP16 precision with model: {new_model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
|
||||
# Now safely delete old model after new one is loaded
|
||||
del old_model
|
||||
del old_tokenizer
|
||||
|
||||
# Clear GPU cache if available
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
print("INFO: Cleared CUDA cache")
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
print("INFO: Cleared MPS cache")
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
print("INFO: Memory cleanup completed")
|
||||
|
||||
# Update model name
|
||||
model_name = new_model_name
|
||||
|
||||
response = ["SUCCESS"]
|
||||
print(f"INFO: Successfully updated model to: {new_model_name}")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update model: {e}")
|
||||
response = ["FAILED", str(e)]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
except:
|
||||
# Not a control message, continue with normal protobuf processing
|
||||
pass
|
||||
|
||||
e2e_start = time.time()
|
||||
lookup_timer = DeviceTimer("text lookup")
|
||||
|
||||
# Parse request
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = req_proto.node_ids
|
||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
||||
|
||||
# Add debug information
|
||||
if len(node_ids) > 0:
|
||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
||||
|
||||
# Look up texts
|
||||
texts = []
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
txt = txtinfo["text"]
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
else:
|
||||
# If text is empty, we still need a placeholder for batch processing,
|
||||
# but record its ID as missing
|
||||
texts.append("")
|
||||
missing_ids.append(nid)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
if missing_ids:
|
||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
||||
|
||||
# Process batch
|
||||
total_size = len(texts)
|
||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
if total_size > max_batch_size:
|
||||
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
||||
for i in range(0, total_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, total_size)
|
||||
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
||||
|
||||
chunk_texts = texts[i:end_idx]
|
||||
chunk_ids = node_ids[i:end_idx]
|
||||
|
||||
if embedding_mode == "mlx":
|
||||
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
|
||||
elif embedding_mode == "openai":
|
||||
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
|
||||
else: # sentence-transformers
|
||||
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
|
||||
all_embeddings.append(embeddings_chunk)
|
||||
|
||||
if embedding_mode == "sentence-transformers":
|
||||
if cuda_available:
|
||||
torch.cuda.empty_cache()
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
hidden = np.vstack(all_embeddings)
|
||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
||||
else:
|
||||
if embedding_mode == "mlx":
|
||||
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
|
||||
elif embedding_mode == "openai":
|
||||
hidden = compute_embeddings_openai(texts, model_name)
|
||||
else: # sentence-transformers
|
||||
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
|
||||
|
||||
# Serialize response
|
||||
ser_start = time.time()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
resp_proto.missing_ids.extend(missing_ids)
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# REP socket sends a single response
|
||||
socket.send_multipart([identity, b'', response_data])
|
||||
|
||||
ser_end = time.time()
|
||||
|
||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
||||
|
||||
if embedding_mode == "sentence-transformers":
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
|
||||
except zmq.Again:
|
||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"ERROR: Error in ZMQ server: {e}")
|
||||
try:
|
||||
# Send empty response to maintain REQ-REP state
|
||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
||||
socket.send(empty_resp.SerializeToString())
|
||||
except:
|
||||
# If sending fails, recreate socket
|
||||
socket.close()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
print("INFO: ZMQ socket recreated after error")
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to start embedding server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_embedding_server(
|
||||
domain="demo",
|
||||
load_passages=True,
|
||||
load_embeddings=False,
|
||||
use_fp16=True,
|
||||
use_int8=False,
|
||||
use_cuda_graphs=False,
|
||||
zmq_port=5555,
|
||||
max_batch_size=128,
|
||||
lazy_load_passages=False,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
passages_file: Optional[str] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
原有的 create_embedding_server 函数保持不变
|
||||
这个是阻塞版本,用于直接运行
|
||||
"""
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
|
||||
choices=["sentence-transformers", "mlx", "openai"],
|
||||
help="Embedding backend mode")
|
||||
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
|
||||
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle backward compatibility with use_mlx
|
||||
embedding_mode = args.embedding_mode
|
||||
if args.use_mlx:
|
||||
embedding_mode = "mlx"
|
||||
|
||||
create_embedding_server(
|
||||
domain=args.domain,
|
||||
load_passages=args.load_passages,
|
||||
load_embeddings=args.load_embeddings,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int8=args.use_int8,
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
zmq_port=args.zmq_port,
|
||||
max_batch_size=args.max_batch_size,
|
||||
lazy_load_passages=args.lazy_load_passages,
|
||||
model_name=args.model_name,
|
||||
passages_file=args.passages_file,
|
||||
embedding_mode=embedding_mode,
|
||||
enable_warmup=not args.disable_warmup,
|
||||
)
|
||||
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Graph Partition Module for LEANN DiskANN Backend
|
||||
|
||||
This module provides Python bindings for the graph partition functionality
|
||||
of DiskANN, allowing users to partition disk-based indices for better
|
||||
performance.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GraphPartitioner:
|
||||
"""
|
||||
A Python interface for DiskANN's graph partition functionality.
|
||||
|
||||
This class provides methods to partition disk-based indices for improved
|
||||
search performance and memory efficiency.
|
||||
"""
|
||||
|
||||
def __init__(self, build_type: str = "release"):
|
||||
"""
|
||||
Initialize the GraphPartitioner.
|
||||
|
||||
Args:
|
||||
build_type: Build type for the executables ("debug" or "release")
|
||||
"""
|
||||
self.build_type = build_type
|
||||
self._ensure_executables()
|
||||
|
||||
def _get_executable_path(self, name: str) -> str:
|
||||
"""Get the path to a graph partition executable."""
|
||||
# Get the directory where this Python module is located
|
||||
module_dir = Path(__file__).parent
|
||||
# Navigate to the graph_partition directory
|
||||
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
||||
|
||||
if not executable_path.exists():
|
||||
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
||||
|
||||
return str(executable_path)
|
||||
|
||||
def _ensure_executables(self):
|
||||
"""Ensure that the required executables are built."""
|
||||
try:
|
||||
self._get_executable_path("partitioner")
|
||||
self._get_executable_path("index_relayout")
|
||||
except FileNotFoundError:
|
||||
# Try to build the executables automatically
|
||||
print("Executables not found, attempting to build them...")
|
||||
self._build_executables()
|
||||
|
||||
def _build_executables(self):
|
||||
"""Build the required executables."""
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Clean any existing build
|
||||
if (graph_partition_dir / "build").exists():
|
||||
shutil.rmtree(graph_partition_dir / "build")
|
||||
|
||||
# Run the build script
|
||||
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
||||
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
||||
|
||||
# Check if executables were created
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
print(f"✅ Built partitioner: {partitioner_path}")
|
||||
print(f"✅ Built index_relayout: {relayout_path}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to build executables: {e}")
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def partition_graph(
|
||||
self,
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Partition a disk-based index for improved performance.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
**kwargs: Additional parameters for graph partitioning:
|
||||
- gp_times: Number of LDG partition iterations (default: 10)
|
||||
- lock_nums: Number of lock nodes (default: 10)
|
||||
- cut: Cut adjacency list degree (default: 100)
|
||||
- scale_factor: Scale factor (default: 1)
|
||||
- data_type: Data type (default: "float")
|
||||
- thread_nums: Number of threads (default: 10)
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the partitioning process fails
|
||||
"""
|
||||
# Set default parameters
|
||||
params = {
|
||||
"gp_times": 10,
|
||||
"lock_nums": 10,
|
||||
"cut": 100,
|
||||
"scale_factor": 1,
|
||||
"data_type": "float",
|
||||
"thread_nums": 10,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Determine output directory
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(index_prefix_path).parent)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Determine partition prefix
|
||||
if partition_prefix is None:
|
||||
partition_prefix = Path(index_prefix_path).name
|
||||
|
||||
# Get executable paths
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Change to the graph_partition directory for temporary files
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Create temporary data directory
|
||||
temp_data_dir = Path(temp_dir) / "data"
|
||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up paths for temporary files
|
||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||
graph_gp_path = (
|
||||
graph_path
|
||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||
)
|
||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Find input index file
|
||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||
if not os.path.exists(old_index_file):
|
||||
old_index_file = f"{index_prefix_path}_disk.index"
|
||||
|
||||
if not os.path.exists(old_index_file):
|
||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||
|
||||
# Run partitioner
|
||||
gp_file_path = graph_gp_path / "_part.bin"
|
||||
partitioner_cmd = [
|
||||
partitioner_path,
|
||||
"--index_file",
|
||||
old_index_file,
|
||||
"--data_type",
|
||||
params["data_type"],
|
||||
"--gp_file",
|
||||
str(gp_file_path),
|
||||
"-T",
|
||||
str(params["thread_nums"]),
|
||||
"--ldg_times",
|
||||
str(params["gp_times"]),
|
||||
"--scale",
|
||||
str(params["scale_factor"]),
|
||||
"--mode",
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
||||
result = subprocess.run(
|
||||
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Partitioner failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Run relayout
|
||||
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
||||
relayout_cmd = [
|
||||
relayout_path,
|
||||
old_index_file,
|
||||
str(gp_file_path),
|
||||
params["data_type"],
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
||||
result = subprocess.run(
|
||||
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Relayout failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Copy results to output directory
|
||||
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
||||
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
||||
|
||||
shutil.copy2(part_tmp_index, disk_graph_path)
|
||||
shutil.copy2(gp_file_path, partition_bin_path)
|
||||
|
||||
print(f"Results copied to: {output_dir}")
|
||||
return str(disk_graph_path), str(partition_bin_path)
|
||||
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def get_partition_info(self, partition_bin_path: str) -> dict:
|
||||
"""
|
||||
Get information about a partition file.
|
||||
|
||||
Args:
|
||||
partition_bin_path: Path to the partition binary file
|
||||
|
||||
Returns:
|
||||
Dictionary containing partition information
|
||||
"""
|
||||
if not os.path.exists(partition_bin_path):
|
||||
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
||||
|
||||
# For now, return basic file information
|
||||
# In the future, this could parse the binary file for detailed info
|
||||
stat = os.stat(partition_bin_path)
|
||||
return {
|
||||
"file_size": stat.st_size,
|
||||
"file_path": partition_bin_path,
|
||||
"modified_time": stat.st_mtime,
|
||||
}
|
||||
|
||||
|
||||
def partition_graph(
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
build_type: str = "release",
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Convenience function to partition a graph index.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix
|
||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
build_type: Build type for executables ("debug" or "release")
|
||||
**kwargs: Additional parameters for graph partitioning
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
"""
|
||||
partitioner = GraphPartitioner(build_type=build_type)
|
||||
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Example: partition an index
|
||||
try:
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
||||
)
|
||||
print("Partitioning completed successfully!")
|
||||
print(f"Disk graph index: {disk_graph_path}")
|
||||
print(f"Partition binary: {partition_bin_path}")
|
||||
except Exception as e:
|
||||
print(f"Partitioning failed: {e}")
|
||||
@@ -4,16 +4,18 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.1.0"
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
version = "0.3.4"
|
||||
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# 关键:简化的 CMake 路径
|
||||
# Key: simplified CMake path
|
||||
cmake.source-dir = "third_party/DiskANN"
|
||||
# 关键:Python 包在根目录,路径完全匹配
|
||||
# Key: Python package in root directory, paths match exactly
|
||||
wheel.packages = ["leann_backend_diskann"]
|
||||
# 使用默认的 redirect 模式
|
||||
# Use default redirect mode
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
build.tool-args = ["-j8"]
|
||||
# Let CMake find packages via Homebrew prefix
|
||||
cmake.define = {CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}, OpenMP_ROOT = {env = "OpenMP_ROOT"}}
|
||||
|
||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: af2a26481e...19f9603c72
@@ -2,12 +2,12 @@ syntax = "proto3";
|
||||
|
||||
package protoembedding;
|
||||
|
||||
message NodeEmbeddingRequest {
|
||||
repeated uint32 node_ids = 1;
|
||||
message NodeEmbeddingRequest {
|
||||
repeated uint32 node_ids = 1;
|
||||
}
|
||||
|
||||
message NodeEmbeddingResponse {
|
||||
bytes embeddings_data = 1; // All embedded binary datas
|
||||
repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim]
|
||||
repeated uint32 missing_ids = 3; // Missing node ids
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,37 @@
|
||||
# 最终简化版
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
project(leann_backend_hnsw_wrapper)
|
||||
set(CMAKE_C_COMPILER_WORKS 1)
|
||||
set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||
|
||||
# Set OpenMP path for macOS
|
||||
if(APPLE)
|
||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
# Detect Homebrew installation path (Apple Silicon vs Intel)
|
||||
if(EXISTS "/opt/homebrew/opt/libomp")
|
||||
set(HOMEBREW_PREFIX "/opt/homebrew")
|
||||
elseif(EXISTS "/usr/local/opt/libomp")
|
||||
set(HOMEBREW_PREFIX "/usr/local")
|
||||
else()
|
||||
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
|
||||
endif()
|
||||
|
||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
||||
set(OpenMP_C_LIB_NAMES "omp")
|
||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
set(OpenMP_omp_LIBRARY "${HOMEBREW_PREFIX}/opt/libomp/lib/libomp.dylib")
|
||||
|
||||
# Force use of system libc++ to avoid version mismatch
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
|
||||
|
||||
# Set minimum macOS version for better compatibility
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||
endif()
|
||||
|
||||
# Build ZeroMQ from source
|
||||
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
||||
add_subdirectory(third_party/libzmq)
|
||||
# Use system ZeroMQ instead of building from source
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||
|
||||
# Add cppzmq headers
|
||||
include_directories(third_party/cppzmq)
|
||||
@@ -29,6 +41,7 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||
include_directories(third_party/msgpack-c/include)
|
||||
|
||||
# Faiss configuration - streamlined build
|
||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||
@@ -36,4 +49,43 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||
|
||||
add_subdirectory(third_party/faiss)
|
||||
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
|
||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# ARM64-specific configuration
|
||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||
message(STATUS "Configuring Faiss for ARM64 architecture")
|
||||
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
|
||||
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
|
||||
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
|
||||
else()
|
||||
# Use generic optimization for other ARM64 platforms (like macOS)
|
||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
||||
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
|
||||
message(STATUS "Using ARM64-compatible Faiss submodule")
|
||||
endif()
|
||||
|
||||
# Additional optimization options from INSTALL.md
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
|
||||
|
||||
# Avoid building demos and benchmarks
|
||||
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# NEW: Tell Faiss to only build the generic version
|
||||
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(third_party/faiss)
|
||||
|
||||
@@ -1 +1 @@
|
||||
from . import hnsw_backend
|
||||
from . import hnsw_backend as hnsw_backend
|
||||
|
||||
@@ -1,87 +1,122 @@
|
||||
import argparse
|
||||
import gc # Import garbage collector interface
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
import argparse
|
||||
import gc # Import garbage collector interface
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Set up logging to avoid print buffer issues
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# --- FourCCs (add more if needed) ---
|
||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little')
|
||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
||||
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
||||
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
|
||||
|
||||
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
||||
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
|
||||
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
||||
NULL_INDEX_FOURCC = int.from_bytes(b"null", "little")
|
||||
|
||||
# --- Helper functions for reading/writing binary data ---
|
||||
|
||||
|
||||
def read_struct(f, fmt):
|
||||
"""Reads data according to the struct format."""
|
||||
size = struct.calcsize(fmt)
|
||||
data = f.read(size)
|
||||
if len(data) != size:
|
||||
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.")
|
||||
raise EOFError(
|
||||
f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}."
|
||||
)
|
||||
return struct.unpack(fmt, data)[0]
|
||||
|
||||
|
||||
def read_vector_raw(f, element_fmt_char):
|
||||
"""Reads a vector (size followed by data), returns count and raw bytes."""
|
||||
count = -1 # Initialize count
|
||||
total_bytes = -1 # Initialize total_bytes
|
||||
count = -1 # Initialize count
|
||||
total_bytes = -1 # Initialize total_bytes
|
||||
try:
|
||||
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned
|
||||
count = read_struct(f, "<Q") # size_t usually 64-bit unsigned
|
||||
element_size = struct.calcsize(element_fmt_char)
|
||||
# --- FIX for MemoryError: Check for unreasonably large count ---
|
||||
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
||||
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
||||
if count > max_reasonable_count or count < 0:
|
||||
raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.")
|
||||
raise MemoryError(
|
||||
f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read."
|
||||
)
|
||||
|
||||
total_bytes = count * element_size
|
||||
# --- FIX for MemoryError: Check for huge byte size before allocation ---
|
||||
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
||||
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
||||
raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.")
|
||||
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
||||
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
||||
raise MemoryError(
|
||||
f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch."
|
||||
)
|
||||
|
||||
data_bytes = f.read(total_bytes)
|
||||
|
||||
if len(data_bytes) != total_bytes:
|
||||
raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.")
|
||||
raise EOFError(
|
||||
f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}."
|
||||
)
|
||||
return count, data_bytes
|
||||
except (MemoryError, OverflowError) as e:
|
||||
# Add context to the error message
|
||||
print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr)
|
||||
raise e # Re-raise the original error type
|
||||
# Add context to the error message
|
||||
print(
|
||||
f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e # Re-raise the original error type
|
||||
|
||||
|
||||
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
||||
"""Reads a vector into a NumPy array."""
|
||||
count = -1 # Initialize count for robust error handling
|
||||
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True)
|
||||
count = -1 # Initialize count for robust error handling
|
||||
print(
|
||||
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
||||
print(f"Count={count}, Bytes={len(data_bytes)}")
|
||||
if count > 0 and len(data_bytes) > 0:
|
||||
arr = np.frombuffer(data_bytes, dtype=np_dtype)
|
||||
if arr.size != count:
|
||||
raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}")
|
||||
raise ValueError(
|
||||
f"Inconsistent array size after reading. Expected {count}, got {arr.size}"
|
||||
)
|
||||
return arr
|
||||
elif count == 0:
|
||||
return np.array([], dtype=np_dtype)
|
||||
return np.array([], dtype=np_dtype)
|
||||
else:
|
||||
raise ValueError("Read zero bytes but count > 0.")
|
||||
raise ValueError("Read zero bytes but count > 0.")
|
||||
except MemoryError as e:
|
||||
# Now count should be defined (or -1 if error was in read_struct)
|
||||
print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr)
|
||||
print(
|
||||
f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e
|
||||
except Exception as e: # Catch other potential errors like ValueError
|
||||
print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr)
|
||||
except Exception as e: # Catch other potential errors like ValueError
|
||||
print(
|
||||
f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def write_numpy_vector(f, arr, struct_fmt_char):
|
||||
"""Writes a NumPy array as a vector (size followed by data)."""
|
||||
count = arr.size
|
||||
f.write(struct.pack('<Q', count))
|
||||
f.write(struct.pack("<Q", count))
|
||||
try:
|
||||
expected_dtype = np.dtype(struct_fmt_char)
|
||||
if arr.dtype != expected_dtype:
|
||||
@@ -89,23 +124,30 @@ def write_numpy_vector(f, arr, struct_fmt_char):
|
||||
else:
|
||||
data_to_write = arr.tobytes()
|
||||
f.write(data_to_write)
|
||||
del data_to_write # Hint GC
|
||||
del data_to_write # Hint GC
|
||||
except MemoryError as e:
|
||||
print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr)
|
||||
raise e
|
||||
print(
|
||||
f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def write_list_vector(f, lst, struct_fmt_char):
|
||||
"""Writes a Python list as a vector iteratively."""
|
||||
count = len(lst)
|
||||
f.write(struct.pack('<Q', count))
|
||||
fmt = '<' + struct_fmt_char
|
||||
f.write(struct.pack("<Q", count))
|
||||
fmt = "<" + struct_fmt_char
|
||||
chunk_size = 1024 * 1024
|
||||
element_size = struct.calcsize(fmt)
|
||||
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
|
||||
try:
|
||||
buffer = bytearray(chunk_size * element_size)
|
||||
except MemoryError:
|
||||
print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr)
|
||||
print(
|
||||
f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise
|
||||
buffer_count = 0
|
||||
|
||||
@@ -116,66 +158,80 @@ def write_list_vector(f, lst, struct_fmt_char):
|
||||
buffer_count += 1
|
||||
|
||||
if buffer_count == chunk_size or i == count - 1:
|
||||
f.write(buffer[:buffer_count * element_size])
|
||||
f.write(buffer[: buffer_count * element_size])
|
||||
buffer_count = 0
|
||||
|
||||
except struct.error as e:
|
||||
print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr)
|
||||
print(
|
||||
f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
|
||||
"""Helper to get cumulative neighbors count, matching C++ logic."""
|
||||
if level < 0: return 0
|
||||
if level < 0:
|
||||
return 0
|
||||
if level < len(cum_nneighbor_per_level_np):
|
||||
return cum_nneighbor_per_level_np[level]
|
||||
else:
|
||||
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
|
||||
|
||||
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||
compact_neighbors_data, storage_fourcc, storage_data):
|
||||
|
||||
def write_compact_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
compact_level_ptr,
|
||||
compact_node_offsets_np,
|
||||
compact_neighbors_data,
|
||||
storage_fourcc,
|
||||
storage_data,
|
||||
):
|
||||
"""Write HNSW data in compact format following C++ read order exactly."""
|
||||
# Write IndexHNSW Header
|
||||
f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['d']))
|
||||
f_out.write(struct.pack('<q', original_hnsw_data['ntotal']))
|
||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy1']))
|
||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy2']))
|
||||
f_out.write(struct.pack('<?', original_hnsw_data['is_trained']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['metric_type']))
|
||||
if original_hnsw_data['metric_type'] > 1:
|
||||
f_out.write(struct.pack('<f', original_hnsw_data['metric_arg']))
|
||||
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
||||
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
||||
|
||||
# Write HNSW struct parts (standard order)
|
||||
write_numpy_vector(f_out, assign_probas_np, 'd')
|
||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i')
|
||||
write_numpy_vector(f_out, levels_np, 'i')
|
||||
write_numpy_vector(f_out, assign_probas_np, "d")
|
||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
||||
write_numpy_vector(f_out, levels_np, "i")
|
||||
|
||||
# Write compact format flag
|
||||
f_out.write(struct.pack('<?', True)) # storage_is_compact = True
|
||||
f_out.write(struct.pack("<?", True)) # storage_is_compact = True
|
||||
|
||||
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
|
||||
if isinstance(compact_level_ptr, np.ndarray):
|
||||
write_numpy_vector(f_out, compact_level_ptr, 'Q')
|
||||
write_numpy_vector(f_out, compact_level_ptr, "Q")
|
||||
else:
|
||||
write_list_vector(f_out, compact_level_ptr, 'Q')
|
||||
|
||||
write_numpy_vector(f_out, compact_node_offsets_np, 'Q')
|
||||
write_list_vector(f_out, compact_level_ptr, "Q")
|
||||
|
||||
write_numpy_vector(f_out, compact_node_offsets_np, "Q")
|
||||
|
||||
# Write HNSW scalar parameters
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['entry_point']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['max_level']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['efConstruction']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['efSearch']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam']))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
||||
|
||||
# Write storage fourcc (this determines how to read what follows)
|
||||
f_out.write(struct.pack('<I', storage_fourcc))
|
||||
|
||||
f_out.write(struct.pack("<I", storage_fourcc))
|
||||
|
||||
# Write compact neighbors data AFTER storage fourcc
|
||||
write_list_vector(f_out, compact_neighbors_data, 'i')
|
||||
|
||||
write_list_vector(f_out, compact_neighbors_data, "i")
|
||||
|
||||
# Write storage data if not NULL (only after neighbors)
|
||||
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||
f_out.write(storage_data)
|
||||
@@ -183,185 +239,244 @@ def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneigh
|
||||
|
||||
# --- Main Conversion Logic ---
|
||||
|
||||
|
||||
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
|
||||
"""
|
||||
Converts an HNSW graph file to the CSR format.
|
||||
Supports both original and already-compact formats (backward compatibility).
|
||||
|
||||
|
||||
Args:
|
||||
input_filename: Input HNSW index file
|
||||
output_filename: Output CSR index file
|
||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||
"""
|
||||
# Keep prints simple; rely on CI runner to flush output as needed
|
||||
|
||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||
start_time = time.time()
|
||||
original_hnsw_data = {}
|
||||
neighbors_np = None # Initialize to allow check in finally block
|
||||
neighbors_np = None # Initialize to allow check in finally block
|
||||
try:
|
||||
with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out:
|
||||
|
||||
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
||||
# --- Read IndexHNSW FourCC and Header ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
|
||||
# ... (Keep the header reading logic as before) ...
|
||||
hnsw_index_fourcc = read_struct(f_in, '<I')
|
||||
hnsw_index_fourcc = read_struct(f_in, "<I")
|
||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr)
|
||||
return False
|
||||
original_hnsw_data['index_fourcc'] = hnsw_index_fourcc
|
||||
original_hnsw_data['d'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['ntotal'] = read_struct(f_in, '<q')
|
||||
original_hnsw_data['dummy1'] = read_struct(f_in, '<q')
|
||||
original_hnsw_data['dummy2'] = read_struct(f_in, '<q')
|
||||
original_hnsw_data['is_trained'] = read_struct(f_in, '?')
|
||||
original_hnsw_data['metric_type'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['metric_arg'] = 0.0
|
||||
if original_hnsw_data['metric_type'] > 1:
|
||||
original_hnsw_data['metric_arg'] = read_struct(f_in, '<f')
|
||||
print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}")
|
||||
|
||||
print(
|
||||
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
||||
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["metric_arg"] = 0.0
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}"
|
||||
)
|
||||
|
||||
# --- Read original HNSW struct data ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
|
||||
assign_probas_np = read_numpy_vector(f_in, np.float64, 'd')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})")
|
||||
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})"
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})")
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})"
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
levels_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
|
||||
gc.collect()
|
||||
|
||||
ntotal = len(levels_np)
|
||||
if ntotal != original_hnsw_data['ntotal']:
|
||||
print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr)
|
||||
original_hnsw_data['ntotal'] = ntotal
|
||||
if ntotal != original_hnsw_data["ntotal"]:
|
||||
print(
|
||||
f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
original_hnsw_data["ntotal"] = ntotal
|
||||
|
||||
# --- Check for compact format flag ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
|
||||
pos_before_compact = f_in.tell()
|
||||
try:
|
||||
is_compact_flag = read_struct(f_in, '<?')
|
||||
is_compact_flag = read_struct(f_in, "<?")
|
||||
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
|
||||
|
||||
|
||||
if is_compact_flag:
|
||||
# Input is already in compact format - read compact data
|
||||
print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...")
|
||||
|
||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})")
|
||||
|
||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})")
|
||||
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data..."
|
||||
)
|
||||
|
||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})"
|
||||
)
|
||||
|
||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})"
|
||||
)
|
||||
|
||||
# Read scalar parameters
|
||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
||||
)
|
||||
|
||||
# Read storage fourcc
|
||||
storage_fourcc = read_struct(f_in, '<I')
|
||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}")
|
||||
|
||||
storage_fourcc = read_struct(f_in, "<I")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}"
|
||||
)
|
||||
|
||||
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
|
||||
# Read compact neighbors data
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})")
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})"
|
||||
)
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
del compact_neighbors_data_np
|
||||
|
||||
|
||||
# Skip storage data and write with NULL marker
|
||||
print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker."
|
||||
)
|
||||
storage_fourcc = NULL_INDEX_FOURCC
|
||||
elif not prune_embeddings:
|
||||
# Read and preserve compact neighbors and storage
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
del compact_neighbors_data_np
|
||||
|
||||
|
||||
# Read remaining storage data
|
||||
storage_data = f_in.read()
|
||||
else:
|
||||
# Already pruned (NULL storage)
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
del compact_neighbors_data_np
|
||||
storage_data = b''
|
||||
|
||||
storage_data = b""
|
||||
|
||||
# Write the updated compact format
|
||||
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
|
||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||
compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'')
|
||||
|
||||
write_compact_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
compact_level_ptr,
|
||||
compact_node_offsets_np,
|
||||
compact_neighbors_data,
|
||||
storage_fourcc,
|
||||
storage_data if not prune_embeddings else b"",
|
||||
)
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
|
||||
return True
|
||||
|
||||
|
||||
else:
|
||||
# is_compact=False, rewind and read original format
|
||||
f_in.seek(pos_before_compact)
|
||||
print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...")
|
||||
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format..."
|
||||
)
|
||||
|
||||
except EOFError:
|
||||
# No compact flag found, assume original format
|
||||
f_in.seek(pos_before_compact)
|
||||
print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format..."
|
||||
)
|
||||
|
||||
# --- Handle potential extra byte in original format (like C++ code) ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets..."
|
||||
)
|
||||
pos_before_probe = f_in.tell()
|
||||
try:
|
||||
suspected_flag = read_struct(f_in, '<B') # Read 1 byte
|
||||
suspected_flag = read_struct(f_in, "<B") # Read 1 byte
|
||||
if suspected_flag == 0x00:
|
||||
print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte."
|
||||
)
|
||||
elif suspected_flag == 0x01:
|
||||
print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False"
|
||||
)
|
||||
raise ValueError("Inconsistent compact flag state")
|
||||
else:
|
||||
# Rewind - this byte is part of offsets data
|
||||
f_in.seek(pos_before_probe)
|
||||
print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})"
|
||||
)
|
||||
except EOFError:
|
||||
f_in.seek(pos_before_probe)
|
||||
print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read"
|
||||
)
|
||||
|
||||
# --- Read original format data ---
|
||||
offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
|
||||
if len(offsets_np) != ntotal + 1:
|
||||
raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}")
|
||||
raise ValueError(
|
||||
f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}"
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
|
||||
neighbors_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
|
||||
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
|
||||
if neighbors_np.size != expected_neighbors_size:
|
||||
print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.")
|
||||
print(
|
||||
f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}."
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
||||
)
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
|
||||
storage_fourcc = None
|
||||
try:
|
||||
storage_fourcc = read_struct(f_in, '<I')
|
||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.")
|
||||
storage_fourcc = read_struct(f_in, "<I")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}."
|
||||
)
|
||||
except EOFError:
|
||||
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
||||
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
||||
except Exception as e:
|
||||
print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}")
|
||||
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}"
|
||||
)
|
||||
|
||||
# --- Perform Conversion ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
|
||||
@@ -373,17 +488,21 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
|
||||
current_level_ptr_idx = 0
|
||||
current_data_idx = 0
|
||||
total_valid_neighbors_counted = 0 # For validation
|
||||
total_valid_neighbors_counted = 0 # For validation
|
||||
|
||||
# Optimize calculation by getting slices once per node if possible
|
||||
for i in range(ntotal):
|
||||
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
||||
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
||||
progress = (i / ntotal) * 100
|
||||
elapsed = time.time() - start_time
|
||||
print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="")
|
||||
print(
|
||||
f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...",
|
||||
end="",
|
||||
)
|
||||
|
||||
node_max_level = levels_np[i] - 1
|
||||
if node_max_level < -1: node_max_level = -1
|
||||
if node_max_level < -1:
|
||||
node_max_level = -1
|
||||
|
||||
node_ptr_start_index = current_level_ptr_idx
|
||||
compact_node_offsets_np[i] = node_ptr_start_index
|
||||
@@ -394,13 +513,17 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
for level in range(node_max_level + 1):
|
||||
compact_level_ptr.append(current_data_idx)
|
||||
|
||||
begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level)
|
||||
end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1)
|
||||
begin_orig_np = original_offset_start + get_cum_neighbors(
|
||||
cum_nneighbor_per_level_np, level
|
||||
)
|
||||
end_orig_np = original_offset_start + get_cum_neighbors(
|
||||
cum_nneighbor_per_level_np, level + 1
|
||||
)
|
||||
|
||||
begin_orig = int(begin_orig_np)
|
||||
end_orig = int(end_orig_np)
|
||||
|
||||
neighbors_len = len(neighbors_np) # Cache length
|
||||
neighbors_len = len(neighbors_np) # Cache length
|
||||
begin_orig = min(max(0, begin_orig), neighbors_len)
|
||||
end_orig = min(max(begin_orig, end_orig), neighbors_len)
|
||||
|
||||
@@ -413,83 +536,117 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
|
||||
if num_valid > 0:
|
||||
# Append valid neighbors
|
||||
compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask])
|
||||
compact_neighbors_data.extend(
|
||||
level_neighbors_slice[valid_neighbors_mask]
|
||||
)
|
||||
current_data_idx += num_valid
|
||||
total_valid_neighbors_counted += num_valid
|
||||
|
||||
|
||||
compact_level_ptr.append(current_data_idx)
|
||||
current_level_ptr_idx += num_pointers_expected
|
||||
|
||||
compact_node_offsets_np[ntotal] = current_level_ptr_idx
|
||||
print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line
|
||||
print(
|
||||
f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. "
|
||||
) # Clear progress line
|
||||
|
||||
# --- Validation Checks ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
|
||||
valid_check_passed = True
|
||||
# Check 1: Total valid neighbors count
|
||||
print(f" Checking total valid neighbor count...")
|
||||
print(" Checking total valid neighbor count...")
|
||||
expected_valid_count = np.sum(neighbors_np >= 0)
|
||||
if total_valid_neighbors_counted != len(compact_neighbors_data):
|
||||
print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
print(
|
||||
f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
valid_check_passed = False
|
||||
if expected_valid_count != len(compact_neighbors_data):
|
||||
print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
print(
|
||||
f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
valid_check_passed = False
|
||||
else:
|
||||
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
||||
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
||||
|
||||
# Check 2: Final pointer indices consistency
|
||||
print(f" Checking final pointer indices...")
|
||||
print(" Checking final pointer indices...")
|
||||
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
|
||||
print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \
|
||||
(len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
||||
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
||||
print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
print(
|
||||
f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
valid_check_passed = False
|
||||
if (
|
||||
len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)
|
||||
) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
||||
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
||||
print(
|
||||
f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
valid_check_passed = False
|
||||
else:
|
||||
print(f" OK: Final pointers match data size.")
|
||||
print(" OK: Final pointers match data size.")
|
||||
|
||||
if not valid_check_passed:
|
||||
print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr)
|
||||
print(
|
||||
"Error: Validation checks failed. Output file might be incorrect.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Optional: Exit here if validation fails
|
||||
# return False
|
||||
|
||||
# --- Explicitly delete large intermediate arrays ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..."
|
||||
)
|
||||
del neighbors_np
|
||||
del offsets_np
|
||||
gc.collect()
|
||||
|
||||
print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}")
|
||||
print(
|
||||
f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}"
|
||||
)
|
||||
|
||||
# --- Write CSR HNSW graph data using unified function ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
|
||||
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..."
|
||||
)
|
||||
|
||||
# Determine storage fourcc and data based on prune_embeddings
|
||||
if prune_embeddings:
|
||||
print(f" Pruning embeddings: Writing NULL storage marker.")
|
||||
print(" Pruning embeddings: Writing NULL storage marker.")
|
||||
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||
storage_data = b''
|
||||
storage_data = b""
|
||||
else:
|
||||
# Keep embeddings - read and preserve original storage data
|
||||
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
|
||||
print(f" Preserving embeddings: Reading original storage data...")
|
||||
print(" Preserving embeddings: Reading original storage data...")
|
||||
storage_data = f_in.read() # Read remaining storage data
|
||||
output_storage_fourcc = storage_fourcc
|
||||
print(f" Read {len(storage_data)} bytes of storage data")
|
||||
else:
|
||||
print(f" No embeddings found in original file (NULL storage)")
|
||||
print(" No embeddings found in original file (NULL storage)")
|
||||
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||
storage_data = b''
|
||||
|
||||
storage_data = b""
|
||||
|
||||
# Use the unified write function
|
||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||
compact_neighbors_data, output_storage_fourcc, storage_data)
|
||||
|
||||
write_compact_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
compact_level_ptr,
|
||||
compact_node_offsets_np,
|
||||
compact_neighbors_data,
|
||||
output_storage_fourcc,
|
||||
storage_data,
|
||||
)
|
||||
|
||||
# Clean up memory
|
||||
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
||||
del compact_neighbors_data, compact_level_ptr, compact_node_offsets_np
|
||||
@@ -503,40 +660,66 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
||||
return False
|
||||
except MemoryError as e:
|
||||
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
|
||||
# Clean up potentially partially written output file?
|
||||
try: os.remove(output_filename)
|
||||
except OSError: pass
|
||||
return False
|
||||
print(
|
||||
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Clean up potentially partially written output file?
|
||||
try:
|
||||
os.remove(output_filename)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
except EOFError as e:
|
||||
print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr)
|
||||
try: os.remove(output_filename)
|
||||
except OSError: pass
|
||||
print(
|
||||
f"Error: Reached end of file unexpectedly reading {input_filename}. {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
try:
|
||||
os.remove(output_filename)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
try:
|
||||
os.remove(output_filename)
|
||||
except OSError: pass
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
# Ensure neighbors_np is deleted even if an error occurs after its allocation
|
||||
finally:
|
||||
if 'neighbors_np' in locals() and neighbors_np is not None:
|
||||
del neighbors_np
|
||||
gc.collect()
|
||||
try:
|
||||
if "neighbors_np" in locals() and neighbors_np is not None:
|
||||
del neighbors_np
|
||||
gc.collect()
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
|
||||
# --- Script Execution ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file."
|
||||
)
|
||||
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
|
||||
parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file")
|
||||
parser.add_argument("--prune-embeddings", action="store_true", default=True,
|
||||
help="Prune embedding storage (write NULL storage marker)")
|
||||
parser.add_argument("--keep-embeddings", action="store_true",
|
||||
help="Keep embedding storage (overrides --prune-embeddings)")
|
||||
parser.add_argument(
|
||||
"output_csr_graph_file", help="Path to write the output CSR HNSW graph file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prune-embeddings",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Prune embedding storage (write NULL storage marker)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-embeddings",
|
||||
action="store_true",
|
||||
help="Keep embedding storage (overrides --prune-embeddings)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -545,10 +728,12 @@ if __name__ == "__main__":
|
||||
sys.exit(1)
|
||||
|
||||
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
|
||||
print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
print("Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
|
||||
success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings)
|
||||
success = convert_hnsw_graph_to_csr(
|
||||
args.input_index_file, args.output_csr_graph_file, prune_embeddings
|
||||
)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal
|
||||
import pickle
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
|
||||
from leann.registry import register_backend
|
||||
import numpy as np
|
||||
from leann.interface import (
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
from leann.registry import register_backend
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
from .prune_index import prune_embeddings_preserve_graph_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_metric_map():
|
||||
@@ -27,6 +30,12 @@ def get_metric_map():
|
||||
}
|
||||
|
||||
|
||||
def normalize_l2(data: np.ndarray) -> np.ndarray:
|
||||
norms = np.linalg.norm(data, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1 # Avoid division by zero
|
||||
return data / norms
|
||||
|
||||
|
||||
@register_backend("hnsw")
|
||||
class HNSWBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -47,8 +56,15 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||
self.dimensions = self.build_params.get("dimensions")
|
||||
if not self.is_recompute and self.is_compact:
|
||||
# Auto-correct: non-recompute requires non-compact storage for HNSW
|
||||
logger.warning(
|
||||
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
|
||||
)
|
||||
self.is_compact = False
|
||||
self.build_params["is_compact"] = False
|
||||
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
from . import faiss # type: ignore
|
||||
|
||||
path = Path(index_path)
|
||||
@@ -57,13 +73,9 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||
data = data.astype(np.float32)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
@@ -73,19 +85,27 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index.hnsw.efConstruction = self.efConstruction
|
||||
|
||||
if self.distance_metric.lower() == "cosine":
|
||||
faiss.normalize_L2(data)
|
||||
data = normalize_l2(data)
|
||||
|
||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
if self.is_recompute:
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
else:
|
||||
# Non-compact format: prune only embeddings, keep original graph
|
||||
ok = prune_embeddings_preserve_graph_inplace(str(index_file))
|
||||
if not ok:
|
||||
raise RuntimeError(
|
||||
"Pruning embeddings while preserving graph failed for non-compact index"
|
||||
)
|
||||
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
|
||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||
|
||||
@@ -94,20 +114,16 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ CSR conversion successful.")
|
||||
index_file_old = index_file.with_suffix(".old")
|
||||
shutil.move(str(index_file), str(index_file_old))
|
||||
logger.info("✅ CSR conversion successful.")
|
||||
# index_file_old = index_file.with_suffix(".old")
|
||||
# shutil.move(str(index_file), str(index_file_old))
|
||||
shutil.move(str(csr_temp_file), str(index_file))
|
||||
print(
|
||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||
)
|
||||
logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
||||
else:
|
||||
# Clean up and fail fast
|
||||
if csr_temp_file.exists():
|
||||
os.remove(csr_temp_file)
|
||||
raise RuntimeError(
|
||||
"CSR conversion failed - cannot proceed with compact format"
|
||||
)
|
||||
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
||||
|
||||
|
||||
class HNSWSearcher(BaseSearcher):
|
||||
@@ -119,7 +135,9 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
from . import faiss # type: ignore
|
||||
|
||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||
self.distance_metric = (
|
||||
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||
)
|
||||
metric_enum = get_metric_map().get(self.distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
@@ -135,34 +153,31 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
hnsw_config = faiss.HNSWIndexConfig()
|
||||
hnsw_config.is_compact = self.is_compact
|
||||
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
||||
hnsw_config.is_recompute = (
|
||||
self.is_pruned
|
||||
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
|
||||
|
||||
if self.is_pruned and not hnsw_config.is_recompute:
|
||||
raise RuntimeError("Index is pruned but recompute is disabled.")
|
||||
# If pruned (recompute mode), explicitly skip storage to avoid reading
|
||||
# the pruned section. Still allow MMAP for graph.
|
||||
io_flags = faiss.IO_FLAG_MMAP
|
||||
if self.is_pruned:
|
||||
io_flags |= faiss.IO_FLAG_SKIP_STORAGE
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
# Load label mapping
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
|
||||
|
||||
with open(label_map_file, "rb") as f:
|
||||
self.label_map = pickle.load(f)
|
||||
self._index = faiss.read_index(str(index_file), io_flags, hnsw_config)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
top_k: int,
|
||||
zmq_port: Optional[int] = None,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
batch_size: int = 0,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for nearest neighbors using HNSW index.
|
||||
|
||||
@@ -177,7 +192,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
- "global": Use global PQ queue size for selection (default)
|
||||
- "local": Local pruning, sort and select best candidates
|
||||
- "proportional": Base selection on new neighbor count ratio
|
||||
zmq_port: ZMQ port for embedding server
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
||||
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
||||
|
||||
@@ -186,26 +201,36 @@ class HNSWSearcher(BaseSearcher):
|
||||
"""
|
||||
from . import faiss # type: ignore
|
||||
|
||||
# Use recompute_embeddings parameter
|
||||
use_recompute = recompute_embeddings or self.is_pruned
|
||||
if use_recompute:
|
||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
if not meta_file_path.exists():
|
||||
raise RuntimeError(
|
||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
||||
)
|
||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||
if not recompute_embeddings and self.is_pruned:
|
||||
raise RuntimeError(
|
||||
"Recompute is required for pruned/compact HNSW index. "
|
||||
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
|
||||
)
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
if self.distance_metric == "cosine":
|
||||
faiss.normalize_L2(query)
|
||||
query = normalize_l2(query)
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
params.zmq_port = zmq_port
|
||||
if zmq_port is not None:
|
||||
params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||
params.efSearch = complexity
|
||||
params.beam_size = beam_width
|
||||
|
||||
# For OpenAI embeddings with cosine distance, disable relative distance check
|
||||
# This prevents early termination when all scores are in a narrow range
|
||||
embedding_model = self.meta.get("embedding_model", "").lower()
|
||||
if self.distance_metric == "cosine" and any(
|
||||
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
|
||||
):
|
||||
params.check_relative_distance = False
|
||||
else:
|
||||
params.check_relative_distance = True
|
||||
|
||||
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
||||
params.pq_pruning_ratio = prune_ratio
|
||||
|
||||
@@ -215,9 +240,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
params.send_neigh_times_ratio = 0.0
|
||||
elif pruning_strategy == "proportional":
|
||||
params.local_prune = False
|
||||
params.send_neigh_times_ratio = (
|
||||
1.0 # Any value > 1e-6 triggers proportional mode
|
||||
)
|
||||
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
|
||||
else: # "global"
|
||||
params.local_prune = False
|
||||
params.send_neigh_times_ratio = 0.0
|
||||
@@ -229,6 +252,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
||||
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
||||
|
||||
search_time = time.time()
|
||||
self._index.search(
|
||||
query.shape[0],
|
||||
faiss.swig_ptr(query),
|
||||
@@ -237,13 +261,60 @@ class HNSWSearcher(BaseSearcher):
|
||||
faiss.swig_ptr(labels),
|
||||
params,
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
for batch_labels in labels
|
||||
]
|
||||
search_time = time.time() - search_time
|
||||
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
|
||||
# ---------- Helper API for incremental add (Python-level) ----------
|
||||
def add_vectors(
|
||||
index_file_path: str,
|
||||
embeddings: np.ndarray,
|
||||
*,
|
||||
ef_construction: Optional[int] = None,
|
||||
recompute: bool = False,
|
||||
) -> None:
|
||||
"""Append vectors to an existing non-compact HNSW index.
|
||||
|
||||
Args:
|
||||
index_file_path: Path to the HNSW .index file
|
||||
embeddings: float32 numpy array (N, D)
|
||||
ef_construction: Optional override for efConstruction during insertion
|
||||
recompute: Reserved for future use to control insertion-time recompute behaviors
|
||||
"""
|
||||
from . import faiss # type: ignore
|
||||
|
||||
if embeddings.dtype != np.float32:
|
||||
embeddings = embeddings.astype(np.float32)
|
||||
if not embeddings.flags.c_contiguous:
|
||||
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
|
||||
# Load index normally to ensure storage is present; toggle is_recompute on the object
|
||||
index = faiss.read_index(str(index_file_path), faiss.IO_FLAG_MMAP)
|
||||
|
||||
# Best-effort: explicitly set flag on the object if the binding exposes it
|
||||
try:
|
||||
index.is_recompute = bool(recompute)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if ef_construction is not None:
|
||||
index.hnsw.efConstruction = int(ef_construction)
|
||||
except Exception:
|
||||
# Best-effort; ignore if backend doesn't expose setter
|
||||
pass
|
||||
|
||||
# For non-compact HNSW, calling add directly is sufficient. When is_recompute is set
|
||||
# (via config or attribute), FAISS will run the insertion/search path accordingly.
|
||||
# To strictly follow per-point insert semantics in recompute mode, add one-by-one.
|
||||
if recompute:
|
||||
# Insert row by row
|
||||
n = embeddings.shape[0]
|
||||
for i in range(n):
|
||||
row = embeddings[i : i + 1]
|
||||
index.add(1, faiss.swig_ptr(row))
|
||||
else:
|
||||
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||
faiss.write_index(index, str(index_file_path))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
149
packages/leann-backend-hnsw/leann_backend_hnsw/prune_index.py
Normal file
149
packages/leann-backend-hnsw/leann_backend_hnsw/prune_index.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import struct
|
||||
from pathlib import Path
|
||||
|
||||
from .convert_to_csr import (
|
||||
EXPECTED_HNSW_FOURCCS,
|
||||
NULL_INDEX_FOURCC,
|
||||
read_struct,
|
||||
read_vector_raw,
|
||||
)
|
||||
|
||||
|
||||
def _write_vector_raw(f_out, count: int, data_bytes: bytes) -> None:
|
||||
"""Write a vector in the same binary layout as read_vector_raw reads: <Q count> + raw bytes."""
|
||||
f_out.write(struct.pack("<Q", count))
|
||||
if count > 0 and data_bytes:
|
||||
f_out.write(data_bytes)
|
||||
|
||||
|
||||
def prune_embeddings_preserve_graph(input_filename: str, output_filename: str) -> bool:
|
||||
"""
|
||||
Copy an original (non-compact) HNSW index file while pruning the trailing embedding storage.
|
||||
Preserves the graph structure and metadata exactly; only writes a NULL storage marker instead of
|
||||
the original storage fourcc and payload.
|
||||
|
||||
Returns True on success.
|
||||
"""
|
||||
print(f"Pruning embeddings from {input_filename} to {output_filename}")
|
||||
print("--------------------------------")
|
||||
# running in mode is-recompute=True and is-compact=False
|
||||
in_path = Path(input_filename)
|
||||
out_path = Path(output_filename)
|
||||
|
||||
try:
|
||||
with open(in_path, "rb") as f_in, open(out_path, "wb") as f_out:
|
||||
# Header
|
||||
index_fourcc = read_struct(f_in, "<I")
|
||||
if index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
# Still proceed, but this is unexpected
|
||||
pass
|
||||
f_out.write(struct.pack("<I", index_fourcc))
|
||||
|
||||
d = read_struct(f_in, "<i")
|
||||
ntotal_hdr = read_struct(f_in, "<q")
|
||||
dummy1 = read_struct(f_in, "<q")
|
||||
dummy2 = read_struct(f_in, "<q")
|
||||
is_trained = read_struct(f_in, "?")
|
||||
metric_type = read_struct(f_in, "<i")
|
||||
f_out.write(struct.pack("<i", d))
|
||||
f_out.write(struct.pack("<q", ntotal_hdr))
|
||||
f_out.write(struct.pack("<q", dummy1))
|
||||
f_out.write(struct.pack("<q", dummy2))
|
||||
f_out.write(struct.pack("<?", is_trained))
|
||||
f_out.write(struct.pack("<i", metric_type))
|
||||
|
||||
if metric_type > 1:
|
||||
metric_arg = read_struct(f_in, "<f")
|
||||
f_out.write(struct.pack("<f", metric_arg))
|
||||
|
||||
# Vectors: assign_probas (double), cum_nneighbor_per_level (int32), levels (int32)
|
||||
cnt, data = read_vector_raw(f_in, "d")
|
||||
_write_vector_raw(f_out, cnt, data)
|
||||
|
||||
cnt, data = read_vector_raw(f_in, "i")
|
||||
_write_vector_raw(f_out, cnt, data)
|
||||
|
||||
cnt, data = read_vector_raw(f_in, "i")
|
||||
_write_vector_raw(f_out, cnt, data)
|
||||
|
||||
# Probe potential extra alignment/flag byte present in some original formats
|
||||
probe = f_in.read(1)
|
||||
if probe:
|
||||
if probe == b"\x00":
|
||||
# Preserve this unexpected 0x00 byte
|
||||
f_out.write(probe)
|
||||
else:
|
||||
# Likely part of the next vector; rewind
|
||||
f_in.seek(-1, os.SEEK_CUR)
|
||||
|
||||
# Offsets (uint64) and neighbors (int32)
|
||||
cnt, data = read_vector_raw(f_in, "Q")
|
||||
_write_vector_raw(f_out, cnt, data)
|
||||
|
||||
cnt, data = read_vector_raw(f_in, "i")
|
||||
_write_vector_raw(f_out, cnt, data)
|
||||
|
||||
# Scalar params
|
||||
entry_point = read_struct(f_in, "<i")
|
||||
max_level = read_struct(f_in, "<i")
|
||||
ef_construction = read_struct(f_in, "<i")
|
||||
ef_search = read_struct(f_in, "<i")
|
||||
dummy_upper_beam = read_struct(f_in, "<i")
|
||||
f_out.write(struct.pack("<i", entry_point))
|
||||
f_out.write(struct.pack("<i", max_level))
|
||||
f_out.write(struct.pack("<i", ef_construction))
|
||||
f_out.write(struct.pack("<i", ef_search))
|
||||
f_out.write(struct.pack("<i", dummy_upper_beam))
|
||||
|
||||
# Storage fourcc (if present) — write NULL marker and drop any remaining data
|
||||
try:
|
||||
read_struct(f_in, "<I")
|
||||
# Regardless of original, write NULL
|
||||
f_out.write(struct.pack("<I", NULL_INDEX_FOURCC))
|
||||
# Discard the rest of the file (embedding payload)
|
||||
# (Do not copy anything else)
|
||||
except EOFError:
|
||||
# No storage section; nothing else to write
|
||||
pass
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
# Best-effort cleanup
|
||||
try:
|
||||
if out_path.exists():
|
||||
out_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def prune_embeddings_preserve_graph_inplace(index_file_path: str) -> bool:
|
||||
"""
|
||||
Convenience wrapper: write pruned file to a temporary path next to the
|
||||
original, then atomically replace on success.
|
||||
"""
|
||||
print(f"Pruning embeddings from {index_file_path} to {index_file_path}")
|
||||
print("--------------------------------")
|
||||
# running in mode is-recompute=True and is-compact=False
|
||||
src = Path(index_file_path)
|
||||
tmp = src.with_suffix(".pruned.tmp")
|
||||
ok = prune_embeddings_preserve_graph(str(src), str(tmp))
|
||||
if not ok:
|
||||
if tmp.exists():
|
||||
try:
|
||||
tmp.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
try:
|
||||
os.replace(str(tmp), str(src))
|
||||
except Exception:
|
||||
# Rollback on failure
|
||||
try:
|
||||
if tmp.exists():
|
||||
tmp.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
return True
|
||||
@@ -6,13 +6,24 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.1.0"
|
||||
version = "0.3.4"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
dependencies = [
|
||||
"leann-core==0.3.4",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.scikit-build]
|
||||
wheel.packages = ["leann_backend_hnsw"]
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
build.tool-args = ["-j8"]
|
||||
|
||||
# CMake definitions to optimize compilation and find Homebrew packages
|
||||
[tool.scikit-build.cmake.define]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}
|
||||
OpenMP_ROOT = {env = "OpenMP_ROOT"}
|
||||
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 2547df4377...ea86d06ceb
Submodule packages/leann-backend-hnsw/third_party/msgpack-c updated: 9b801f087a...a0b2ec09da
@@ -4,16 +4,49 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.1.0"
|
||||
description = "Core API and plugin system for Leann."
|
||||
version = "0.3.4"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { text = "MIT" }
|
||||
|
||||
# All required dependencies included
|
||||
dependencies = [
|
||||
"numpy>=1.20.0",
|
||||
"tqdm>=4.60.0"
|
||||
"tqdm>=4.60.0",
|
||||
"psutil>=5.8.0",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
"torch>=2.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"llama-index-readers-file>=0.4.0", # Essential for document reading
|
||||
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
|
||||
"python-dotenv>=1.0.0",
|
||||
"openai>=1.0.0",
|
||||
"huggingface-hub>=0.20.0",
|
||||
"transformers>=4.30.0",
|
||||
"requests>=2.25.0",
|
||||
"accelerate>=0.20.0",
|
||||
"PyPDF2>=3.0.0",
|
||||
"pymupdf>=1.23.0",
|
||||
"pdfplumber>=0.10.0",
|
||||
"nbconvert>=7.0.0", # For .ipynb file support
|
||||
"gitignore-parser>=0.1.12", # For proper .gitignore handling
|
||||
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
colab = [
|
||||
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
|
||||
"transformers>=4.30.0,<5.0.0", # Limit transformers version
|
||||
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
leann = "leann.cli:main"
|
||||
leann_mcp = "leann.mcp:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
where = ["src"]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user