Compare commits

...

118 Commits

Author SHA1 Message Date
Andy Lee
d8b4ea7564 fix: add write permissions for GitHub Actions to push commits 2025-07-24 20:55:24 -07:00
Andy Lee
f0a2ef96b4 fix: restore complete build configuration from working version 2025-07-24 19:49:38 -07:00
Andy Lee
7d73c2c803 fix: remove invalid --extra build flag from build commands 2025-07-24 19:43:23 -07:00
Andy Lee
e8d2ecab03 refactor: use reusable workflow to avoid code duplication 2025-07-24 19:35:12 -07:00
Andy Lee
32a374d094 feat: true one-click automated release with multi-platform support 2025-07-24 19:30:44 -07:00
Andy Lee
d45c013806 fix: handle workflow trigger permission gracefully 2025-07-24 19:25:29 -07:00
GitHub Actions
9000a7083d chore: release v0.1.4 2025-07-25 02:23:36 +00:00
Andy Lee
8307555d54 fix: manually trigger CI after version push in release workflow 2025-07-24 19:21:32 -07:00
GitHub Actions
20f2aece08 chore: release v0.1.3 2025-07-25 02:05:11 +00:00
yichuan520030910320
43eb4f9a1d Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-24 19:03:52 -07:00
yichuan520030910320
5461b71d8c colab dev 2025-07-24 19:03:46 -07:00
Andy Lee
374db0ebb8 fix: release workflow to build new version before publishing 2025-07-24 19:03:09 -07:00
GitHub Actions
cea1f6f87c chore: release v0.1.2 2025-07-25 01:53:29 +00:00
Andy Lee
6c0e39372b fix: download all artifacts in release workflow 2025-07-24 17:45:46 -07:00
Andy Lee
2bec67d2b6 feat: auto-update leann-core dependencies during release
- Enhanced bump_version.sh to automatically update leann-core dependency versions
- Script now updates both package versions and their leann-core dependencies
- This ensures version consistency across all packages during release

No more manual dependency version updates needed!
2025-07-24 17:22:41 -07:00
Andy Lee
133e715832 fix: resolve CI issues and consolidate workflows
- Fix version dependencies: update backend packages to depend on leann-core==0.1.1
- Remove duplicate ci.yml workflow (keeping build-and-publish.yml as main CI)
- Update release-manual.yml to reference correct CI workflow name

This fixes the dependency resolution error and eliminates duplicate builds.
2025-07-24 17:20:58 -07:00
Andy Lee
95cf2f16e2 refactor: consolidate release and publish into single workflow
- Manual Release workflow now directly publishes to PyPI after downloading CI artifacts
- No more duplicate builds - reuses artifacts from CI
- build-and-publish.yml renamed to 'CI - Build Multi-Platform Packages'
- Publishing in CI workflow only for emergency manual triggers
- Updated RELEASE.md to reflect the new streamlined process

This fixes the issue where releases would trigger redundant builds.
2025-07-24 17:04:47 -07:00
Andy Lee
47a4c153eb fix: enable PyPI publish on tag push
- Manual Release workflow creates tags but build-and-publish.yml only published on 'release' events
- Now build-and-publish.yml will also publish when v* tags are pushed
- This fixes the issue where manual releases didn't trigger PyPI uploads
2025-07-24 17:00:21 -07:00
GitHub Actions
faf5ae3533 chore: release v0.1.1 2025-07-24 23:36:23 +00:00
Andy Lee
a44dccecac fix: make TestPyPI upload optional and non-blocking
- Add continue-on-error to TestPyPI step
- Check if TEST_PYPI_API_TOKEN exists before attempting upload
- Add graceful failure handling with clear messages
- Update docs to explain TestPyPI token configuration
- Clarify that TestPyPI testing is optional

Now the release won't fail if TestPyPI is not configured or upload fails
2025-07-24 16:02:07 -07:00
yichuan520030910320
9cf9358b9c Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-24 14:40:39 -07:00
yichuan520030910320
de252fef31 [chat] update 30s example 2025-07-24 14:40:33 -07:00
Andy Lee
9076bc27b8 fix: resolve CI run detection issues in release workflow
- Add 'actions: read' permission to access workflow runs
- Use workflow name instead of filename for gh run list
- Look for CI run on HEAD~1 (before version bump commit)
- Improve error messages for better debugging

Fixes HTTP 403 error when trying to find successful CI runs
2025-07-24 14:27:26 -07:00
Andy Lee
50686c0819 refactor: use CI artifacts in release workflow instead of rebuilding
- Download pre-built wheels from successful CI runs
- Avoids duplicate builds and ensures consistency
- CI artifacts are already tested across all platforms
- Faster release process (no build time)
- Updates release documentation to reflect new flow

This ensures the released packages are exactly what was tested in CI.
2025-07-24 14:24:03 -07:00
Andy Lee
1614203786 fix: make bump_version.sh work on both macOS and Linux
- macOS uses sed -i '' while Linux uses sed -i
- Add OS detection to use correct syntax
- Ensures script works in CI (Linux) and local dev (macOS)
2025-07-24 14:13:31 -07:00
Andy Lee
3d4c75a56c fix: add missing scripts directory to git
- Remove scripts/ from .gitignore
- Add build_and_test.sh for local testing
- Add bump_version.sh for version updates (used by CI)
- Add release.sh and upload_to_pypi.sh for publishing
- Fixes CI error: ./scripts/bump_version.sh: No such file or directory
2025-07-24 14:13:14 -07:00
Andy Lee
2684ee71dc fix: ensure uv build uses correct Python version in CI
- Add --python python flag to uv build commands
- This ensures wheels are built with the correct Python version (cp313 for Python 3.13, etc)
- Fixes issue where Python 3.13 CI was building cp311 wheels
- Also adds Python version verification before build
2025-07-24 13:44:02 -07:00
Andy Lee
1d321953ba ci: update all GitHub Actions to latest versions
- Update actions/upload-artifact from v3 to v4 (v3 deprecated April 2024)
- Update actions/setup-python from v4 to v5 (latest version)
- Add Python 3.12 and 3.13 to CI test matrix
- Ensure compatibility with latest Python versions and GitHub Actions
2025-07-24 13:36:21 -07:00
Andy Lee
b3cb251369 ci: add Python 3.12 and 3.13 to test matrix
- Add Python 3.12 and 3.13 to CI test matrix
- Ensure compatibility with latest Python versions
- Python 3.12 is stable, 3.13 was released in Oct 2024
2025-07-24 13:32:29 -07:00
Andy Lee
0a17d2c9d8 feat: implement comprehensive CI/CD pipeline with two-stage release
- Add ci.yml for continuous integration on every commit
  - Test builds on Ubuntu/macOS with Python 3.9/3.10/3.11
  - Ensure code quality before any release

- Add release-manual.yml for controlled releases
  - Manual trigger prevents accidental releases
  - Version validation and tag creation
  - Optional TestPyPI testing before production
  - Only creates tag after validation passes

- Keep build-and-publish.yml for automated PyPI deployment
  - Triggered by new tags (separation of concerns)
  - Handles multi-platform wheel building
  - Allows retry if PyPI upload fails

- Update RELEASE.md with clear prerequisites and workflow

This setup ensures:
1. Every commit is tested (CI)
2. Releases are deliberate (manual trigger)
3. Failed CI won't create broken tags
4. PyPI publish can be retried independently
2025-07-24 13:29:21 -07:00
Andy Lee
e3defbca84 fix: add minimal CI dependencies for HNSW and DiskANN backends
- HNSW (Ubuntu): add libopenblas-dev for BLAS requirements
- DiskANN (Ubuntu): keep MKL, remove redundant pkg-config (HNSW already has it)
- DiskANN (macOS): add protobuf for build requirements
- Both: ensure patchelf for auditwheel on Linux

This avoids OpenBLAS/MKL conflicts by using them in separate jobs
2025-07-24 01:06:57 -07:00
Andy Lee
e407f63977 chore: fix uv build 2025-07-24 00:51:57 -07:00
Andy Lee
7add391b2c chore: build and package 2025-07-24 00:47:46 -07:00
yichuan520030910320
efd6373b32 [chat] update huggingface chat and make qwen no thinking 2025-07-24 00:11:42 -07:00
yichuan520030910320
d502fa24b0 [installation] update install for linux 2025-07-24 02:17:17 +00:00
yichuan520030910320
258a9a5c7f [misc]test link again 2025-07-23 18:29:32 -07:00
yichuan520030910320
5d41ac6115 test link 2025-07-23 18:28:22 -07:00
yichuan520030910320
2a0fdb49b8 test link 2025-07-23 18:27:08 -07:00
yichuan520030910320
9d1b7231b6 fix broken link 2025-07-23 18:25:22 -07:00
yichuan520030910320
ed3095b478 fix broken link 2025-07-23 18:24:17 -07:00
yichuan520030910320
88eca75917 fix readme 2025-07-23 18:22:10 -07:00
yichuan520030910320
42de27e16a Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-23 18:17:19 -07:00
yichuan520030910320
c083bda5b7 fix several bug 2025-07-23 18:17:11 -07:00
Andy Lee
e86da38726 fix: ollama hint for similar models 2025-07-23 15:45:10 -07:00
yichuan520030910320
99076e38bc update install 2025-07-23 14:55:34 -07:00
yichuan520030910320
9698c1a02c fix readme 2025-07-23 14:52:01 -07:00
yichuan520030910320
851f0f04c3 fix some para 2025-07-23 01:46:34 -07:00
yichuan520030910320
ae16d9d888 fix readme 2025-07-23 00:44:43 -07:00
yichuan520030910320
6e1af2eb0c fix readme 2025-07-23 00:43:46 -07:00
yichuan520030910320
7695dd0d50 fix readme 2025-07-23 00:42:17 -07:00
yichuan520030910320
c2065473ad fix readme 2025-07-23 00:30:42 -07:00
yichuan520030910320
5f3870564d Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-23 00:09:30 -07:00
yichuan520030910320
c214b2e33e fix readme 2025-07-23 00:09:24 -07:00
Andy Lee
2420c5fd35 chore: update sentence-transformer to prevent MixIn not found error 2025-07-22 23:27:25 -07:00
yichuan520030910320
f48f526f0a fix readme 2025-07-22 23:21:15 -07:00
yichuan520030910320
5dd74982ba fix readme 2025-07-22 23:14:31 -07:00
Andy Lee
e07aaf52a7 docs: align 2025-07-22 22:37:27 -07:00
Andy Lee
30e5f12616 docs: quick start 2025-07-22 22:33:04 -07:00
Andy Lee
594427bf87 docs: demo 2025-07-22 22:32:18 -07:00
yichuan520030910320
a97d3ada1c fix readme need to polish example 2025-07-22 22:09:55 -07:00
yichuan520030910320
6217bb5638 fix readme 2025-07-22 22:05:28 -07:00
yichuan520030910320
2760e99e18 fix readme 2025-07-22 22:03:19 -07:00
yichuan520030910320
0544f96b79 default main cli to openai add data dict as a args 2025-07-22 21:56:30 -07:00
yichuan520030910320
2ebb29de65 default main cli to openai 2025-07-22 21:55:18 -07:00
yichuan520030910320
43762d44c7 fix readme 2025-07-22 21:51:30 -07:00
yichuan520030910320
cdaf0c98be fix readme 2025-07-22 21:44:52 -07:00
yichuan520030910320
aa9a14a917 make the email wonderful format 2025-07-22 21:41:58 -07:00
yichuan520030910320
9efcc6d95c Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-22 20:44:02 -07:00
yichuan520030910320
f3f5d91207 make the google history wonderful format 2025-07-22 20:43:56 -07:00
Andy Lee
6070160959 chore: remove .vscode 2025-07-22 19:59:35 -07:00
Andy Lee
43155d2811 fix: supress resources leak logs 2025-07-22 19:53:45 -07:00
Andy Lee
d3f85678ec perf: much faster loading and embedding serving 2025-07-22 19:38:22 -07:00
yichuan520030910320
2a96d05b21 upd readme 2025-07-22 17:06:33 -07:00
yichuan520030910320
851e888535 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-22 17:01:04 -07:00
yichuan520030910320
90120d4dff upd the structure in the chat for better perf 2025-07-22 17:00:56 -07:00
Andy Lee
8513471573 feat: make diskann runnable 2025-07-22 14:26:03 -07:00
Andy Lee
71e5f1774c docs: cli 2025-07-21 23:48:40 -07:00
yichuan520030910320
870a443446 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-21 23:13:45 -07:00
yichuan520030910320
cefaa2a4cc upd readme 2025-07-21 23:13:38 -07:00
Andy Lee
ab72a2ab9d fix: more logs 2025-07-21 23:08:53 -07:00
yichuan520030910320
046d457d22 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-21 23:04:00 -07:00
yichuan520030910320
7fd0a30fee upd log 2025-07-21 23:03:52 -07:00
Andy Lee
c2f35c8e73 fix: logs 2025-07-21 23:02:13 -07:00
Andy Lee
573313f0b6 refactor: logs 2025-07-21 22:45:24 -07:00
yichuan520030910320
f7af6805fa readme 2025-07-21 22:33:03 -07:00
yichuan520030910320
966de3a399 readme 2025-07-21 22:32:02 -07:00
yichuan520030910320
8a75829f3a readme 2025-07-21 22:30:03 -07:00
yichuan520030910320
0f7e34b9e2 readme 2025-07-21 22:18:00 -07:00
yichuan520030910320
be0322b616 readme 2025-07-21 22:16:52 -07:00
yichuan520030910320
232a525a62 readme 2025-07-21 22:14:43 -07:00
yichuan520030910320
587ce65cf6 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-21 21:54:27 -07:00
yichuan520030910320
ccf6c8bfd7 fix flush print 2025-07-21 21:54:20 -07:00
Andy Lee
c112956d2d fix: mlx 2025-07-21 21:29:15 -07:00
Andy Lee
b3970793cf fix: cache the loaded model 2025-07-21 21:20:53 -07:00
yichuan520030910320
727724990e add todo 2025-07-21 20:59:09 -07:00
yichuan520030910320
530f6e4af5 add progress bar in build 2025-07-21 20:55:18 -07:00
Andy Lee
2f224f5793 fix: use server to emb query only when recompute 2025-07-21 20:40:21 -07:00
Andy Lee
1b6272ce0e Building, CLI tool & Embedding Server Fixed (#5)
* chore: shorter build time

* chore: update faiss

* fix: no longger do embedding server reuse

* fix: do not reuse emb_server and close it properly

* feat: cli tool

* feat: cli more args

* fix: same embedding logic
2025-07-21 20:17:25 -07:00
yichuan520030910320
5259ace111 [Readme] 2025-07-21 20:06:21 -07:00
yichuan520030910320
48ea5566e9 [Readme] detail number 2025-07-21 19:51:51 -07:00
yichuan520030910320
3f8b6c5bbd [Readme] 2025-07-21 18:15:00 -07:00
yichuan520030910320
725b32e74f [Readme] 2025-07-21 17:57:44 -07:00
yichuan520030910320
d0b71f393f [Readme] 2025-07-21 17:56:10 -07:00
yichuan520030910320
8a92efdae3 [Readme] 2025-07-21 17:53:59 -07:00
yichuan520030910320
019cdce2e8 [Readme] 2025-07-21 17:30:11 -07:00
yichuan520030910320
b64aa54fac fix break link 2025-07-21 17:29:35 -07:00
yichuan520030910320
c0d040f9d4 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-21 16:22:24 -07:00
yichuan520030910320
32364320f8 update wechat and we should fix the bug introduced in 1c5fec5 2025-07-21 16:22:16 -07:00
Andy Lee
34c71c072d chore: parallel compile fix 2025-07-19 22:51:47 -07:00
Andy Lee
6d2149c503 chore: parallel compile fix 2025-07-19 22:46:24 -07:00
Andy Lee
043b0bf69d chore: parallel compile fix 2025-07-19 22:34:19 -07:00
Andy Lee
9b07e392c6 chore: parallel compile 2025-07-19 22:32:13 -07:00
Andy Lee
e60fad8c73 chore: mark diskann as optional 2025-07-19 22:24:44 -07:00
Andy Lee
19c1b182c3 docs: effects figure 2025-07-19 22:07:04 -07:00
Andy Lee
49edea780c docs: figure 2025-07-19 21:59:58 -07:00
Andy Lee
12ef5a1900 docs: effects 2025-07-19 21:57:12 -07:00
Andy Lee
d21a134b2a docs: polish 2025-07-19 21:53:41 -07:00
Andy Lee
1cd809aa41 [Docs] README polished version (#4)
* docs: polish

* docs: logo

* docs: logo

* docs: logo with text

* docs: readme effects

* docs: polish

* docs: highlight applications

* docs: polish

* docs: how it works earlier

* docs: polish

* docs: polish

* docs: follow yichuan's suggestion

* docs: follow yichuan's suggestion

---------

Co-authored-by: Yichuan Wang <73766326+yichuan-w@users.noreply.github.com>
2025-07-19 21:47:25 -07:00
54 changed files with 4308 additions and 4385 deletions

11
.github/workflows/build-and-publish.yml vendored Normal file
View File

@@ -0,0 +1,11 @@
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
uses: ./.github/workflows/build-reusable.yml

163
.github/workflows/build-reusable.yml vendored Normal file
View File

@@ -0,0 +1,163 @@
name: Reusable Build
on:
workflow_call:
inputs:
ref:
description: 'Git ref to build'
required: false
type: string
default: ''
jobs:
build:
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
strategy:
matrix:
include:
- os: ubuntu-latest
python: '3.9'
- os: ubuntu-latest
python: '3.10'
- os: ubuntu-latest
python: '3.11'
- os: ubuntu-latest
python: '3.12'
- os: macos-latest
python: '3.9'
- os: macos-latest
python: '3.10'
- os: macos-latest
python: '3.11'
- os: macos-latest
python: '3.12'
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
brew install llvm libomp boost protobuf zeromq
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
else
uv pip install --system delocate
fi
- name: Build packages
run: |
# Build core (platform independent)
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
cd packages/leann-core
uv build
cd ../..
fi
# Build HNSW backend
cd packages/leann-backend-hnsw
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build DiskANN backend
cd packages/leann-backend-diskann
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build meta package (platform independent)
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
cd packages/leann
uv build
cd ../..
fi
- name: Repair wheels (Linux)
if: runner.os == 'Linux'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: Repair wheels (macOS)
if: runner.os == 'macOS'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: List built packages
run: |
echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/

103
.github/workflows/release-manual.yml vendored Normal file
View File

@@ -0,0 +1,103 @@
name: Release
on:
workflow_dispatch:
inputs:
version:
description: 'Version to release (e.g., 0.1.2)'
required: true
type: string
jobs:
update-version:
name: Update Version
runs-on: ubuntu-latest
permissions:
contents: write
outputs:
commit-sha: ${{ steps.push.outputs.commit-sha }}
steps:
- uses: actions/checkout@v4
- name: Validate version
run: |
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format"
exit 1
fi
echo "✅ Version format valid"
- name: Update versions and push
id: push
run: |
./scripts/bump_version.sh ${{ inputs.version }}
git config user.name "GitHub Actions"
git config user.email "actions@github.com"
git add packages/*/pyproject.toml
git commit -m "chore: release v${{ inputs.version }}"
git push origin main
COMMIT_SHA=$(git rev-parse HEAD)
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
echo "✅ Pushed version update: $COMMIT_SHA"
build-packages:
name: Build packages
needs: update-version
uses: ./.github/workflows/build-reusable.yml
with:
ref: ${{ needs.update-version.outputs.commit-sha }}
publish:
name: Publish and Release
needs: build-packages
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
ref: ${{ needs.update-version.outputs.commit-sha }}
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: dist-artifacts
- name: Collect packages
run: |
mkdir -p dist
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
echo "📦 Packages to publish:"
ls -la dist/
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
if [ -z "$TWINE_PASSWORD" ]; then
echo "❌ PYPI_API_TOKEN not configured!"
exit 1
fi
pip install twine
twine upload dist/* --skip-existing --verbose
echo "✅ Published to PyPI!"
- name: Create release
run: |
git tag "v${{ inputs.version }}"
git push origin "v${{ inputs.version }}"
gh release create "v${{ inputs.version }}" \
--title "Release v${{ inputs.version }}" \
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
--latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

5
.gitignore vendored
View File

@@ -12,7 +12,6 @@ outputs/
*.idx
*.map
.history/
scripts/
lm_eval.egg-info/
demo/experiment_results/**/*.json
*.jsonl
@@ -84,4 +83,6 @@ test_*.py
packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json
*.passages.json
batchtest.py

4
.gitmodules vendored
View File

@@ -1,9 +1,9 @@
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
path = packages/leann-backend-diskann/third_party/DiskANN
url = https://github.com/yichuan520030910320/DiskANN.git
url = https://github.com/yichuan-w/DiskANN.git
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
path = packages/leann-backend-hnsw/third_party/faiss
url = https://github.com/yichuan520030910320/faiss.git
url = https://github.com/yichuan-w/faiss.git
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
path = packages/leann-backend-hnsw/third_party/msgpack-c
url = https://github.com/msgpack/msgpack-c.git

View File

@@ -1,9 +0,0 @@
{
"recommendations": [
"llvm-vs-code-extensions.vscode-clangd",
"ms-python.python",
"ms-vscode.cmake-tools",
"vadimcn.vscode-lldb",
"eamodio.gitlens",
]
}

283
.vscode/launch.json vendored
View File

@@ -1,283 +0,0 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
// new emdedder
{
"name": "New Embedder",
"type": "debugpy",
"request": "launch",
"program": "demo/main.py",
"console": "integratedTerminal",
"args": [
"--search",
"--use-original",
"--domain",
"dpr",
"--nprobe",
"5000",
"--load",
"flat",
"--embedder",
"intfloat/multilingual-e5-small"
]
}
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
{
"name": "main.py",
"type": "debugpy",
"request": "launch",
"program": "demo/main.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--query",
"1000",
"--load",
"bm25"
]
},
{
"name": "Simple Build",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"faiss/demo/simple_build.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
//# Fix for Intel MKL error
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
//python faiss/demo/build_demo.py
{
"name": "Build Demo",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"faiss/demo/build_demo.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
{
"name": "DiskANN Serve",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--mode",
"serve",
"--engine",
"sglang",
"--load-indices",
"diskann",
"--domain",
"rpj_wiki",
"--lazy-load",
"--recompute-beighbor-embeddings",
"--port",
"8082",
"--diskann-search-memory-maximum",
"2",
"--diskann-graph",
"240",
"--search-only"
],
"env": {
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
},
"preLaunchTask": "CMake: build",
},
{
"name": "DiskANN Serve MAC",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--mode",
"serve",
"--engine",
"ollama",
"--load-indices",
"diskann",
"--domain",
"rpj_wiki",
"--lazy-load",
"--recompute-beighbor-embeddings"
],
"preLaunchTask": "CMake: build",
"env": {
"KMP_DUPLICATE_LIB_OK": "TRUE",
"OMP_NUM_THREADS": "1",
"MKL_NUM_THREADS": "1",
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
"KMP_BLOCKTIME": "0"
}
},
{
"name": "Python Debugger: Current File with Arguments",
"type": "debugpy",
"request": "launch",
"program": "ric/main_ric.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--config-name",
"${input:configSelection}"
],
"justMyCode": false
},
//python ./demo/validate_equivalence.py sglang
{
"name": "Validate Equivalence",
"type": "debugpy",
"request": "launch",
"program": "demo/validate_equivalence.py",
"console": "integratedTerminal",
"args": [
"sglang"
],
},
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
{
"name": "Retrieval Demo",
"type": "debugpy",
"request": "launch",
"program": "demo/retrieval_demo.py",
"console": "integratedTerminal",
"args": [
"--engine",
"vllm",
"--skip-embeddings",
"--domain",
"dpr",
"--load-indices",
// "flat",
"ivf_flat"
],
},
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
{
"name": "Retrieval Demo DiskANN",
"type": "debugpy",
"request": "launch",
"program": "demo/retrieval_demo.py",
"console": "integratedTerminal",
"args": [
"--engine",
"sglang",
"--skip-embeddings",
"--domain",
"dpr",
"--load-indices",
"diskann",
"--hnsw-M",
"64",
"--hnsw-efConstruction",
"150",
"--hnsw-efSearch",
"128",
"--hnsw-sq-bits",
"8"
],
},
{
"name": "Find Probe",
"type": "debugpy",
"request": "launch",
"program": "find_probe.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
},
{
"name": "Python: Attach",
"type": "debugpy",
"request": "attach",
"processId": "${command:pickProcess}",
"justMyCode": true
},
{
"name": "Edge RAG",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"edgerag_demo.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
"MKL_NUM_THREADS": "1",
"OMP_NUM_THREADS": "1",
}
},
{
"name": "Launch Embedding Server",
"type": "debugpy",
"request": "launch",
"program": "demo/embedding_server.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--domain",
"rpj_wiki",
"--zmq-port",
"5556",
]
},
{
"name": "HNSW Serve",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--domain",
"rpj_wiki",
"--load",
"hnsw",
"--mode",
"serve",
"--search",
"--skip-pa",
"--recompute",
"--hnsw-old"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
],
"inputs": [
{
"id": "configSelection",
"type": "pickString",
"description": "Select a configuration",
"options": [
"example_config",
"vllm_gritlm"
],
"default": "example_config"
}
],
}

43
.vscode/settings.json vendored
View File

@@ -1,43 +0,0 @@
{
"python.analysis.extraPaths": [
"./sglang_repo/python"
],
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
"cmake.configureArgs": [
"-DPYBIND=True",
"-DUPDATE_EDITABLE_INSTALL=ON",
],
"cmake.environment": {
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
},
"cmake.buildDirectory": "${workspaceFolder}/build",
"files.associations": {
"*.tcc": "cpp",
"deque": "cpp",
"string": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"map": "cpp",
"unordered_set": "cpp",
"atomic": "cpp",
"inplace_vector": "cpp",
"*.ipp": "cpp",
"forward_list": "cpp",
"list": "cpp",
"any": "cpp",
"system_error": "cpp",
"__hash_table": "cpp",
"__split_buffer": "cpp",
"__tree": "cpp",
"ios": "cpp",
"set": "cpp",
"__string": "cpp",
"string_view": "cpp",
"ranges": "cpp",
"iosfwd": "cpp"
},
"lldb.displayFormat": "auto",
"lldb.showDisassembly": "auto",
"lldb.dereferencePointers": true,
"lldb.consoleMode": "commands",
}

16
.vscode/tasks.json vendored
View File

@@ -1,16 +0,0 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "cmake",
"label": "CMake: build",
"command": "build",
"targets": [
"all"
],
"group": "build",
"problemMatcher": [],
"detail": "CMake template build task"
}
]
}

757
README.md
View File

@@ -1,104 +1,85 @@
<h1 align="center">🚀 LEANN: A Low-Storage Vector Index</h1>
<p align="center">
<img src="assets/logo-text.png" alt="LEANN Logo" width="400">
</p>
<p align="center">
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs Welcome">
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
</p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
The smallest vector index in the world. RAG Everything with LEANN!
</h2>
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
## Why LEANN?
<p align="center">
<strong>💾 Extreme Storage Saving • 🔒 100% Private • 📚 RAG Everything • ⚡ Easy & Accurate</strong>
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
</p>
<p align="center">
<a href="#-quick-start">Quick Start</a> •
<a href="#-features">Features</a> •
<a href="#-benchmarks">Benchmarks</a> •
<a href="https://arxiv.org/abs/2506.08276" target="_blank">Paper</a>
</p>
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
---
## 🌟 What is LEANN-RAG?
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
**LEANN-RAG** is a lightweight, locally deployable **Retrieval-Augmented Generation (RAG)** engine designed for personal devices. It combines **compact storage**, **clean usability**, and **privacy-by-design**, making it easy to build personalized retrieval systems over your own data — emails, notes, documents, chats, or anything else.
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
Unlike traditional vector databases that rely on massive embedding storage, LEANN reduces storage needs dramatically by using **graph-based recomputation** and **pruned HNSW search**, while maintaining responsive and reliable performance — all without sending any data to the cloud.
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
---
**No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
## 🔥 Key Highlights
### 💾 1. Extreme Storage Efficiency
LEANN reduces storage usage by **up to 97%** compared to conventional vector DBs (e.g., FAISS), by storing only pruned graph structures and computing embeddings at query time.
> For example: 60M chunks can be indexed in just **6GB**, compared to **200GB+** with dense storage.
### 🔒 2. Fully Private, Cloud-Free
LEANN runs entirely locally. No cloud services, no API keys, and no risk of leaking sensitive data.
> Converse with your own files **without compromising privacy**.
### 🧠 3. RAG Everything
Build truly personalized assistants by querying over **your own** chat logs, email archives, browser history, or agent memory.
> LEANN makes it easy to integrate personal context into RAG workflows.
### ⚡ 4. Easy, Accurate, and Fast
LEANN is designed to be **easy to install**, with a **clean API** and minimal setup. It runs efficiently on consumer hardware without sacrificing retrieval accuracy.
> One command to install, one click to run.
---
## 🚀 Why Choose LEANN?
Traditional RAG systems often require trade-offs between storage, privacy, and usability. **LEANN-RAG aims to simplify the stack** with a more practical design:
-**No embedding storage** — compute on demand, save disk space
-**Low memory footprint** — lightweight and hardware-friendly
-**Privacy-first** — 100% local, no network dependency
-**Simple to use** — developer-friendly API and seamless setup
> 📄 For more details, see our [academic paper](https://arxiv.org/abs/2506.08276)
## 🚀 Quick Start
### Installation
## Installation
```bash
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
git clone git@github.com:yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
```
**macOS:**
```bash
brew install llvm libomp boost protobuf
export CC=$(brew --prefix llvm)/bin/clang
export CXX=$(brew --prefix llvm)/bin/clang++
uv sync
brew install llvm libomp boost protobuf zeromq pkgconf
# Install with HNSW backend (default, recommended for most users)
# Install uv first if you don't have it:
# curl -LsSf https://astral.sh/uv/install.sh | sh
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
```
**Linux (Ubuntu/Debian):**
**Linux:**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
# Install with HNSW backend (default, recommended for most users)
uv sync
```
**Ollama Setup (Optional for Local LLM):**
*macOS:*
**Ollama Setup (Recommended for full privacy):**
> *You can skip this installation if you only want to use OpenAI API for generation.*
**macOS:**
First, [download Ollama for macOS](https://ollama.com/download/mac).
```bash
# Install Ollama
brew install ollama
```bash
# Pull a lightweight model (recommended for consumer hardware)
ollama pull llama3.2:1b
# For better performance but higher memory usage
ollama pull llama3.2:3b
```
*Linux:*
**Linux:**
```bash
# Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh
@@ -108,53 +89,367 @@ ollama serve &
# Pull a lightweight model (recommended for consumer hardware)
ollama pull llama3.2:1b
# For better performance but higher memory usage
ollama pull llama3.2:3b
```
**Note:** For Hugging Face models >1B parameters, you may encounter OOM errors on consumer hardware. Consider using smaller models like Qwen3-0.6B or switch to Ollama for better memory management.
## Quick Start in 30s
### 30-Second Example
Try it out in [**demo.ipynb**](demo.ipynb)
Our declarative API makes RAG as easy as writing a config file.
[Try in this ipynb file →](demo.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
```python
from leann.api import LeannBuilder, LeannSearcher
# 1. Build index (no embeddings stored!)
from leann.api import LeannBuilder, LeannSearcher, LeannChat
# 1. Build the index (no embeddings stored!)
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("C# is a powerful programming language")
builder.add_text("Python is a powerful programming language")
builder.add_text("Machine learning transforms industries")
builder.add_text("Python is a powerful programming language and it is very popular")
builder.add_text("Machine learning transforms industries")
builder.add_text("Neural networks process complex data")
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
builder.build_index("knowledge.leann")
# 2. Search with real-time embeddings
searcher = LeannSearcher("knowledge.leann")
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
print(results)
results = searcher.search("programming languages", top_k=2)
# 3. Chat with LEANN using retrieved results
llm_config = {
"type": "ollama",
"model": "llama3.2:1b"
}
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
response = chat.ask(
"Compare the two retrieved programming languages and say which one is more popular today.",
top_k=2,
)
```
### Run the Demo (support .pdf,.txt,.docx, .pptx, .csv, .md etc)
## RAG on Everything!
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
```bash
# Drop your PDFs, .txt, .md files into examples/data/
uv run ./examples/main_cli_example.py
```
or you want to use python
```bash
```
# Or use python directly
source .venv/bin/activate
python ./examples/main_cli_example.py
```
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
This demo showcases how to build a RAG system for PDF/md documents using Leann.
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
```bash
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
```
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
# Use default mail path (works for most macOS setups)
python examples/mail_reader_leann.py
# Run with custom index directory
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
# Process all emails (may take time but indexes everything)
python examples/mail_reader_leann.py --max-emails -1
# Limit number of emails processed (useful for testing)
python examples/mail_reader_leann.py --max-emails 1000
# Run a single query
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
```
</details>
<details>
<summary><strong>📋 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "Find emails from my boss about deadlines"
- "What did John say about the project timeline?"
- "Show me emails about travel expenses"
</details>
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History!
```bash
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
```
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
# Use default Chrome profile (auto-finds all profiles)
python examples/google_history_reader_leann.py
# Run with custom index directory
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
# Limit number of history entries processed (useful for testing)
python examples/google_history_reader_leann.py --max-entries 500
# Run a single query
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
```
</details>
<details>
<summary><strong>📋 Click to expand: How to find your Chrome profile</strong></summary>
The default Chrome profile path is configured for a typical macOS setup. If you need to find your specific Chrome profile:
1. Open Terminal
2. Run: `ls ~/Library/Application\ Support/Google/Chrome/`
3. Look for folders like "Default", "Profile 1", "Profile 2", etc.
4. Use the full path as your `--chrome-profile` argument
**Common Chrome profile locations:**
- macOS: `~/Library/Application Support/Google/Chrome/Default`
- Linux: `~/.config/google-chrome/Default`
</details>
<details>
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "What websites did I visit about machine learning?"
- "Find my search history about programming"
- "What YouTube videos did I watch recently?"
- "Show me websites I visited about travel planning"
</details>
### 💬 WeChat Detective: Unlock Your Golden Memories!
```bash
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
```
**400K messages → 64MB storage** Search years of chat history in any language.
<details>
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
First, you need to install the WeChat exporter:
```bash
sudo packages/wechat-exporter/wechattweak-cli install
```
**Troubleshooting:**
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
- **Export errors**: If you encounter the error below, try restarting WeChat
```
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
Failed to find or export WeChat data. Exiting.
```
</details>
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
# Use default settings (recommended for first run)
python examples/wechat_history_reader_leann.py
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
# Run with custom index directory
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
# Limit number of chat entries processed (useful for testing)
python examples/wechat_history_reader_leann.py --max-entries 1000
# Run a single query
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
```
</details>
<details>
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?" (Chinese: Show me chat records about buying Magic Johnson's jersey)
</details>
## 🖥️ Command Line Interface
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
```bash
# Build an index from documents
leann build my-docs --docs ./documents
# Search your documents
leann search my-docs "machine learning concepts"
# Interactive chat with your documents
leann ask my-docs --interactive
# List all your indexes
leann list
```
**Key CLI features:**
- Auto-detects document formats (PDF, TXT, MD, DOCX)
- Smart text chunking with overlap
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
- Organized index storage in `~/.leann/indexes/`
- Support for advanced search parameters
<details>
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
**Build Command:**
```bash
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
Options:
--backend {hnsw,diskann} Backend to use (default: hnsw)
--embedding-model MODEL Embedding model (default: facebook/contriever)
--graph-degree N Graph degree (default: 32)
--complexity N Build complexity (default: 64)
--force Force rebuild existing index
--compact Use compact storage (default: true)
--recompute Enable recomputation (default: true)
```
**Search Command:**
```bash
leann search INDEX_NAME QUERY [OPTIONS]
Options:
--top-k N Number of results (default: 5)
--complexity N Search complexity (default: 64)
--recompute-embeddings Use recomputation for highest accuracy
--pruning-strategy {global,local,proportional}
```
**Ask Command:**
```bash
leann ask INDEX_NAME [OPTIONS]
Options:
--llm {ollama,openai,hf} LLM provider (default: ollama)
--model MODEL Model name (default: qwen3:8b)
--interactive Interactive chat mode
--top-k N Retrieval count (default: 20)
```
</details>
## 🏗️ Architecture & How It Works
<p align="center">
<img src="assets/arch.png" alt="LEANN Architecture" width="800">
</p>
**The magic:** Most vector DBs store every single embedding (expensive). LEANN stores a pruned graph structure (cheap) and recomputes embeddings only when needed (fast).
**Core techniques:**
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
**Backends:** DiskANN or HNSW - pick what works for your data size.
## Benchmarks
Run the comparison yourself:
```bash
python examples/compare_faiss_vs_leann.py
```
| System | Storage |
|--------|---------|
| FAISS HNSW | 5.5 MB |
| LEANN | 0.5 MB |
| **Savings** | **91%** |
Same dataset, same hardware, same embedding model. LEANN just works better.
### Storage Usage Comparison
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (780K messages chunks) |Google Search History (38K entries)
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 2.4G |130.4 MB |
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **79 MB** |**6.4MB** |
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **97% smaller** |**95% smaller** |
<!-- ### Memory Usage Comparison
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
| --------------------- | ---------------- | ---------------- | ---------------- |
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
| **Leann** | **xx MB** | **x GB** | **x GB** |
| **Reduction** | **x%** | **x%** | **x%** |
### Query Performance of LEANN
| Backend | Index Size | Query Time | Recall@3 |
| ------------------- | ---------- | ---------- | --------- |
| DiskANN | 1M docs | xms | 0.95 |
| HNSW | 1M docs | xms | 0.95 | -->
*Benchmarks run on Apple M3 Pro 36 GB*
## Reproduce Our Results
```bash
uv pip install -e ".[dev]" # Install dev dependencies
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
```
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
## 🔬 Paper
If you find Leann useful, please cite:
**[LEANN: A Low-Storage Vector Index](https://arxiv.org/abs/2506.08276)**
```bibtex
@misc{wang2025leannlowstoragevectorindex,
title={LEANN: A Low-Storage Vector Index},
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
year={2025},
eprint={2506.08276},
archivePrefix={arXiv},
primaryClass={cs.DB},
url={https://arxiv.org/abs/2506.08276},
}
```
## ✨ Features
### 🔥 Core Features
@@ -178,283 +473,6 @@ This demo showcases how to build a RAG system for PDF/md documents using Leann.
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
## Applications on your MacBook
### 📧 Lightweight RAG on your Apple Mail
LEANN can create a searchable index of your Apple Mail emails, allowing you to query your email history using natural language.
#### Quick Start
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
# Use default mail path (works for most macOS setups)
python examples/mail_reader_leann.py
# Run with custom index directory
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
# embedd and search all of your email(this may take a long preprocessing time but it will encode all your emails)
python examples/mail_reader_leann.py --max-emails -1
# Limit number of emails processed (useful for testing)
python examples/mail_reader_leann.py --max-emails 1000
# Run a single query
python examples/mail_reader_leann.py --query "Whats the number of class recommend to take per semester for incoming EECS students"
```
</details>
#### Example Queries
<details>
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "Find emails from my boss about deadlines"
- "What did John say about the project timeline?"
- "Show me emails about travel expenses"
</details>
### 🌐 Lightweight RAG on your Google Chrome History
LEANN can create a searchable index of your Chrome browser history, allowing you to query your browsing history using natural language.
#### Quick Start
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
Note you need to quit google right now to successfully run this.
```bash
# Use default Chrome profile (auto-finds all profiles) and recommand method to run this because usually default file is enough
python examples/google_history_reader_leann.py
# Run with custom index directory
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
# Limit number of history entries processed (useful for testing)
python examples/google_history_reader_leann.py --max-entries 500
# Run a single query
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
# Use only a specific profile (disable auto-find)
python examples/google_history_reader_leann.py --chrome-profile "~/Library/Application Support/Google/Chrome/Default" --no-auto-find-profiles
```
</details>
#### Finding Your Chrome Profile
<details>
<summary><strong>🔍 Click to expand: How to find your Chrome profile</strong></summary>
The default Chrome profile path is configured for a typical macOS setup. If you need to find your specific Chrome profile:
1. Open Terminal
2. Run: `ls ~/Library/Application\ Support/Google/Chrome/`
3. Look for folders like "Default", "Profile 1", "Profile 2", etc.
4. Use the full path as your `--chrome-profile` argument
**Common Chrome profile locations:**
- macOS: `~/Library/Application Support/Google/Chrome/Default`
- Linux: `~/.config/google-chrome/Default`
</details>
#### Example Queries
<details>
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "What websites did I visit about machine learning?"
- "Find my search history about programming"
- "What YouTube videos did I watch recently?"
- "Show me websites I visited about travel planning"
</details>
### 💬 Lightweight RAG on your WeChat History
LEANN can create a searchable index of your WeChat chat history, allowing you to query your conversations using natural language.
#### Prerequisites
<details>
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
First, you need to install the WeChat exporter:
```bash
sudo packages/wechat-exporter/wechattweak-cli install
```
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
</details>
#### Quick Start
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
# Use default settings (recommended for first run)
python examples/wechat_history_reader_leann.py
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
# Run with custom index directory
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
# Limit number of chat entries processed (useful for testing)
python examples/wechat_history_reader_leann.py --max-entries 1000
# Run a single query
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
```
</details>
#### Example Queries
<details>
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
Once the index is built, you can ask questions like:
- "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?" (Chinese: Show me chat records about buying Magic Johnson's jersey)
</details>
## ⚡ Performance Comparison
### LEANN vs Faiss HNSW
We benchmarked LEANN against the popular Faiss HNSW implementation to demonstrate the significant memory and storage savings our approach provides:
```bash
# Run the comparison benchmark
python examples/compare_faiss_vs_leann.py
```
#### 🎯 Results Summary
| Metric | Faiss HNSW | LEANN HNSW | **Improvement** |
|--------|------------|-------------|-----------------|
| **Storage Size** | 5.5 MB | 0.5 MB | **11.4x smaller** (5.0 MB saved) |
#### 📈 Key Takeaways
- **💾 Storage Optimization**: LEANN requires **91% less storage** for the same dataset
- **⚖️ Fair Comparison**: Both systems tested on identical hardware with the same 2,573 document dataset and the same embedding model and chunk method
> **Note**: Results may vary based on dataset size, hardware configuration, and query patterns. The comparison excludes text storage to focus purely on index structures.
*Benchmark results obtained on Apple Silicon with consistent environmental conditions*
## 📊 Benchmarks
### How to Reproduce Evaluation Results
Reproducing our benchmarks is straightforward. The evaluation script is designed to be self-contained, automatically downloading all necessary data on its first run.
#### 1. Environment Setup
First, ensure you have followed the installation instructions in the [Quick Start](#-quick-start) section. This will install all core dependencies.
Next, install the optional development dependencies, which include the `huggingface-hub` library required for automatic data download:
```bash
# This command installs all development dependencies
uv pip install -e ".[dev]"
```
#### 2. Run the Evaluation
Simply run the evaluation script. The first time you run it, it will detect that the data is missing, download it from Hugging Face Hub, and then proceed with the evaluation.
**To evaluate the DPR dataset:**
```bash
python examples/run_evaluation.py data/indices/dpr/dpr_diskann
```
**To evaluate the RPJ-Wiki dataset:**
```bash
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index
```
The script will print the recall and search time for each query, followed by the average results.
### Storage Usage Comparison
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
<!-- ### Memory Usage Comparison
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
| --------------------- | ---------------- | ---------------- | ---------------- |
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
| **Leann** | **xx MB** | **x GB** | **x GB** |
| **Reduction** | **x%** | **x%** | **x%** |
### Query Performance of LEANN
| Backend | Index Size | Query Time | Recall@3 |
| ------------------- | ---------- | ---------- | --------- |
| DiskANN | 1M docs | xms | 0.95 |
| HNSW | 1M docs | xms | 0.95 | -->
*Benchmarks run on Apple M3 Pro 36 GB*
## 🏗️ Architecture
<p align="center">
<img src="asset/arch.png" alt="LEANN Architecture" width="800">
</p>
## 🔬 Paper
If you find Leann useful, please cite:
**[LEANN: A Low-Storage Vector Index](https://arxiv.org/abs/2506.08276)**
```bibtex
@misc{wang2025leannlowstoragevectorindex,
title={LEANN: A Low-Storage Vector Index},
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
year={2025},
eprint={2506.08276},
archivePrefix={arXiv},
primaryClass={cs.DB},
url={https://arxiv.org/abs/2506.08276},
}
```
## 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
@@ -490,6 +508,17 @@ export NCCL_IB_DISABLE=1
export NCCL_NET_PLUGIN=none
export NCCL_SOCKET_IFNAME=ens5
``` -->
## FAQ
### 1. My building time seems long
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
## 📈 Roadmap

View File

Before

Width:  |  Height:  |  Size: 78 KiB

After

Width:  |  Height:  |  Size: 78 KiB

BIN
assets/effects.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 339 KiB

BIN
assets/logo-text.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 818 KiB

BIN
assets/logo.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

View File

@@ -1,35 +1,321 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quick Start in 30s"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
"# 1. Build index (no embeddings stored!)\n",
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
"builder.add_text(\"C# is a powerful programming language but it is not very popular\")\n",
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
"builder.add_text(\"Machine learning transforms industries\") \n",
"builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Leann is a great storage saving engine for RAG on your macbook\")\n",
"builder.build_index(\"knowledge.leann\")\n",
"# 2. Search with real-time embeddings\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n",
"results = searcher.search(\"programming languages\", top_k=2, recompute_beighbor_embeddings=True)\n",
"print(results)\n",
"# install this if you areusing colab\n",
"! pip install leann"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the index"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"M: 64 for level: 0\n",
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
"[0.00s] Reading Index HNSW header...\n",
"[0.00s] Header read: d=768, ntotal=5\n",
"[0.00s] Reading HNSW struct vectors...\n",
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
"[0.00s] Read assign_probas (6)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
"[0.11s] Read cum_nneighbor_per_level (7)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
"[0.21s] Read levels (5)\n",
"[0.30s] Probing for compact storage flag...\n",
"[0.30s] Found compact flag: False\n",
"[0.30s] Compact flag is False, reading original format...\n",
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
"[0.30s] Read offsets (6)\n",
"[0.40s] Attempting to read neighbors vector...\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
"[0.40s] Read neighbors (320)\n",
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
"[0.50s] Checking for storage data...\n",
"[0.50s] Found storage fourcc: 49467849.\n",
"[0.50s] Converting to CSR format...\n",
"[0.50s] Conversion loop finished. \n",
"[0.50s] Running validation checks...\n",
" Checking total valid neighbor count...\n",
" OK: Total valid neighbors = 20\n",
" Checking final pointer indices...\n",
" OK: Final pointers match data size.\n",
"[0.50s] Deleting original neighbors and offsets arrays...\n",
" CSR Stats: |data|=20, |level_ptr|=10\n",
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
" Pruning embeddings: Writing NULL storage marker.\n",
"[0.69s] Conversion complete.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
]
}
],
"source": [
"from leann.api import LeannBuilder\n",
"\n",
"llm_config = {\"type\": \"ollama\", \"model\": \"qwen3:8b\"}\n",
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
"builder.add_text(\"Machine learning transforms industries\")\n",
"builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
"builder.build_index(\"knowledge.leann\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Search with real-time embeddings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'programming languages'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n",
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n"
]
},
{
"data": {
"text/plain": [
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from leann.api import LeannSearcher\n",
"\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n",
"results = searcher.search(\"programming languages\", top_k=2)\n",
"results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chat with LEANN using retrieved results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n",
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"data": {
"text/plain": [
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from leann.api import LeannChat\n",
"\n",
"llm_config = {\n",
" \"type\": \"hf\",\n",
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
"}\n",
"\n",
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
"\n",
"response = chat.ask(\n",
" \"Compare the two retrieved programming languages and say which one is more popular today. Respond in a single well-formed sentence.\",\n",
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
" top_k=2,\n",
" recompute_beighbor_embeddings=True,\n",
" llm_kwargs={\"max_tokens\": 128}\n",
")\n",
"print(response)"
"response"
]
}
],

22
docs/RELEASE.md Normal file
View File

@@ -0,0 +1,22 @@
# Release Guide
## Setup (One-time)
Add `PYPI_API_TOKEN` to GitHub Secrets:
1. Get token: https://pypi.org/manage/account/token/
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
## Release (One-click)
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
2. Click "Run workflow"
3. Enter version: `0.1.2`
4. Click green "Run workflow" button
That's it! The workflow will automatically:
- ✅ Update version in all packages
- ✅ Build all packages
- ✅ Publish to PyPI
- ✅ Create GitHub tag and release
Check progress: https://github.com/yichuan-w/LEANN/actions

View File

@@ -96,14 +96,12 @@ class EmlxReader(BaseReader):
# Create document content with metadata embedded in text
doc_content = f"""
[EMAIL METADATA]
File: {filename}
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
[END METADATA]
[File]: {filename}
[From]: {from_addr}
[To]: {to_addr}
[Subject]: {subject}
[Date]: {date}
[EMAIL BODY Start]:
{body}
"""

View File

@@ -65,12 +65,14 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
if not all_documents:
print("No documents loaded from any source. Exiting.")
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
return None
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
@@ -78,7 +80,9 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
text = node.get_content()
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
all_texts.append(text)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
@@ -225,7 +229,7 @@ async def main():
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
parser.add_argument('--index-dir', type=str, default="./all_google_new",
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
parser.add_argument('--max-entries', type=int, default=1000,
help='Maximum number of history entries to process (default: 1000)')

View File

@@ -74,22 +74,17 @@ class ChromeHistoryReader(BaseReader):
# Create document content with metadata embedded in text
doc_content = f"""
[BROWSING HISTORY METADATA]
URL: {url}
Title: {title}
Last Visit: {last_visit}
Visit Count: {visit_count}
Typed Count: {typed_count}
Hidden: {hidden}
[END METADATA]
Title: {title}
URL: {url}
Last visited: {last_visit}
[Title]: {title}
[URL of the page]: {url}
[Last visited time]: {last_visit}
[Visit times]: {visit_count}
[Typed times]: {typed_count}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={})
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
# if len(title) > 150:
# print(f"Title is too long: {title}")
docs.append(doc)
count += 1

View File

@@ -197,8 +197,8 @@ class WeChatHistoryReader(BaseReader):
Args:
messages: List of message dictionaries
max_length: Maximum length for concatenated message groups
time_window_minutes: Time window in minutes to group messages together
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
overlap_messages: Number of messages to overlap between consecutive groups
Returns:
@@ -230,8 +230,8 @@ class WeChatHistoryReader(BaseReader):
if not readable_text.strip():
continue
# Check time window constraint
if last_timestamp is not None and create_time > 0:
# Check time window constraint (only if time_window_minutes != -1)
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
time_diff_minutes = (create_time - last_timestamp) / 60
if time_diff_minutes > time_window_minutes:
# Time gap too large, start new group
@@ -250,9 +250,9 @@ class WeChatHistoryReader(BaseReader):
current_group = []
current_length = 0
# Check length constraint
# Check length constraint (only if max_length != -1)
message_length = len(readable_text)
if current_length + message_length > max_length and current_group:
if max_length != -1 and current_length + message_length > max_length and current_group:
# Current group would exceed max length, save it and start new
concatenated_groups.append({
'messages': current_group,
@@ -335,14 +335,15 @@ class WeChatHistoryReader(BaseReader):
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%H:%M:%S')
# change to YYYY-MM-DD HH:MM:SS
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
sender = "Me" if is_sent_from_self else "Contact"
message_parts.append(f"[{time_str}] {sender}: {readable_text}")
sender = "[Me]" if is_sent_from_self else "[Contact]"
message_parts.append(f"({time_str}) {sender}: {readable_text}")
concatenated_text = "\n".join(message_parts)
@@ -354,13 +355,11 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
{concatenated_text}
"""
# TODO @yichuan give better format and rich info here!
doc_content = f"""
Contact: {contact_name}
{concatenated_text}
"""
return doc_content
return doc_content, contact_name
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
@@ -431,9 +430,9 @@ Contact: {contact_name}
# Concatenate messages based on rules
message_groups = self._concatenate_messages(
readable_messages,
max_length=max_length,
time_window_minutes=time_window_minutes,
overlap_messages=2 # Keep 2 messages overlap between groups
max_length=-1,
time_window_minutes=-1,
overlap_messages=0 # Keep 2 messages overlap between groups
)
# Create documents from concatenated groups
@@ -441,8 +440,8 @@ Contact: {contact_name}
if count >= max_count and max_count > 0:
break
doc_content = self._create_concatenated_content(message_group, contact_name)
doc = Document(text=doc_content, metadata={})
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
docs.append(doc)
count += 1

View File

@@ -22,7 +22,7 @@ def get_mail_path():
return os.path.join(home_dir, "Library", "Mail")
# Default mail path for macOS
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
@@ -74,7 +74,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
@@ -85,9 +85,11 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
text = node.get_content()
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
all_texts.append(text)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
# Create LEANN index directory
@@ -156,7 +158,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
print(f"Loaded {len(documents)} email documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
@@ -216,11 +218,10 @@ async def query_leann_index(index_path: str, query: str):
start_time = time.time()
chat_response = chat.ask(
query,
top_k=10,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=12,
complexity=32,
beam_width=1,
)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
@@ -231,7 +232,7 @@ async def main():
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
# Remove --mail-path argument and auto-detect all Messages directories
# Remove DEFAULT_MAIL_PATH
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
parser.add_argument('--index-dir', type=str, default="./mail_index_index_file",
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
parser.add_argument('--max-emails', type=int, default=1000,
help='Maximum number of emails to process (-1 means all)')
@@ -251,6 +252,9 @@ async def main():
mail_path = get_mail_path()
print(f"Searching for email data in: {mail_path}")
messages_dirs = find_all_messages_directories(mail_path)
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
# messages_dirs = [DEFAULT_MAIL_PATH]
# messages_dirs = messages_dirs[:1]
print('len(messages_dirs): ', len(messages_dirs))

View File

@@ -1,40 +1,40 @@
import argparse
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
import asyncio
import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat
import shutil
from leann.api import LeannBuilder, LeannChat
from pathlib import Path
dotenv.load_dotenv()
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
print("Loading documents...")
documents = SimpleDirectoryReader(
"examples/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
async def main(args):
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
print(f"\n[PHASE 1] Building Leann index...")
print("Loading documents...")
documents = SimpleDirectoryReader(
args.data_dir,
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print("--- Index directory not found, building new index ---")
print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
@@ -58,8 +58,9 @@ async def main(args):
print(f"\n[PHASE 2] Starting Leann chat session...")
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
llm_config = {"type": "ollama", "model": "qwen3:8b"}
llm_config = {"type": "openai", "model": "gpt-4o"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
@@ -70,9 +71,7 @@ async def main(args):
# )
print(f"You: {query}")
chat_response = chat.ask(
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
)
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
print(f"Leann: {chat_response}")
@@ -105,6 +104,12 @@ if __name__ == "__main__":
default="./test_doc_files",
help="Directory where the Leann index will be stored.",
)
parser.add_argument(
"--data-dir",
type=str,
default="examples/data",
help="Directory containing documents to index (PDF, TXT, MD files).",
)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -74,11 +74,11 @@ def create_leann_index_from_multiple_wechat_exports(
return None
print(
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
)
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
# Convert Documents to text strings and chunk them
all_texts = []
@@ -86,10 +86,11 @@ def create_leann_index_from_multiple_wechat_exports(
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
all_texts.append(text)
print(
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
)
# Create LEANN index directory
@@ -224,7 +225,7 @@ async def query_leann_index(index_path: str, query: str):
query,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=64,
complexity=16,
beam_width=1,
llm_config={
"type": "openai",
@@ -252,13 +253,13 @@ async def main():
parser.add_argument(
"--index-dir",
type=str,
default="./wechat_history_june19_test",
default="./wechat_history_magic_test_11Debug_new",
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
)
parser.add_argument(
"--max-entries",
type=int,
default=5000,
default=50,
help="Maximum number of chat entries to process (default: 5000)",
)
parser.add_argument(

View File

@@ -1,8 +1,8 @@
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
cmake_minimum_required(VERSION 3.20)
project(leann_backend_diskann_wrapper)
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
# DiskANN will handle everything itself, including compiling Python bindings
add_subdirectory(src/third_party/DiskANN)

View File

@@ -1,10 +1,12 @@
import numpy as np
import os
import struct
import sys
from pathlib import Path
from typing import Dict, Any, List, Literal
from typing import Dict, Any, List, Literal, Optional
import contextlib
import pickle
import logging
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
@@ -14,6 +16,46 @@ from leann.interface import (
LeannBackendSearcherInterface,
)
logger = logging.getLogger(__name__)
@contextlib.contextmanager
def suppress_cpp_output_if_needed():
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
if not should_suppress:
# Don't suppress, just yield
yield
return
# Save original file descriptors
stdout_fd = sys.stdout.fileno()
stderr_fd = sys.stderr.fileno()
# Save original stdout/stderr
stdout_dup = os.dup(stdout_fd)
stderr_dup = os.dup(stderr_fd)
try:
# Redirect to /dev/null
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, stdout_fd)
os.dup2(devnull, stderr_fd)
os.close(devnull)
yield
finally:
# Restore original file descriptors
os.dup2(stdout_dup, stdout_fd)
os.dup2(stderr_dup, stderr_fd)
os.close(stdout_dup)
os.close(stderr_dup)
def _get_diskann_metrics():
from . import _diskannpy as diskannpy # type: ignore
@@ -65,22 +107,20 @@ class DiskannBuilder(LeannBackendBuilderInterface):
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32)
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, "wb") as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs}
metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower()
)
if metric_enum is None:
raise ValueError("Unsupported distance_metric.")
raise ValueError(
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
)
try:
from . import _diskannpy as diskannpy # type: ignore
@@ -102,36 +142,40 @@ class DiskannBuilder(LeannBackendBuilderInterface):
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
os.remove(temp_data_file)
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
super().__init__(
index_path,
backend_module_name="leann_backend_diskann.embedding_server",
backend_module_name="leann_backend_diskann.diskann_embedding_server",
**kwargs,
)
from . import _diskannpy as diskannpy # type: ignore
distance_metric = kwargs.get("distance_metric", "mips").lower()
metric_enum = _get_diskann_metrics().get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
# Initialize DiskANN index with suppressed C++ output based on log level
with suppress_cpp_output_if_needed():
from . import _diskannpy as diskannpy # type: ignore
self.num_threads = kwargs.get("num_threads", 8)
self.zmq_port = kwargs.get("zmq_port", 6666)
distance_metric = kwargs.get("distance_metric", "mips").lower()
metric_enum = _get_diskann_metrics().get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum,
full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
self.zmq_port,
"",
"",
)
self.num_threads = kwargs.get("num_threads", 8)
fake_zmq_port = 6666
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum,
full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
fake_zmq_port, # Initial port, can be updated at runtime
"",
"",
)
def search(
self,
@@ -142,7 +186,7 @@ class DiskannSearcher(BaseSearcher):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
zmq_port: Optional[int] = None,
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs,
@@ -161,7 +205,7 @@ class DiskannSearcher(BaseSearcher):
- "global": Use global pruning strategy (default)
- "local": Use local pruning strategy
- "proportional": Not supported in DiskANN, falls back to global
zmq_port: ZMQ port for embedding server
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
@@ -169,22 +213,25 @@ class DiskannSearcher(BaseSearcher):
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
# Handle zmq_port compatibility: DiskANN can now update port at runtime
if recompute_embeddings:
if zmq_port is None:
raise ValueError(
"zmq_port must be provided if recompute_embeddings is True"
)
current_port = self._index.get_zmq_port()
if zmq_port != current_port:
logger.debug(
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
)
self._index.set_zmq_port(zmq_port)
# DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional":
raise NotImplementedError(
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
)
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -194,28 +241,26 @@ class DiskannSearcher(BaseSearcher):
else: # "global"
use_global_pruning = True
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False),
use_recompute,
dedup_node_dis,
prune_ratio,
batch_recompute,
use_global_pruning,
)
# Perform search with suppressed C++ output based on log level
with suppress_cpp_output_if_needed():
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False),
recompute_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
use_global_pruning,
)
string_labels = [
[
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
for batch_labels in labels
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -0,0 +1,283 @@
"""
DiskANN-specific embedding server
"""
import argparse
import threading
import time
import os
import zmq
import numpy as np
import json
from pathlib import Path
from typing import Optional
import sys
import logging
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logger = logging.getLogger(__name__)
# Force set logger level (don't rely on basicConfig in subprocess)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Ensure we have a handler if none exists
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
def create_diskann_embedding_server(
passages_file: Optional[str] = None,
zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
embedding_mode: str = "sentence-transformers",
):
"""
Create and start a ZMQ-based embedding server for DiskANN backend.
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
"""
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
logger.info(f"Using embedding mode: {embedding_mode}")
# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager
logger.info("Successfully imported unified embedding computation module")
except ImportError as e:
logger.error(f"Failed to import embedding computation module: {e}")
return
finally:
sys.path.pop(0)
# Check port availability
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port):
logger.error(f"Port {zmq_port} is already in use")
return
# Only support metadata file, fail fast for everything else
if not passages_file or not passages_file.endswith(".meta.json"):
raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources
with open(passages_file, "r") as f:
meta = json.load(f)
passages = PassageManager(meta["passage_sources"])
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
# Import protobuf after ensuring the path is correct
try:
from . import embedding_pb2
except ImportError as e:
logger.error(f"Failed to import protobuf module: {e}")
return
def zmq_server_thread():
"""ZMQ server thread using REP socket for universal compatibility"""
context = zmq.Context()
socket = context.socket(
zmq.REP
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
while True:
try:
# REP socket receives single-part messages
message = socket.recv()
# Check for empty messages - REP socket requires response to every request
if len(message) == 0:
logger.debug("Received empty message, sending empty response")
socket.send(b"") # REP socket must respond to every request
continue
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
e2e_start = time.time()
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
texts = []
node_ids = []
is_text_request = False
try:
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = list(req_proto.node_ids)
if not node_ids:
raise RuntimeError(
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
)
logger.info(
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
)
except Exception as protobuf_error:
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
# Fallback to msgpack (for BaseSearcher direct text requests)
try:
import msgpack
request = msgpack.unpackb(message)
# For BaseSearcher compatibility, request is a list of texts directly
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
)
else:
raise ValueError("Not a valid msgpack text request")
except Exception as msgpack_error:
raise RuntimeError(
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
)
# Look up texts by node IDs (only if not direct text request)
if not is_text_request:
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt)
except KeyError as e:
logger.error(f"Passage ID {nid} not found: {e}")
raise e
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Debug logging
logger.debug(f"Processing {len(texts)} texts")
logger.debug(
f"Text lengths: {[len(t) for t in texts[:5]]}"
) # Show first 5
# Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Prepare response based on request type
if is_text_request:
# For BaseSearcher compatibility: return msgpack format
import msgpack
response_data = msgpack.packb(embeddings.tolist())
else:
# For DiskANN C++ compatibility: return protobuf format
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(
embeddings, dtype=np.float32
)
# Serialize embeddings data
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
response_data = resp_proto.SerializeToString()
# Send response back to the client
socket.send(response_data)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
logger.error(f"Error in ZMQ server loop: {e}")
import traceback
traceback.print_exc()
raise
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("DiskANN Server shutting down...")
return
if __name__ == "__main__":
import signal
import sys
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument(
"--passages-file",
type=str,
help="Metadata JSON file containing passage sources",
)
parser.add_argument(
"--model-name",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx"],
help="Embedding backend mode",
)
args = parser.parse_args()
# Create and start the DiskANN embedding server
create_diskann_embedding_server(
passages_file=args.passages_file,
zmq_port=args.zmq_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
)

View File

@@ -1,741 +0,0 @@
#!/usr/bin/env python3
"""
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
"""
import pickle
import argparse
import time
import json
from typing import Dict, Any, Optional, Union
from transformers import AutoTokenizer, AutoModel
import os
from contextlib import contextmanager
import zmq
import numpy as np
import msgpack
from pathlib import Path
import logging
RED = "\033[91m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
RESET = "\033[0m"
# --- New Passage Loader from HNSW backend ---
class SimplePassageLoader:
"""
Simple passage loader that replaces config.py dependencies
"""
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {}
self._meta_path = ''
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID"""
str_id = str(passage_id)
if str_id in self.passages_data:
return {"text": self.passages_data[str_id]}
else:
# Return empty text for missing passages
return {"text": ""}
def __len__(self) -> int:
return len(self.passages_data)
def keys(self):
return self.passages_data.keys()
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""
Load passages using metadata file with PassageManager for lazy loading
"""
# Load metadata to get passage sources
with open(meta_file, 'r') as f:
meta = json.load(f)
# Import PassageManager dynamically to avoid circular imports
import sys
from pathlib import Path
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
passage_manager = PassageManager(meta['passage_sources'])
finally:
sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
else:
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
except Exception as e:
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
def __len__(self) -> int:
return len(self.label_map)
def keys(self):
return self.label_map.keys()
loader = LazyPassageLoader(passage_manager, label_map)
loader._meta_path = meta_file
return loader
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
if not os.path.exists(passages_file):
raise FileNotFoundError(f"Passages file {passages_file} not found.")
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_embedding_server_thread(
zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128,
passages_file: Optional[str] = None,
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = False,
):
"""
Create and run embedding server in the current thread
This function is designed to be called in a separate thread
"""
logger.info(f"Initializing embedding server thread on port {zmq_port}")
try:
# Check if port is already occupied
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}")
return
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
embedding_mode = "openai"
if embedding_mode == "mlx":
from leann.api import compute_embeddings_mlx
import torch
logger.info("Using MLX for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
mps_available = False
elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai
import torch
logger.info("Using OpenAI API for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
mps_available = False
elif embedding_mode == "sentence-transformers":
# Initialize model
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
import torch
# Select device
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
if cuda_available:
device = torch.device("cuda")
logger.info("Using CUDA device")
elif mps_available:
device = torch.device("mps")
logger.info("Using MPS device (Apple Silicon)")
else:
device = torch.device("cpu")
logger.info("Using CPU device")
# Load model
logger.info(f"Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval()
# Optimize model
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
logger.info(f"Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
else:
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
# Load passages from file if provided
if passages_file and os.path.exists(passages_file):
# Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = load_passages_from_file(passages_file)
else:
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader()
logger.info(f"Loaded {len(passages)} passages.")
def client_warmup(zmq_port):
"""Perform client-side warmup for DiskANN server"""
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
# Get actual passage IDs from the loaded passages
sample_ids = []
if hasattr(passages, 'keys') and len(passages) > 0:
available_ids = list(passages.keys())
# Take up to 5 actual IDs, but at least 1
sample_ids = available_ids[:min(5, len(available_ids))]
print(f"Using actual passage IDs for warmup: {sample_ids}")
else:
print("No passages available for warmup, skipping warmup...")
return
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000)
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
print("Warning: Could not convert sample IDs to integers, skipping warmup")
return
if not ids_to_send:
print("Skipping warmup send.")
return
# Use protobuf format for warmup
from . import embedding_pb2
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.node_ids.extend(ids_to_send)
request_bytes = req_proto.SerializeToString()
for i in range(3):
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
socket.send(request_bytes)
response_bytes = socket.recv()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
resp_proto.ParseFromString(response_bytes)
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
time.sleep(0.1)
print("Client-side Protobuf ZMQ warmup complete")
socket.close()
context.term()
except Exception as e:
print(f"Error during Protobuf ZMQ warmup: {e}")
class DeviceTimer:
"""Device timer"""
def __init__(self, name="", device=device):
self.name = name
self.device = device
self.start_time = 0
self.end_time = 0
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
else:
self.start_event = None
self.end_event = None
@contextmanager
def timing(self):
self.start()
yield
self.end()
def start(self):
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
torch.cuda.synchronize()
self.start_event.record()
else:
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
torch.mps.synchronize()
self.start_time = time.time()
def end(self):
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
self.end_event.record()
torch.cuda.synchronize()
else:
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
torch.mps.synchronize()
self.end_time = time.time()
def elapsed_time(self):
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
return self.start_event.elapsed_time(self.end_event) / 1000.0
else:
return self.end_time - self.start_time
def print_elapsed(self):
elapsed = self.elapsed_time()
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
"""Process text batch"""
if not texts_batch:
return np.array([])
# Filter out empty texts and their corresponding IDs
valid_texts = []
valid_ids = []
for i, text in enumerate(texts_batch):
if text.strip(): # Only include non-empty texts
valid_texts.append(text)
valid_ids.append(ids_batch[i])
if not valid_texts:
print("WARNING: No valid texts in batch")
return np.array([])
# Tokenize
token_timer = DeviceTimer("tokenization")
with token_timer.timing():
inputs = tokenizer(
valid_texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(device)
# Compute embeddings
embed_timer = DeviceTimer("embedding computation")
with embed_timer.timing():
with torch.no_grad():
outputs = model(**inputs)
hidden_states = outputs.last_hidden_state
# Mean pooling
attention_mask = inputs['attention_mask']
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
batch_embeddings = sum_embeddings / sum_mask
embed_timer.print_elapsed()
return batch_embeddings.cpu().numpy()
# ZMQ server main loop - modified to use REP socket
context = zmq.Context()
socket = context.socket(zmq.ROUTER) # Changed to REP socket
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
# Set timeouts
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
from . import embedding_pb2
print(f"INFO: Embedding server ready to serve requests")
# Start warmup thread if enabled
if enable_warmup and len(passages) > 0:
import threading
print(f"Warmup enabled: starting warmup thread")
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
warmup_thread.daemon = True
warmup_thread.start()
else:
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
while True:
try:
parts = socket.recv_multipart()
# --- Restore robust message format detection ---
# Must check parts length to avoid IndexError
if len(parts) >= 3:
identity = parts[0]
# empty = parts[1] # We usually don't care about the middle empty frame
message = parts[2]
elif len(parts) == 2:
# Can also handle cases without empty frame
identity = parts[0]
message = parts[1]
else:
# If received message format is wrong, print warning and ignore it instead of crashing
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
continue
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
# Handle control messages (MessagePack format)
try:
request_payload = msgpack.unpackb(message)
if isinstance(request_payload, list) and len(request_payload) >= 1:
if request_payload[0] == "__QUERY_META_PATH__":
# Return the current meta path being used by the server
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
response = [current_meta_path]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
# Update the server's meta path and reload passages
new_meta_path = request_payload[1]
try:
print(f"INFO: Updating server meta path to: {new_meta_path}")
# Reload passages from the new meta file
passages = load_passages_from_metadata(new_meta_path)
# Store the meta path for future queries
passages._meta_path = new_meta_path
response = ["SUCCESS"]
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
except Exception as e:
print(f"ERROR: Failed to update meta path: {e}")
response = ["FAILED", str(e)]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__QUERY_MODEL__":
# Return the current model being used by the server
response = [model_name]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
# Update the server's embedding model
new_model_name = request_payload[1]
try:
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
# Clean up old model to free memory
if not use_mlx:
print("INFO: Releasing old model from memory...")
old_model = model
old_tokenizer = tokenizer
# Load new tokenizer first
print(f"Loading new tokenizer for {new_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
# Load new model
print(f"Loading new model {new_model_name}...")
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
# Optimize new model
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {new_model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# Now safely delete old model after new one is loaded
del old_model
del old_tokenizer
# Clear GPU cache if available
if device.type == "cuda":
torch.cuda.empty_cache()
print("INFO: Cleared CUDA cache")
elif device.type == "mps":
torch.mps.empty_cache()
print("INFO: Cleared MPS cache")
# Force garbage collection
import gc
gc.collect()
print("INFO: Memory cleanup completed")
# Update model name
model_name = new_model_name
response = ["SUCCESS"]
print(f"INFO: Successfully updated model to: {new_model_name}")
except Exception as e:
print(f"ERROR: Failed to update model: {e}")
response = ["FAILED", str(e)]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
except:
# Not a control message, continue with normal protobuf processing
pass
e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup")
# Parse request
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = req_proto.node_ids
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
# Add debug information
if len(node_ids) > 0:
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
# Look up texts
texts = []
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
txtinfo = passages[nid]
txt = txtinfo["text"]
if txt:
texts.append(txt)
else:
# If text is empty, we still need a placeholder for batch processing,
# but record its ID as missing
texts.append("")
missing_ids.append(nid)
lookup_timer.print_elapsed()
if missing_ids:
print(f"WARNING: Missing passages for IDs: {missing_ids}")
# Process batch
total_size = len(texts)
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
all_embeddings = []
if total_size > max_batch_size:
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
if embedding_mode == "mlx":
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
elif embedding_mode == "openai":
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
else: # sentence-transformers
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
all_embeddings.append(embeddings_chunk)
if embedding_mode == "sentence-transformers":
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
hidden = np.vstack(all_embeddings)
print(f"INFO: Combined embeddings shape: {hidden.shape}")
else:
if embedding_mode == "mlx":
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
elif embedding_mode == "openai":
hidden = compute_embeddings_openai(texts, model_name)
else: # sentence-transformers
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
# Serialize response
ser_start = time.time()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
resp_proto.missing_ids.extend(missing_ids)
response_data = resp_proto.SerializeToString()
# REP socket sends a single response
socket.send_multipart([identity, b'', response_data])
ser_end = time.time()
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
if embedding_mode == "sentence-transformers":
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
except zmq.Again:
print("INFO: ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
print(f"ERROR: Error in ZMQ server: {e}")
try:
# Send empty response to maintain REQ-REP state
empty_resp = embedding_pb2.NodeEmbeddingResponse()
socket.send(empty_resp.SerializeToString())
except:
# If sending fails, recreate socket
socket.close()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 5000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
print("INFO: ZMQ socket recreated after error")
except Exception as e:
print(f"ERROR: Failed to start embedding server: {e}")
raise
def create_embedding_server(
domain="demo",
load_passages=True,
load_embeddings=False,
use_fp16=True,
use_int8=False,
use_cuda_graphs=False,
zmq_port=5555,
max_batch_size=128,
lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = False,
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False)
parser.add_argument("--use-int8", action="store_true", default=False)
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
choices=["sentence-transformers", "mlx", "openai"],
help="Embedding backend mode")
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
args = parser.parse_args()
# Handle backward compatibility with use_mlx
embedding_mode = args.embedding_mode
if args.use_mlx:
embedding_mode = "mlx"
create_embedding_server(
domain=args.domain,
load_passages=args.load_passages,
load_embeddings=args.load_embeddings,
use_fp16=args.use_fp16,
use_int8=args.use_int8,
use_cuda_graphs=args.use_cuda_graphs,
zmq_port=args.zmq_port,
max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name,
passages_file=args.passages_file,
embedding_mode=embedding_mode,
enable_warmup=not args.disable_warmup,
)

View File

@@ -4,13 +4,16 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.1.0"
dependencies = ["leann-core==0.1.0", "numpy"]
version = "0.1.4"
dependencies = ["leann-core==0.1.4", "numpy"]
[tool.scikit-build]
# 关键:简化的 CMake 路径
# Key: simplified CMake path
cmake.source-dir = "third_party/DiskANN"
# 关键:Python 包在根目录,路径完全匹配
# Key: Python package in root directory, paths match exactly
wheel.packages = ["leann_backend_diskann"]
# 使用默认的 redirect 模式
editable.mode = "redirect"
# Use default redirect mode
editable.mode = "redirect"
cmake.build-type = "Release"
build.verbose = true
build.tool-args = ["-j8"]

View File

@@ -1,6 +1,7 @@
# 最终简化版
cmake_minimum_required(VERSION 3.24)
project(leann_backend_hnsw_wrapper)
set(CMAKE_C_COMPILER_WORKS 1)
set(CMAKE_CXX_COMPILER_WORKS 1)
# Set OpenMP path for macOS
if(APPLE)
@@ -11,15 +12,9 @@ if(APPLE)
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
endif()
# Build ZeroMQ from source
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
add_subdirectory(third_party/libzmq)
# Use system ZeroMQ instead of building from source
find_package(PkgConfig REQUIRED)
pkg_check_modules(ZMQ REQUIRED libzmq)
# Add cppzmq headers
include_directories(third_party/cppzmq)
@@ -29,6 +24,7 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
add_compile_definitions(MSGPACK_NO_BOOST)
include_directories(third_party/msgpack-c/include)
# Faiss configuration - streamlined build
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
@@ -36,4 +32,24 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
# Disable additional SIMD versions to speed up compilation
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
# Additional optimization options from INSTALL.md
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
# Avoid building demos and benchmarks
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
# NEW: Tell Faiss to only build the generic version
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
# IMPORTANT: Disable building AVX versions to speed up compilation
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
add_subdirectory(third_party/faiss)

View File

@@ -1,10 +1,9 @@
import numpy as np
import os
from pathlib import Path
from typing import Dict, Any, List, Literal
import pickle
from typing import Dict, Any, List, Literal, Optional
import shutil
import time
import logging
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
@@ -16,6 +15,8 @@ from leann.interface import (
LeannBackendSearcherInterface,
)
logger = logging.getLogger(__name__)
def get_metric_map():
from . import faiss # type: ignore
@@ -57,13 +58,9 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, "wb") as f:
pickle.dump(label_map, f)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
@@ -85,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
@@ -94,11 +91,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
)
if success:
print("✅ CSR conversion successful.")
logger.info("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
print(
logger.info(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
)
else:
@@ -135,31 +132,22 @@ class HNSWSearcher(BaseSearcher):
hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned but recompute is disabled.")
hnsw_config.is_recompute = (
self.is_pruned
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load label mapping
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
with open(label_map_file, "rb") as f:
self.label_map = pickle.load(f)
def search(
self,
query: np.ndarray,
top_k: int,
zmq_port: Optional[int] = None,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
batch_size: int = 0,
**kwargs,
) -> Dict[str, Any]:
@@ -177,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
zmq_port: ZMQ port for embedding server
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
@@ -186,15 +174,14 @@ class HNSWSearcher(BaseSearcher):
"""
from . import faiss # type: ignore
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings or self.is_pruned
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
if not recompute_embeddings:
if self.is_pruned:
raise RuntimeError("Recompute is required for pruned index.")
if recompute_embeddings:
if zmq_port is None:
raise ValueError(
"zmq_port must be provided if recompute_embeddings is True"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -202,7 +189,10 @@ class HNSWSearcher(BaseSearcher):
faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW()
params.zmq_port = zmq_port
if zmq_port is not None:
params.zmq_port = (
zmq_port # C++ code won't use this if recompute_embeddings is False
)
params.efSearch = complexity
params.beam_size = beam_width
@@ -239,11 +229,7 @@ class HNSWSearcher(BaseSearcher):
)
string_labels = [
[
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
for batch_labels in labels
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -6,12 +6,22 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.1.0"
version = "0.1.4"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = ["leann-core==0.1.0", "numpy"]
dependencies = [
"leann-core==0.1.4",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
]
[tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"]
editable.mode = "redirect"
cmake.build-type = "Debug"
build.verbose = true
cmake.build-type = "Release"
build.verbose = true
build.tool-args = ["-j8"]
# CMake definitions to optimize compilation
[tool.scikit-build.cmake.define]
CMAKE_BUILD_PARALLEL_LEVEL = "8"

View File

@@ -4,16 +4,27 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.1.0"
description = "Core API and plugin system for Leann."
version = "0.1.4"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"
license = { text = "MIT" }
# All required dependencies included
dependencies = [
"numpy>=1.20.0",
"tqdm>=4.60.0"
"tqdm>=4.60.0",
"psutil>=5.8.0",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
"torch>=2.0.0",
"sentence-transformers>=2.2.0",
"llama-index-core>=0.12.0",
"python-dotenv>=1.0.0",
]
[project.scripts]
leann = "leann.cli:main"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -5,16 +5,18 @@ with the correct, original embedding logic from the user's reference code.
import json
import pickle
from leann.interface import LeannBackendSearcherInterface
import numpy as np
import time
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
import uuid
import torch
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from .chat import get_llm
import logging
logger = logging.getLogger(__name__)
def compute_embeddings(
@@ -22,7 +24,8 @@ def compute_embeddings(
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
port: Optional[int] = None,
is_build=False,
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -39,251 +42,63 @@ def compute_embeddings(
Returns:
numpy array of embeddings
"""
# Override mode for backward compatibility
if use_mlx:
mode = "mlx"
# Auto-detect mode based on model name if not explicitly set
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
mode = "openai"
if mode == "mlx":
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
elif mode == "openai":
return compute_embeddings_openai(chunks, model_name)
elif mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
chunks, model_name, use_server=use_server
)
if use_server:
# Use embedding server (for search/query)
if port is None:
raise ValueError("port is required when use_server is True")
return compute_embeddings_via_server(chunks, model_name, port=port)
else:
raise ValueError(
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
# Use direct computation (for build_index)
from .embedding_compute import (
compute_embeddings as compute_embeddings_direct,
)
return compute_embeddings_direct(
chunks,
model_name,
mode=mode,
is_build=is_build,
)
def compute_embeddings_sentence_transformers(
chunks: List[str], model_name: str, use_server: bool = True
def compute_embeddings_via_server(
chunks: List[str], model_name: str, port: int
) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
chunks: List of text chunks to embed
model_name: Name of the sentence transformer model
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
"""
if not use_server:
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
logger.info(
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
import zmq
import msgpack
import numpy as np
# Use embedding server for sentence-transformers too
# This avoids loading the model twice (once in API, once in server)
try:
# Import ZMQ client functionality and server manager
import zmq
import msgpack
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Ensure embedding server is running
port = 5557
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
server_started = server_manager.start_server(
port=port,
model_name=model_name,
embedding_mode="sentence-transformers",
enable_warmup=False,
)
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close()
context.term()
return embeddings
except Exception as e:
# Fallback to direct sentence-transformers if server connection fails
print(
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(
chunks: List[str], model_name: str
) -> np.ndarray:
"""Direct sentence-transformers computation (fallback)."""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise RuntimeError(
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
) from e
# Load model using sentence-transformers
model = SentenceTransformer(model_name)
model = model.half()
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
# use acclerater GPU or MAC GPU
if torch.cuda.is_available():
model = model.to("cuda")
elif torch.backends.mps.is_available():
model = model.to("mps")
# Generate embeddings
# give use an warning if OOM here means we need to turn down the batch size
embeddings = model.encode(
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
)
socket.close()
context.term()
return embeddings
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using OpenAI API."""
try:
import openai
import os
except ImportError as e:
raise RuntimeError(
"openai not available. Install with: uv pip install openai"
) from e
# Get API key from environment
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
)
# OpenAI has a limit on batch size and input length
max_batch_size = 100 # Conservative batch size
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(chunks), max_batch_size)
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
except ImportError:
# Fallback without progress bar
batch_iterator = range(0, len(chunks), max_batch_size)
for i in batch_iterator:
batch_chunks = chunks[i:i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_chunks)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
print(
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
)
return embeddings
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Load model and tokenizer
model, tokenizer = load(model_name)
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i:i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)
@dataclass
class SearchResult:
id: str
@@ -299,25 +114,24 @@ class PassageManager:
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not Path(index_file).exists():
raise FileNotFoundError(
f"Passage index file not found: {index_file}"
)
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"]
index_file = source["index_path"] # .idx file
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
with open(passage_file, "r", encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
@@ -328,7 +142,7 @@ class LeannBuilder:
def __init__(
self,
backend_name: str,
embedding_model: str = "facebook/contriever-msmarco",
embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers",
**backend_kwargs,
@@ -344,14 +158,12 @@ class LeannBuilder:
self.dimensions = dimensions
self.embedding_mode = embedding_mode
self.backend_kwargs = backend_kwargs
if 'mlx' in self.embedding_model:
self.embedding_mode = "mlx"
self.chunks: List[Dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None:
metadata = {}
passage_id = metadata.get("id", str(uuid.uuid4()))
passage_id = metadata.get("id", str(len(self.chunks)))
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
self.chunks.append(chunk_data)
@@ -377,10 +189,13 @@ class LeannBuilder:
with open(passages_file, "w", encoding="utf-8") as f:
try:
from tqdm import tqdm
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
chunk_iterator = tqdm(
self.chunks, desc="Writing passages", unit="chunk"
)
except ImportError:
chunk_iterator = self.chunks
for chunk in chunk_iterator:
offset = f.tell()
json.dump(
@@ -398,7 +213,11 @@ class LeannBuilder:
pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
is_build=True,
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
@@ -472,7 +291,7 @@ class LeannBuilder:
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
)
print(
logger.info(
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
)
@@ -480,7 +299,7 @@ class LeannBuilder:
if len(self.chunks) != len(ids):
# If no text chunks provided, create placeholder text entries
if not self.chunks:
print("No text chunks provided, creating placeholder entries...")
logger.info("No text chunks provided, creating placeholder entries...")
for id_val in ids:
self.add_text(
f"Document {id_val}",
@@ -555,15 +374,19 @@ class LeannBuilder:
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
print(f"Index built successfully from precomputed embeddings: {index_path}")
logger.info(
f"Index built successfully from precomputed embeddings: {index_path}"
)
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
with open(meta_path_str, "r", encoding="utf-8") as f:
self.meta_path_str = f"{index_path}.meta.json"
if not Path(self.meta_path_str).exists():
raise FileNotFoundError(
f"Leann metadata file not found at {self.meta_path_str}"
)
with open(self.meta_path_str, "r", encoding="utf-8") as f:
self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"]
@@ -571,16 +394,15 @@ class LeannSearcher:
self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers"
)
# Backward compatibility with use_mlx
if self.meta_data.get("use_mlx", False):
self.embedding_mode = "mlx"
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
def search(
self,
@@ -589,26 +411,39 @@ class LeannSearcher:
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
expected_zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'")
print(f" Top_k: {top_k}")
print(f" Additional kwargs: {kwargs}")
logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}")
logger.info(f" Additional kwargs: {kwargs}")
# Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed
import time
zmq_port = None
start_time = time.time()
if recompute_embeddings:
zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str,
port=expected_zmq_port,
**kwargs,
)
del expected_zmq_port
zmq_time = time.time() - start_time
logger.info(f" Launching server time: {zmq_time} seconds")
start_time = time.time()
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
print(f" Generated embedding shape: {query_embedding.shape}")
query_embedding = self.backend_impl.compute_query_embedding(
query,
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
print(f" Embedding time: {embedding_time} seconds")
logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
results = self.backend_impl.search(
@@ -623,14 +458,14 @@ class LeannSearcher:
**kwargs,
)
search_time = time.time() - start_time
print(f" Search time: {search_time} seconds")
print(
logger.info(f" Search time: {search_time} seconds")
logger.info(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
enriched_results = []
if "labels" in results and "distances" in results:
print(f" Processing {len(results['labels'][0])} passage IDs:")
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0])
):
@@ -644,15 +479,15 @@ class LeannSearcher:
metadata=passage_data.get("metadata", {}),
)
)
print(
logger.info(
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
)
except KeyError:
print(
logger.error(
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
)
print(f" Final enriched results: {len(enriched_results)} passages")
logger.info(f" Final enriched results: {len(enriched_results)} passages")
return enriched_results
@@ -674,10 +509,10 @@ class LeannChat:
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
llm_kwargs: Optional[Dict[str, Any]] = None,
expected_zmq_port: int = 5557,
**search_kwargs,
):
if llm_kwargs is None:
@@ -691,7 +526,7 @@ class LeannChat:
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
expected_zmq_port=expected_zmq_port,
**search_kwargs,
)
context = "\n\n".join([r.text for r in results])

View File

@@ -9,6 +9,7 @@ from typing import Dict, Any, Optional, List
import logging
import os
import difflib
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -28,6 +29,68 @@ def check_ollama_models() -> List[str]:
return []
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
(model_exists, available_tags): bool and list of matching tags
"""
try:
import requests
import re
# Split model name and tag
if ':' in model_name:
base_model, requested_tag = model_name.split(':', 1)
else:
base_model, requested_tag = model_name, None
# First check if base model exists in library
library_response = requests.get("https://ollama.com/library", timeout=8)
if library_response.status_code != 200:
return True, [] # Assume exists if can't check
# Extract model names from library page
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
if base_model not in models_in_library:
return False, [] # Base model doesn't exist
# If base model exists, get available tags
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
if tags_response.status_code != 200:
return True, [] # Base model exists but can't get tags
# Extract tags for this model - be more specific to avoid HTML artifacts
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
raw_tags = re.findall(tag_pattern, tags_response.text)
# Clean up tags - remove HTML artifacts and duplicates
available_tags = []
seen = set()
for tag in raw_tags:
# Skip if it looks like HTML (contains < or >)
if '<' in tag or '>' in tag:
continue
if tag not in seen:
seen.add(tag)
available_tags.append(tag)
# Check if exact model exists
if requested_tag is None:
# User just requested base model, suggest tags
return True, available_tags[:10] # Return up to 10 tags
else:
exact_match = model_name in available_tags
return exact_match, available_tags[:10]
except Exception:
pass
# If scraping fails, assume model might exist (don't block user)
return True, []
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
@@ -243,24 +306,66 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
if llm_type == "ollama":
available_models = check_ollama_models()
if available_models and model_name not in available_models:
# Use intelligent fuzzy search based on locally installed models
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
if suggestions:
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
error_msg += "\n\nYour installed models:\n"
for i, model in enumerate(available_models[:8], 1):
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\nTo list all models: ollama list"
error_msg += "\nTo download a new model: ollama pull <model_name>"
error_msg += "\nBrowse models: https://ollama.com/library"
# Check if the model exists remotely and get available tags
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
if model_exists_remotely and model_name in available_tags:
# Exact model exists remotely - suggest pulling it
error_msg += f"\n\nTo install the requested model:\n"
error_msg += f" ollama pull {model_name}\n"
# Show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
elif model_exists_remotely and available_tags:
# Base model exists but requested tag doesn't - suggest correct tags
base_model = model_name.split(':')[0]
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
for i, tag in enumerate(available_tags[:8], 1):
error_msg += f" {i}. ollama pull {tag}\n"
if len(available_tags) > 8:
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
# Also show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
# Model doesn't exist remotely - show fuzzy suggestions
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
if suggestions:
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
error_msg += "\n\nYour installed models:\n"
for i, model in enumerate(available_models[:8], 1):
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\n\nCommands:"
error_msg += "\n ollama list # List installed models"
if model_exists_remotely and available_tags:
if model_name in available_tags:
error_msg += f"\n ollama pull {model_name} # Install requested model"
else:
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
error_msg += "\n https://ollama.com/library # Browse available models"
return error_msg
elif llm_type == "hf":
@@ -375,8 +480,9 @@ class OllamaChat(LLMInterface):
"stream": False, # Keep it simple for now
"options": kwargs,
}
logger.info(f"Sending request to Ollama: {payload}")
logger.debug(f"Sending request to Ollama: {payload}")
try:
logger.info(f"Sending request to Ollama and waiting for response...")
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
@@ -396,7 +502,7 @@ class OllamaChat(LLMInterface):
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models."""
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
@@ -407,7 +513,7 @@ class HFChat(LLMInterface):
raise ValueError(model_error)
try:
from transformers.pipelines import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
except ImportError:
raise ImportError(
@@ -416,54 +522,101 @@ class HFChat(LLMInterface):
# Auto-detect device
if torch.cuda.is_available():
device = "cuda"
self.device = "cuda"
logger.info("CUDA is available. Using GPU.")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
self.device = "mps"
logger.info("MPS is available. Using Apple Silicon GPU.")
else:
device = "cpu"
self.device = "cpu"
logger.info("No GPU detected. Using CPU.")
self.pipeline = pipeline("text-generation", model=model_name, device=device)
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True
)
# Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model):
self.model = self.model.to(self.device)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def ask(self, prompt: str, **kwargs) -> str:
# Map OpenAI-style arguments to Hugging Face equivalents
if "max_tokens" in kwargs:
# Prefer user-provided max_new_tokens if both are present
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
# Remove the unsupported key to avoid errors in Transformers
kwargs.pop("max_tokens")
print('kwargs in HF: ', kwargs)
# Check if this is a Qwen model and add /no_think by default
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
# For Qwen models, automatically add /no_think to the prompt
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
prompt = prompt + " /no_think"
# Prepare chat template
messages = [{"role": "user", "content": prompt}]
# Apply chat template if available
if hasattr(self.tokenizer, "apply_chat_template"):
try:
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
logger.warning(f"Chat template failed, using raw prompt: {e}")
formatted_prompt = prompt
else:
# Fallback for models without chat template
formatted_prompt = prompt
# Handle temperature=0 edge-case for greedy decoding
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
# Remove unsupported zero temperature and use deterministic generation
kwargs.pop("temperature")
kwargs.setdefault("do_sample", False)
# Tokenize input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
# Move inputs to device
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Sensible defaults for text generation
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
# Set generation parameters
generation_config = {
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"do_sample": kwargs.get("temperature", 0.7) > 0,
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Handle temperature=0 for greedy decoding
if generation_config["temperature"] == 0.0:
generation_config["do_sample"] = False
generation_config.pop("temperature")
# Handle different response formats from transformers
if isinstance(results, list) and len(results) > 0:
generated_text = (
results[0].get("generated_text", "")
if isinstance(results[0], dict)
else str(results[0])
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
**generation_config
)
else:
generated_text = str(results)
# Extract only the newly generated portion by removing the original prompt
if isinstance(generated_text, str) and generated_text.startswith(prompt):
response = generated_text[len(prompt) :].strip()
else:
# Fallback: return the full response if prompt removal fails
response = str(generated_text)
return response
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
class OpenAIChat(LLMInterface):

View File

@@ -0,0 +1,315 @@
import argparse
import asyncio
from pathlib import Path
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from .api import LeannBuilder, LeannSearcher, LeannChat
class LeannCLI:
def __init__(self):
self.indexes_dir = Path.home() / ".leann" / "indexes"
self.indexes_dir.mkdir(parents=True, exist_ok=True)
self.node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
def get_index_path(self, index_name: str) -> str:
index_dir = self.indexes_dir / index_name
return str(index_dir / "documents.leann")
def index_exists(self, index_name: str) -> bool:
index_dir = self.indexes_dir / index_name
meta_file = index_dir / "documents.leann.meta.json"
return meta_file.exists()
def create_parser(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="leann",
description="LEANN - Local Enhanced AI Navigation",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
leann build my-docs --docs ./documents # Build index named my-docs
leann search my-docs "query" # Search in my-docs index
leann ask my-docs "question" # Ask my-docs index
leann list # List all stored indexes
""",
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Build command
build_parser = subparsers.add_parser("build", help="Build document index")
build_parser.add_argument("index_name", help="Index name")
build_parser.add_argument(
"--docs", type=str, required=True, help="Documents directory"
)
build_parser.add_argument(
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
)
build_parser.add_argument(
"--embedding-model", type=str, default="facebook/contriever"
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild"
)
build_parser.add_argument("--graph-degree", type=int, default=32)
build_parser.add_argument("--complexity", type=int, default=64)
build_parser.add_argument("--num-threads", type=int, default=1)
build_parser.add_argument("--compact", action="store_true", default=True)
build_parser.add_argument("--recompute", action="store_true", default=True)
# Search command
search_parser = subparsers.add_parser("search", help="Search documents")
search_parser.add_argument("index_name", help="Index name")
search_parser.add_argument("query", help="Search query")
search_parser.add_argument("--top-k", type=int, default=5)
search_parser.add_argument("--complexity", type=int, default=64)
search_parser.add_argument("--beam-width", type=int, default=1)
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
search_parser.add_argument("--recompute-embeddings", action="store_true")
search_parser.add_argument(
"--pruning-strategy",
choices=["global", "local", "proportional"],
default="global",
)
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument(
"--llm",
type=str,
default="ollama",
choices=["simulated", "ollama", "hf", "openai"],
)
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
ask_parser.add_argument("--interactive", "-i", action="store_true")
ask_parser.add_argument("--top-k", type=int, default=20)
ask_parser.add_argument("--complexity", type=int, default=32)
ask_parser.add_argument("--beam-width", type=int, default=1)
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
ask_parser.add_argument("--recompute-embeddings", action="store_true")
ask_parser.add_argument(
"--pruning-strategy",
choices=["global", "local", "proportional"],
default="global",
)
# List command
list_parser = subparsers.add_parser("list", help="List all indexes")
return parser
def list_indexes(self):
print("Stored LEANN indexes:")
if not self.indexes_dir.exists():
print(
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
)
return
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
if not index_dirs:
print(
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
)
return
print(f"Found {len(index_dirs)} indexes:")
for i, index_dir in enumerate(index_dirs, 1):
index_name = index_dir.name
status = "" if self.index_exists(index_name) else ""
print(f" {i}. {index_name} [{status}]")
if self.index_exists(index_name):
meta_file = index_dir / "documents.leann.meta.json"
size_mb = sum(
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
) / (1024 * 1024)
print(f" Size: {size_mb:.1f} MB")
if index_dirs:
example_name = index_dirs[0].name
print(f"\nUsage:")
print(f' leann search {example_name} "your query"')
print(f" leann ask {example_name} --interactive")
def load_documents(self, docs_dir: str):
print(f"Loading documents from {docs_dir}...")
documents = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md", ".docx"],
).load_data(show_progress=True)
all_texts = []
for doc in documents:
nodes = self.node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
return all_texts
async def build_index(self, args):
docs_dir = args.docs
index_name = args.index_name
index_dir = self.indexes_dir / index_name
index_path = self.get_index_path(index_name)
if index_dir.exists() and not args.force:
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
return
all_texts = self.load_documents(docs_dir)
if not all_texts:
print("No documents found")
return
index_dir.mkdir(parents=True, exist_ok=True)
print(f"Building index '{index_name}' with {args.backend} backend...")
builder = LeannBuilder(
backend_name=args.backend,
embedding_model=args.embedding_model,
graph_degree=args.graph_degree,
complexity=args.complexity,
is_compact=args.compact,
is_recompute=args.recompute,
num_threads=args.num_threads,
)
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"Index built at {index_path}")
async def search_documents(self, args):
index_name = args.index_name
query = args.query
index_path = self.get_index_path(index_name)
if not self.index_exists(index_name):
print(
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
)
return
searcher = LeannSearcher(index_path=index_path)
results = searcher.search(
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
)
print(f"Search results for '{query}' (top {len(results)}):")
for i, result in enumerate(results, 1):
print(f"{i}. Score: {result.score:.3f}")
print(f" {result.text[:200]}...")
print()
async def ask_questions(self, args):
index_name = args.index_name
index_path = self.get_index_path(index_name)
if not self.index_exists(index_name):
print(
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
)
return
print(f"Starting chat with index '{index_name}'...")
print(f"Using {args.model} ({args.llm})")
llm_config = {"type": args.llm, "model": args.model}
if args.llm == "ollama":
llm_config["host"] = args.host
chat = LeannChat(index_path=index_path, llm_config=llm_config)
if args.interactive:
print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40)
while True:
user_input = input("\nYou: ").strip()
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
if not user_input:
continue
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
)
print(f"LEANN: {response}")
else:
query = input("Enter your question: ").strip()
if query:
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
)
print(f"LEANN: {response}")
async def run(self, args=None):
parser = self.create_parser()
if args is None:
args = parser.parse_args()
if not args.command:
parser.print_help()
return
if args.command == "list":
self.list_indexes()
elif args.command == "build":
await self.build_index(args)
elif args.command == "search":
await self.search_documents(args)
elif args.command == "ask":
await self.ask_questions(args)
else:
parser.print_help()
def main():
import dotenv
dotenv.load_dotenv()
cli = LeannCLI()
asyncio.run(cli.run())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,375 @@
"""
Unified embedding computation module
Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import numpy as np
import torch
from typing import List, Dict, Any
import logging
import os
# Set up logger with proper level
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Global model cache to avoid repeated loading
_model_cache: Dict[str, Any] = {}
def compute_embeddings(
texts: List[str],
model_name: str,
mode: str = "sentence-transformers",
is_build: bool = False,
batch_size: int = 32,
adaptive_optimization: bool = True,
) -> np.ndarray:
"""
Unified embedding computation entry point
Args:
texts: List of texts to compute embeddings for
model_name: Model name
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
is_build: Whether this is a build operation (shows progress bar)
batch_size: Batch size for processing
adaptive_optimization: Whether to use adaptive optimization based on batch size
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
texts,
model_name,
is_build=is_build,
batch_size=batch_size,
adaptive_optimization=adaptive_optimization,
)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
else:
raise ValueError(f"Unsupported embedding mode: {mode}")
def compute_embeddings_sentence_transformers(
texts: List[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
batch_size: int = 32,
is_build: bool = False,
adaptive_optimization: bool = True,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
Args:
texts: List of texts to compute embeddings for
model_name: Model name
use_fp16: Whether to use FP16 precision
device: Device to use ('auto', 'cuda', 'mps', 'cpu')
batch_size: Batch size for processing
is_build: Whether this is a build operation (shows progress bar)
adaptive_optimization: Whether to use adaptive optimization based on batch size
"""
# Handle empty input
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
logger.info(
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
)
# Auto-detect device
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# Apply optimizations based on benchmark results
if adaptive_optimization:
# Use optimal batch_size constants for different devices based on benchmark results
if device == "mps":
batch_size = 128 # MPS optimal batch size from benchmark
if model_name == "Qwen/Qwen3-Embedding-0.6B":
batch_size = 32
elif device == "cuda":
batch_size = 256 # CUDA optimal batch size
# Keep original batch_size for CPU
# Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
# Check if model is already cached
if cache_key in _model_cache:
logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key]
else:
logger.info(
f"Loading and caching optimized SentenceTransformer model: {model_name}"
)
from sentence_transformers import SentenceTransformer
logger.info(f"Using device: {device}")
# Apply hardware optimizations
if device == "cuda":
# TODO: Haven't tested this yet
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.cuda.set_per_process_memory_fraction(0.9)
elif device == "mps":
try:
if hasattr(torch.mps, "set_per_process_memory_fraction"):
torch.mps.set_per_process_memory_fraction(0.9)
except AttributeError:
logger.warning(
"Some MPS optimizations not available in this PyTorch version"
)
elif device == "cpu":
# TODO: Haven't tested this yet
torch.set_num_threads(min(8, os.cpu_count() or 4))
try:
torch.backends.mkldnn.enabled = True
except AttributeError:
pass
# Prepare optimized model and tokenizer parameters
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True,
"attn_implementation": "eager", # Use eager attention for speed
}
tokenizer_kwargs = {
"use_fast": True,
"padding": True,
"truncation": True,
}
try:
# Try local loading first
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
logger.info("Model loaded successfully! (local + optimized)")
except Exception as e:
logger.warning(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + optimized)")
# Apply additional optimizations based on mode
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
logger.info(f"Applied FP16 precision: {model_name}")
except Exception as e:
logger.warning(f"FP16 optimization failed: {e}")
# Apply torch.compile optimization
if device in ["cuda", "mps"]:
try:
model = torch.compile(model, mode="reduce-overhead", dynamic=True)
logger.info(f"Applied torch.compile optimization: {model_name}")
except Exception as e:
logger.warning(f"torch.compile optimization failed: {e}")
# Set model to eval mode and disable gradients for inference
model.eval()
for param in model.parameters():
param.requires_grad_(False)
# Cache the model
_model_cache[cache_key] = model
logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
# Use torch.inference_mode for optimal performance
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
logger.info(
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(
f"Detected NaN or Inf values in embeddings, model: {model_name}"
)
return embeddings
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import openai
import os
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = "openai_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=api_key)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
# OpenAI has limits on batch size and input length
max_batch_size = 100 # Conservative batch size
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
logger.error(f"Batch {i} failed: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info(
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
return embeddings
def compute_embeddings_mlx(
chunks: List[str], model_name: str, batch_size: int = 16
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
logger.info(
f"Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Cache MLX model and tokenizer
cache_key = f"mlx_{model_name}"
if cache_key in _model_cache:
logger.info(f"Using cached MLX model: {model_name}")
model, tokenizer = _model_cache[cache_key]
else:
logger.info(f"Loading and caching MLX model: {model_name}")
model, tokenizer = load(model_name)
_model_cache[cache_key] = (model, tokenizer)
logger.info(f"MLX model cached: {cache_key}")
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
)
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i : i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)

View File

@@ -1,14 +1,21 @@
import threading
import time
import atexit
import socket
import subprocess
import sys
import zmq
import msgpack
import os
import logging
from pathlib import Path
from typing import Optional
import select
import psutil
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format="%(levelname)s - %(name)s - %(message)s",
)
logger = logging.getLogger(__name__)
def _check_port(port: int) -> bool:
@@ -17,151 +24,135 @@ def _check_port(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""
Check if the existing server on the port is using the correct meta file.
Returns True if the server has the right meta path, False otherwise.
Check if the process using the port matches our expected model and passages file.
Returns True if matches, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
socket.connect(f"tcp://localhost:{port}")
for proc in psutil.process_iter(["pid", "cmdline"]):
if not _is_process_listening_on_port(proc, port):
continue
# Send a special control message to query the server's meta path
control_request = ["__QUERY_META_PATH__"]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
cmdline = proc.info["cmdline"]
if not cmdline:
continue
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the response contains the meta path and if it matches
if isinstance(response, list) and len(response) > 0:
server_meta_path = response[0]
# Normalize paths for comparison
expected_path = Path(expected_meta_path).resolve()
server_path = Path(server_meta_path).resolve() if server_meta_path else None
return server_path == expected_path
return _check_cmdline_matches_config(
cmdline, port, expected_model, expected_passages_file
)
logger.debug(f"No process found listening on port {port}")
return False
except Exception as e:
print(f"WARNING: Could not query server meta path on port {port}: {e}")
logger.warning(f"Could not check process on port {port}: {e}")
return False
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
"""
Send a control message to update the server's meta path.
Returns True if successful, False otherwise.
"""
def _is_process_listening_on_port(proc, port: int) -> bool:
"""Check if a process is listening on the given port."""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
socket.connect(f"tcp://localhost:{port}")
# Send a control message to update the meta path
control_request = ["__UPDATE_META_PATH__", new_meta_path]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
connections = proc.net_connections()
for conn in connections:
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
return True
return False
except Exception as e:
print(f"ERROR: Could not update server meta path on port {port}: {e}")
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
return False
def _check_server_model(port: int, expected_model: str) -> bool:
def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
logger.debug(f"Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(
server_type in cmdline_str
for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server",
]
)
if not is_embedding_server:
logger.debug(f"Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
logger.debug(
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
"""Check if the command line contains the expected passages file."""
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
return actual_path == expected_path
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]:
"""
Check if the existing server on the port is using the correct embedding model.
Returns True if the server has the right model, False otherwise.
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
socket.connect(f"tcp://localhost:{port}")
for port in range(start_port, start_port + max_attempts):
if not _check_port(port):
# Port is available
return port, False
# Send a special control message to query the server's model
control_request = ["__QUERY_MODEL__"]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file):
logger.info(f"Found compatible server on port {port}")
return port, True
else:
logger.info(f"Port {port} has incompatible server, trying next port...")
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the response contains the model name and if it matches
if isinstance(response, list) and len(response) > 0:
server_model = response[0]
return server_model == expected_model
return False
except Exception as e:
print(f"WARNING: Could not query server model on port {port}: {e}")
return False
def _update_server_model(port: int, new_model: str) -> bool:
"""
Send a control message to update the server's embedding model.
Returns True if successful, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
socket.connect(f"tcp://localhost:{port}")
# Send a control message to update the model
control_request = ["__UPDATE_MODEL__", new_model]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
return False
except Exception as e:
print(f"ERROR: Could not update server model on port {port}: {e}")
return False
raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
)
class EmbeddingServerManager:
"""
A generic manager for handling the lifecycle of a backend-specific embedding server process.
A simplified manager for embedding server processes that avoids complex update mechanisms.
"""
def __init__(self, backend_module_name: str):
@@ -175,246 +166,183 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
atexit.register(self.stop_server)
self._atexit_registered = False
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
def start_server(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""
Starts the embedding server process.
Args:
port (int): The ZMQ port for the server.
port (int): The preferred ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
**kwargs: Additional arguments for the server.
Returns:
bool: True if the server is started successfully or already running, False otherwise.
tuple[bool, int]: (success, actual_port_used)
"""
if self.server_process and self.server_process.poll() is None:
# Even if we have a running process, check if model/meta path match
if self.server_port is not None:
port_in_use = _check_port(self.server_port)
if port_in_use:
print(
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
)
passages_file = kwargs.get("passages_file")
assert isinstance(passages_file, str), "passages_file must be a string"
# Check model compatibility
model_matches = _check_server_model(self.server_port, model_name)
if model_matches:
print(
f"✅ Existing server already using correct model: {model_name}"
)
# Still check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
print(
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(self.server_port, model_name):
print(
"❌ Failed to update existing server model. Restarting server..."
)
self.stop_server()
# Continue to start new server below
else:
print(
f"✅ Successfully updated existing server model to: {model_name}"
)
# Check if we have a compatible running server
if self._has_compatible_running_server(model_name, passages_file):
assert self.server_port is not None, (
"a compatible running server should set server_port"
)
return True, self.server_port
# Also check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
# Find available port (compatible or free)
try:
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
except RuntimeError as e:
logger.error(str(e))
return False, port
return True
else:
# Server process exists but port not responding - restart
print("⚠️ Server process exists but not responding. Restarting...")
self.stop_server()
# Continue to start new server below
else:
# No port stored - restart
print("⚠️ No port information stored. Restarting server...")
self.stop_server()
# Continue to start new server below
if is_compatible:
logger.info(f"Using existing compatible server on port {actual_port}")
self.server_port = actual_port
self.server_process = None # We don't own this process
return True, actual_port
if _check_port(port):
# Port is in use, check if it's using the correct meta file and model
passages_file = kwargs.get("passages_file")
if actual_port != port:
logger.info(f"Using port {actual_port} instead of {port}")
print(f"INFO: Port {port} is in use. Checking server compatibility...")
# Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
# Check model compatibility first
model_matches = _check_server_model(port, model_name)
if model_matches:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
else:
print(
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(port, model_name):
raise RuntimeError(
f"❌ Failed to update server model to {model_name}. Consider using a different port."
)
print(f"✅ Successfully updated server model to: {model_name}")
def _has_compatible_running_server(
self, model_name: str, passages_file: str
) -> bool:
"""Check if we have a compatible running server."""
if not (
self.server_process
and self.server_process.poll() is None
and self.server_port
):
return False
# Check meta path compatibility if provided
if passages_file and str(passages_file).endswith(".meta.json"):
meta_matches = _check_server_meta_path(port, str(passages_file))
if not meta_matches:
print(
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
)
if not _update_server_meta_path(port, str(passages_file)):
raise RuntimeError(
"❌ Failed to update server meta path. This may cause data synchronization issues."
)
print(
f"✅ Successfully updated server meta path to: {passages_file}"
)
else:
print(
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
)
print(f"✅ Server on port {port} is compatible and ready to use.")
if _check_process_matches_config(self.server_port, model_name, passages_file):
logger.info(
f"Existing server process (PID {self.server_process.pid}) is compatible"
)
return True
print(
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
logger.info(
"Existing server process is incompatible. Should start a new server."
)
return False
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...")
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
command = [
sys.executable,
"-m",
self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
]
# Add extra arguments for specific backends
if "passages_file" in kwargs and kwargs["passages_file"]:
command.extend(["--passages-file", str(kwargs["passages_file"])])
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
# command.extend(["--distance-metric", kwargs["distance_metric"]])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
command.extend(["--disable-warmup"])
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
text=True,
encoding="utf-8",
bufsize=1, # Line buffered
universal_newlines=True,
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print("✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print(
"❌ ERROR: Server process terminated unexpectedly during startup."
)
self._print_recent_output()
return False
time.sleep(wait_interval)
print(
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
)
self.stop_server()
return False
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
except Exception as e:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
logger.error(f"Failed to start embedding server: {e}")
return False, port
def _print_recent_output(self):
"""Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout:
return
try:
# Read any available output
def _build_server_command(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> list:
"""Build the command to start the embedding server."""
command = [
sys.executable,
"-m",
self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
]
if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read()
if output:
print(f"[{self.backend_module_name} OUTPUT]: {output}")
except Exception as e:
print(f"Error reading server output: {e}")
if kwargs.get("passages_file"):
command.extend(["--passages-file", str(kwargs["passages_file"])])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
def _log_monitor(self):
"""Monitors and prints the server's stdout and stderr."""
if not self.server_process:
return
try:
if self.server_process.stdout:
while True:
line = self.server_process.stdout.readline()
if not line:
break
print(
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
)
except Exception as e:
print(f"Log monitor error: {e}")
return command
def _launch_server_process(self, command: list, port: int) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}")
# Let server output go directly to console
# The server will respect LEANN_LOG_LEVEL environment variable
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=None, # Direct to console
stderr=None, # Direct to console
)
self.server_port = port
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Use a lambda to avoid issues with bound methods
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Embedding server is ready!")
return True, port
if self.server_process and self.server_process.poll() is not None:
logger.error("Server terminated during startup.")
return False, port
time.sleep(wait_interval)
logger.error(f"Server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port
def stop_server(self):
"""Stops the embedding server process if it's running."""
if self.server_process and self.server_process.poll() is None:
print(
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
if not self.server_process:
return
if self.server_process.poll() is not None:
# Process already terminated
self.server_process = None
return
logger.info(
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
logger.warning(
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
)
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired:
print(
"WARNING: Server process did not terminate gracefully, killing it."
)
self.server_process.kill()
self.server_process.kill()
# Clean up process resources to prevent resource tracker warnings
try:
self.server_process.wait() # Ensure process is fully cleaned up
except Exception:
pass
self.server_process = None

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
import numpy as np
from typing import Dict, Any, List, Literal
from typing import Dict, Any, List, Literal, Optional
class LeannBackendBuilderInterface(ABC):
@@ -34,6 +34,13 @@ class LeannBackendSearcherInterface(ABC):
"""
pass
@abstractmethod
def _ensure_server_running(
self, passages_source_file: str, port: Optional[int], **kwargs
) -> int:
"""Ensure server is running"""
pass
@abstractmethod
def search(
self,
@@ -44,7 +51,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
zmq_port: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
"""Search for nearest neighbors
@@ -57,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
**kwargs: Backend-specific parameters
Returns:
@@ -67,7 +74,10 @@ class LeannBackendSearcherInterface(ABC):
@abstractmethod
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
self,
query: str,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
) -> np.ndarray:
"""Compute embedding for a query string

View File

@@ -7,30 +7,37 @@ import importlib.metadata
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
def register_backend(name: str):
"""A decorator to register a new backend class."""
def decorator(cls):
print(f"INFO: Registering backend '{name}'")
BACKEND_REGISTRY[name] = cls
return cls
return decorator
def autodiscover_backends():
"""Automatically discovers and imports all 'leann-backend-*' packages."""
print("INFO: Starting backend auto-discovery...")
# print("INFO: Starting backend auto-discovery...")
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata['name']
if dist_name.startswith('leann-backend-'):
backend_module_name = dist_name.replace('-', '_')
dist_name = dist.metadata["name"]
if dist_name.startswith("leann-backend-"):
backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
for backend_module_name in sorted(
discovered_backends
): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError as e:
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
print("INFO: Backend auto-discovery finished.")
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
pass
# print("INFO: Backend auto-discovery finished.")

View File

@@ -1,8 +1,7 @@
import json
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, Literal
from typing import Dict, Any, Literal, Optional
import numpy as np
@@ -43,10 +42,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"WARNING: embedding_model not found in meta.json. Recompute will fail."
)
self.label_map = self._load_label_map()
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name
backend_module_name=backend_module_name,
)
def _load_meta(self) -> Dict[str, Any]:
@@ -58,17 +57,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f)
def _load_label_map(self) -> Dict[int, str]:
"""Loads the mapping from integer IDs to string IDs."""
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, "rb") as f:
return pickle.load(f)
def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs
) -> None:
) -> int:
"""
Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses.
@@ -78,21 +69,26 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"Cannot use recompute mode without 'embedding_model' in meta.json."
)
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
server_started = self.embedding_server_manager.start_server(
server_started, actual_port = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
embedding_mode=self.embedding_mode,
passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"),
embedding_mode=embedding_mode,
enable_warmup=kwargs.get("enable_warmup", False),
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
raise RuntimeError(
f"Failed to start embedding server on port {actual_port}"
)
return actual_port
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
self,
query: str,
use_server_if_available: bool = True,
zmq_port: int = 5557,
) -> np.ndarray:
"""
Compute embedding for a query string.
@@ -106,12 +102,20 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
Query embedding as numpy array
"""
# Try to use embedding server if available and requested
if (
use_server_if_available
and self.embedding_server_manager
and self.embedding_server_manager.server_process
):
if use_server_if_available:
try:
# TODO: Maybe we can directly use this port here?
# For this internal method, it's ok to assume that the server is running
# on that port?
# Ensure we have a server with passages_file for compatibility
passages_source_file = (
self.index_dir / f"{self.index_path.name}.meta.json"
)
zmq_port = self._ensure_server_running(
str(passages_source_file), zmq_port
)
return self._compute_embedding_via_server([query], zmq_port)[
0:1
] # Return (1, D) shape
@@ -120,7 +124,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("⏭️ Falling back to direct model loading...")
# Fallback to direct computation
from .api import compute_embeddings
from .embedding_compute import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode)
@@ -167,7 +171,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
zmq_port: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
"""
@@ -181,7 +185,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
Returns:

40
packages/leann/README.md Normal file
View File

@@ -0,0 +1,40 @@
# LEANN - The smallest vector index in the world
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
## Installation
```bash
# Default installation (HNSW backend, recommended)
uv pip install leann
# With DiskANN backend (for large-scale deployments)
uv pip install leann[diskann]
```
## Quick Start
```python
from leann import LeannBuilder, LeannSearcher, LeannChat
# Build an index
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.build_index("my_index.leann")
# Search
searcher = LeannSearcher("my_index.leann")
results = searcher.search("storage savings", top_k=3)
# Chat with your data
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
response = chat.ask("How much storage does LEANN save?")
```
## Documentation
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
## License
MIT License

View File

@@ -0,0 +1,12 @@
"""
LEANN - Low-storage Embedding Approximation for Neural Networks
A revolutionary vector database that democratizes personal AI.
"""
__version__ = "0.1.0"
# Re-export main API from leann-core
from leann_core import LeannBuilder, LeannSearcher, LeannChat
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat"]

View File

@@ -0,0 +1,42 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.1.4"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"
license = { text = "MIT" }
authors = [
{ name = "LEANN Team" }
]
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
# Default installation: core + hnsw
dependencies = [
"leann-core>=0.1.0",
"leann-backend-hnsw>=0.1.0",
]
[project.optional-dependencies]
diskann = [
"leann-backend-diskann>=0.1.0",
]
[project.urls]
Homepage = "https://github.com/yourusername/leann"
Documentation = "https://leann.readthedocs.io"
Repository = "https://github.com/yourusername/leann"
Issues = "https://github.com/yourusername/leann/issues"

View File

@@ -9,7 +9,6 @@ requires-python = ">=3.10"
dependencies = [
"leann-core",
"leann-backend-diskann",
"leann-backend-hnsw",
"numpy>=1.26.0",
"torch",
@@ -34,8 +33,9 @@ dependencies = [
"msgpack>=1.1.1",
"llama-index-vector-stores-faiss>=0.4.0",
"llama-index-embeddings-huggingface>=0.5.5",
"mlx>=0.26.3",
"mlx-lm>=0.26.0",
"mlx>=0.26.3; sys_platform == 'darwin'",
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
"psutil>=5.8.0",
]
[project.optional-dependencies]
@@ -48,6 +48,10 @@ dev = [
"huggingface-hub>=0.20.0",
]
diskann = [
"leann-backend-diskann",
]
[tool.setuptools]
py-modules = []

87
scripts/build_and_test.sh Executable file
View File

@@ -0,0 +1,87 @@
#!/bin/bash
# Manual build and test script for local testing
PACKAGE=${1:-"all"} # Default to all packages
echo "Building package: $PACKAGE"
# Ensure we're in a virtual environment
if [ -z "$VIRTUAL_ENV" ]; then
echo "Error: Please activate a virtual environment first"
echo "Run: source .venv/bin/activate (or .venv/bin/activate.fish for fish shell)"
exit 1
fi
# Install build tools
uv pip install build twine delocate auditwheel scikit-build-core cmake pybind11 numpy
build_package() {
local package_dir=$1
local package_name=$(basename $package_dir)
echo "Building $package_name..."
cd $package_dir
# Clean previous builds
rm -rf dist/ build/ _skbuild/
# Build directly with pip wheel (avoids sdist issues)
pip wheel . --no-deps -w dist
# Repair wheel for binary packages
if [[ "$package_name" != "leann-core" ]] && [[ "$package_name" != "leann" ]]; then
if [[ "$OSTYPE" == "darwin"* ]]; then
# For macOS
for wheel in dist/*.whl; do
if [[ -f "$wheel" ]]; then
delocate-wheel -w dist_repaired -v "$wheel"
fi
done
if [[ -d dist_repaired ]]; then
rm -rf dist/*.whl
mv dist_repaired/*.whl dist/
rmdir dist_repaired
fi
else
# For Linux
for wheel in dist/*.whl; do
if [[ -f "$wheel" ]]; then
auditwheel repair "$wheel" -w dist_repaired
fi
done
if [[ -d dist_repaired ]]; then
rm -rf dist/*.whl
mv dist_repaired/*.whl dist/
rmdir dist_repaired
fi
fi
fi
echo "Built wheels in $package_dir/dist/"
ls -la dist/
cd - > /dev/null
}
# Build specific package or all
if [ "$PACKAGE" == "diskann" ]; then
build_package "packages/leann-backend-diskann"
elif [ "$PACKAGE" == "hnsw" ]; then
build_package "packages/leann-backend-hnsw"
elif [ "$PACKAGE" == "core" ]; then
build_package "packages/leann-core"
elif [ "$PACKAGE" == "meta" ]; then
build_package "packages/leann"
elif [ "$PACKAGE" == "all" ]; then
build_package "packages/leann-core"
build_package "packages/leann-backend-hnsw"
build_package "packages/leann-backend-diskann"
build_package "packages/leann"
else
echo "Unknown package: $PACKAGE"
echo "Usage: $0 [diskann|hnsw|core|meta|all]"
exit 1
fi
echo -e "\nBuild complete! Test with:"
echo "uv pip install packages/*/dist/*.whl"

31
scripts/bump_version.sh Executable file
View File

@@ -0,0 +1,31 @@
#!/bin/bash
if [ $# -eq 0 ]; then
echo "Usage: $0 <new_version>"
exit 1
fi
NEW_VERSION=$1
# Get the directory where the script is located
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
PROJECT_ROOT="$( cd "$SCRIPT_DIR/.." && pwd )"
# Update all pyproject.toml files
echo "Updating versions in $PROJECT_ROOT/packages/"
# Use different sed syntax for macOS vs Linux
if [[ "$OSTYPE" == "darwin"* ]]; then
# Update version fields
find "$PROJECT_ROOT/packages" -name "pyproject.toml" -exec sed -i '' "s/version = \".*\"/version = \"$NEW_VERSION\"/" {} \;
# Update leann-core dependencies
find "$PROJECT_ROOT/packages" -name "pyproject.toml" -exec sed -i '' "s/leann-core==[0-9.]*/leann-core==$NEW_VERSION/" {} \;
else
# Update version fields
find "$PROJECT_ROOT/packages" -name "pyproject.toml" -exec sed -i "s/version = \".*\"/version = \"$NEW_VERSION\"/" {} \;
# Update leann-core dependencies
find "$PROJECT_ROOT/packages" -name "pyproject.toml" -exec sed -i "s/leann-core==[0-9.]*/leann-core==$NEW_VERSION/" {} \;
fi
echo "✅ Version updated to $NEW_VERSION"
echo "✅ Dependencies updated to use leann-core==$NEW_VERSION"

18
scripts/release.sh Executable file
View File

@@ -0,0 +1,18 @@
#!/bin/bash
if [ $# -eq 0 ]; then
echo "Usage: $0 <version>"
echo "Example: $0 0.1.1"
exit 1
fi
VERSION=$1
# Update version
./scripts/bump_version.sh $VERSION
# Commit and push
git add . && git commit -m "chore: bump version to $VERSION" && git push
# Create release (triggers CI)
gh release create v$VERSION --generate-notes

30
scripts/upload_to_pypi.sh Executable file
View File

@@ -0,0 +1,30 @@
#!/bin/bash
# Manual upload script for testing
TARGET=${1:-"test"} # Default to test pypi
if [ "$TARGET" != "test" ] && [ "$TARGET" != "prod" ]; then
echo "Usage: $0 [test|prod]"
exit 1
fi
# Check for built packages
if ! ls packages/*/dist/*.whl >/dev/null 2>&1; then
echo "No built packages found. Run ./scripts/build_and_test.sh first"
exit 1
fi
if [ "$TARGET" == "test" ]; then
echo "Uploading to Test PyPI..."
twine upload --repository testpypi packages/*/dist/*
else
echo "Uploading to PyPI..."
echo "Are you sure? (y/N)"
read -r response
if [ "$response" == "y" ]; then
twine upload packages/*/dist/*
else
echo "Cancelled"
fi
fi

View File

@@ -12,7 +12,7 @@ else:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
use_mlx=True,
embedding_mode="mlx",
)
# 2. Add documents

1865
uv.lock generated
View File

File diff suppressed because it is too large Load Diff