Compare commits
86 Commits
perf-build
...
v0.1.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 | ||
|
|
a97d3ada1c | ||
|
|
6217bb5638 | ||
|
|
2760e99e18 | ||
|
|
0544f96b79 | ||
|
|
2ebb29de65 | ||
|
|
43762d44c7 | ||
|
|
cdaf0c98be | ||
|
|
aa9a14a917 | ||
|
|
9efcc6d95c | ||
|
|
f3f5d91207 | ||
|
|
6070160959 | ||
|
|
43155d2811 | ||
|
|
d3f85678ec | ||
|
|
2a96d05b21 | ||
|
|
851e888535 | ||
|
|
90120d4dff | ||
|
|
8513471573 | ||
|
|
71e5f1774c | ||
|
|
870a443446 | ||
|
|
cefaa2a4cc | ||
|
|
ab72a2ab9d | ||
|
|
046d457d22 | ||
|
|
7fd0a30fee | ||
|
|
c2f35c8e73 | ||
|
|
573313f0b6 | ||
|
|
f7af6805fa | ||
|
|
966de3a399 | ||
|
|
8a75829f3a | ||
|
|
0f7e34b9e2 | ||
|
|
be0322b616 | ||
|
|
232a525a62 | ||
|
|
587ce65cf6 | ||
|
|
ccf6c8bfd7 | ||
|
|
c112956d2d | ||
|
|
b3970793cf | ||
|
|
727724990e | ||
|
|
530f6e4af5 | ||
|
|
2f224f5793 | ||
|
|
1b6272ce0e |
262
.github/workflows/build-and-publish.yml
vendored
Normal file
262
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
name: CI - Build Multi-Platform Packages
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
publish:
|
||||||
|
description: 'Publish to PyPI (only use for emergency fixes)'
|
||||||
|
required: true
|
||||||
|
default: 'false'
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- 'false'
|
||||||
|
- 'test'
|
||||||
|
- 'prod'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
# Build pure Python package: leann-core
|
||||||
|
build-core:
|
||||||
|
name: Build leann-core
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv pip install --system build twine
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: |
|
||||||
|
cd packages/leann-core
|
||||||
|
uv build
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: leann-core-dist
|
||||||
|
path: packages/leann-core/dist/
|
||||||
|
|
||||||
|
# Build binary package: leann-backend-hnsw (default backend)
|
||||||
|
build-hnsw:
|
||||||
|
name: Build leann-backend-hnsw
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, macos-latest]
|
||||||
|
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libomp-dev libboost-all-dev libzmq3-dev \
|
||||||
|
pkg-config libopenblas-dev patchelf
|
||||||
|
|
||||||
|
- name: Install system dependencies (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
brew install libomp boost zeromq
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv pip install --system scikit-build-core numpy swig
|
||||||
|
uv pip install --system auditwheel delocate
|
||||||
|
|
||||||
|
- name: Build wheel
|
||||||
|
run: |
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
uv build --wheel --python python
|
||||||
|
|
||||||
|
- name: Repair wheel (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
|
||||||
|
- name: Repair wheel (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: hnsw-${{ matrix.os }}-py${{ matrix.python-version }}
|
||||||
|
path: packages/leann-backend-hnsw/dist/
|
||||||
|
|
||||||
|
# Build binary package: leann-backend-diskann (multi-platform)
|
||||||
|
build-diskann:
|
||||||
|
name: Build leann-backend-diskann
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, macos-latest]
|
||||||
|
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libomp-dev libboost-all-dev libaio-dev libzmq3-dev \
|
||||||
|
protobuf-compiler libprotobuf-dev libabsl-dev patchelf
|
||||||
|
|
||||||
|
# Install Intel MKL using Intel's installer
|
||||||
|
wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install system dependencies (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
brew install libomp boost zeromq protobuf
|
||||||
|
# MKL is not available on Homebrew, but DiskANN can work without it
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv pip install --system scikit-build-core numpy Cython pybind11
|
||||||
|
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||||
|
uv pip install --system auditwheel
|
||||||
|
else
|
||||||
|
uv pip install --system delocate
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Build wheel
|
||||||
|
run: |
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
uv build --wheel --python python
|
||||||
|
|
||||||
|
- name: Repair wheel (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
|
||||||
|
- name: Repair wheel (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: diskann-${{ matrix.os }}-py${{ matrix.python-version }}
|
||||||
|
path: packages/leann-backend-diskann/dist/
|
||||||
|
|
||||||
|
# Build meta-package: leann (build last)
|
||||||
|
build-meta:
|
||||||
|
name: Build leann meta-package
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv pip install --system build
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: |
|
||||||
|
cd packages/leann
|
||||||
|
uv build
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: leann-meta-dist
|
||||||
|
path: packages/leann/dist/
|
||||||
|
|
||||||
|
# Publish to PyPI (only for emergency fixes or manual triggers)
|
||||||
|
publish:
|
||||||
|
name: Publish to PyPI (Emergency)
|
||||||
|
needs: [build-core, build-hnsw, build-diskann, build-meta]
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: github.event_name == 'workflow_dispatch' && github.event.inputs.publish != 'false'
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: dist
|
||||||
|
|
||||||
|
- name: Flatten directory structure
|
||||||
|
run: |
|
||||||
|
mkdir -p all_wheels
|
||||||
|
find dist -name "*.whl" -exec cp {} all_wheels/ \;
|
||||||
|
find dist -name "*.tar.gz" -exec cp {} all_wheels/ \;
|
||||||
|
|
||||||
|
- name: Show what will be published
|
||||||
|
run: |
|
||||||
|
echo "📦 Packages to be published:"
|
||||||
|
ls -la all_wheels/
|
||||||
|
|
||||||
|
- name: Publish to Test PyPI
|
||||||
|
if: github.event.inputs.publish == 'test'
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
|
||||||
|
repository-url: https://test.pypi.org/legacy/
|
||||||
|
packages-dir: all_wheels/
|
||||||
|
skip-existing: true
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
if: github.event.inputs.publish == 'prod'
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
packages-dir: all_wheels/
|
||||||
|
skip-existing: true
|
||||||
206
.github/workflows/release-manual.yml
vendored
Normal file
206
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
name: Manual Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: 'Version to release (e.g., 0.1.1)'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
test_pypi:
|
||||||
|
description: 'Test on TestPyPI first'
|
||||||
|
required: false
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
validate-and-release:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
actions: read
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Check CI status
|
||||||
|
run: |
|
||||||
|
echo "ℹ️ This workflow will download build artifacts from the latest CI run."
|
||||||
|
echo " CI must have completed successfully on the current commit."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
- name: Validate version format
|
||||||
|
run: |
|
||||||
|
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||||
|
echo "❌ Invalid version format. Use semantic versioning (e.g., 0.1.1)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
- name: Check if version already exists
|
||||||
|
run: |
|
||||||
|
if git tag | grep -q "^v${{ inputs.version }}$"; then
|
||||||
|
echo "❌ Version v${{ inputs.version }} already exists!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ Version is new"
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.13'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
run: |
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Update versions
|
||||||
|
run: |
|
||||||
|
./scripts/bump_version.sh ${{ inputs.version }}
|
||||||
|
git config user.name "GitHub Actions"
|
||||||
|
git config user.email "actions@github.com"
|
||||||
|
git add packages/*/pyproject.toml
|
||||||
|
git commit -m "chore: release v${{ inputs.version }}"
|
||||||
|
|
||||||
|
- name: Get CI run ID
|
||||||
|
id: get-ci-run
|
||||||
|
run: |
|
||||||
|
# Get the latest successful CI run on the previous commit (before version bump)
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD~1)
|
||||||
|
RUN_ID=$(gh run list \
|
||||||
|
--workflow="CI - Build Multi-Platform Packages" \
|
||||||
|
--status=success \
|
||||||
|
--commit=$COMMIT_SHA \
|
||||||
|
--json databaseId \
|
||||||
|
--jq '.[0].databaseId')
|
||||||
|
|
||||||
|
if [ -z "$RUN_ID" ]; then
|
||||||
|
echo "❌ No successful CI run found for commit $COMMIT_SHA"
|
||||||
|
echo ""
|
||||||
|
echo "This usually means:"
|
||||||
|
echo "1. CI hasn't run on the latest commit yet"
|
||||||
|
echo "2. CI failed on the latest commit"
|
||||||
|
echo ""
|
||||||
|
echo "Please ensure CI passes on main branch before releasing."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✅ Found CI run: $RUN_ID"
|
||||||
|
echo "run-id=$RUN_ID" >> $GITHUB_OUTPUT
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Download artifacts from CI run
|
||||||
|
run: |
|
||||||
|
echo "📦 Downloading artifacts from CI run ${{ steps.get-ci-run.outputs.run-id }}..."
|
||||||
|
|
||||||
|
# Download all artifacts (not just wheels-*)
|
||||||
|
gh run download ${{ steps.get-ci-run.outputs.run-id }} \
|
||||||
|
--dir ./dist-downloads
|
||||||
|
|
||||||
|
# Consolidate all wheels into packages/*/dist/
|
||||||
|
mkdir -p packages/leann-core/dist
|
||||||
|
mkdir -p packages/leann-backend-hnsw/dist
|
||||||
|
mkdir -p packages/leann-backend-diskann/dist
|
||||||
|
mkdir -p packages/leann/dist
|
||||||
|
|
||||||
|
find ./dist-downloads -name "*.whl" -exec cp {} ./packages/ \;
|
||||||
|
|
||||||
|
# Move wheels to correct package directories
|
||||||
|
for wheel in packages/*.whl; do
|
||||||
|
if [[ $wheel == *"leann_core"* ]]; then
|
||||||
|
mv "$wheel" packages/leann-core/dist/
|
||||||
|
elif [[ $wheel == *"leann_backend_hnsw"* ]]; then
|
||||||
|
mv "$wheel" packages/leann-backend-hnsw/dist/
|
||||||
|
elif [[ $wheel == *"leann_backend_diskann"* ]]; then
|
||||||
|
mv "$wheel" packages/leann-backend-diskann/dist/
|
||||||
|
elif [[ $wheel == *"leann-"* ]] && [[ $wheel != *"backend"* ]] && [[ $wheel != *"core"* ]]; then
|
||||||
|
mv "$wheel" packages/leann/dist/
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# List downloaded wheels
|
||||||
|
echo "✅ Downloaded wheels:"
|
||||||
|
find packages/*/dist -name "*.whl" -type f | sort
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Test on TestPyPI (optional)
|
||||||
|
if: inputs.test_pypi
|
||||||
|
continue-on-error: true
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "$TWINE_PASSWORD" ]; then
|
||||||
|
echo "⚠️ TEST_PYPI_API_TOKEN not configured, skipping TestPyPI upload"
|
||||||
|
echo " To enable TestPyPI testing, add TEST_PYPI_API_TOKEN to repository secrets"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
pip install twine
|
||||||
|
echo "📦 Uploading to TestPyPI..."
|
||||||
|
twine upload --repository testpypi packages/*/dist/* --verbose || {
|
||||||
|
echo "⚠️ TestPyPI upload failed, but continuing with release"
|
||||||
|
echo " This is optional and won't block the release"
|
||||||
|
exit 0
|
||||||
|
}
|
||||||
|
echo "✅ Test upload successful!"
|
||||||
|
echo "📋 Check packages at: https://test.pypi.org/user/your-username/"
|
||||||
|
echo ""
|
||||||
|
echo "To test installation:"
|
||||||
|
echo "pip install -i https://test.pypi.org/simple/ leann"
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "$TWINE_PASSWORD" ]; then
|
||||||
|
echo "❌ PYPI_API_TOKEN not configured!"
|
||||||
|
echo " Please add PYPI_API_TOKEN to repository secrets"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
pip install twine
|
||||||
|
echo "📦 Publishing to PyPI..."
|
||||||
|
|
||||||
|
# Collect all wheels in one place
|
||||||
|
mkdir -p all_wheels
|
||||||
|
find packages/*/dist -name "*.whl" -exec cp {} all_wheels/ \;
|
||||||
|
find packages/*/dist -name "*.tar.gz" -exec cp {} all_wheels/ \;
|
||||||
|
|
||||||
|
echo "📋 Packages to publish:"
|
||||||
|
ls -la all_wheels/
|
||||||
|
|
||||||
|
# Upload to PyPI
|
||||||
|
twine upload all_wheels/* --skip-existing --verbose
|
||||||
|
|
||||||
|
echo "✅ Published to PyPI!"
|
||||||
|
echo "🎉 Check packages at: https://pypi.org/project/leann/"
|
||||||
|
|
||||||
|
- name: Create and push tag
|
||||||
|
run: |
|
||||||
|
git tag "v${{ inputs.version }}"
|
||||||
|
git push origin main
|
||||||
|
git push origin "v${{ inputs.version }}"
|
||||||
|
echo "✅ Tag v${{ inputs.version }} created and pushed"
|
||||||
|
|
||||||
|
- name: Create GitHub Release
|
||||||
|
uses: softprops/action-gh-release@v1
|
||||||
|
with:
|
||||||
|
tag_name: v${{ inputs.version }}
|
||||||
|
name: Release v${{ inputs.version }}
|
||||||
|
body: |
|
||||||
|
## 🚀 Release v${{ inputs.version }}
|
||||||
|
|
||||||
|
### What's Changed
|
||||||
|
See the [full changelog](https://github.com/${{ github.repository }}/compare/...v${{ inputs.version }})
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
```bash
|
||||||
|
pip install leann==${{ inputs.version }}
|
||||||
|
```
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -12,7 +12,6 @@ outputs/
|
|||||||
*.idx
|
*.idx
|
||||||
*.map
|
*.map
|
||||||
.history/
|
.history/
|
||||||
scripts/
|
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
@@ -85,3 +84,5 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
|
|
||||||
*.meta.json
|
*.meta.json
|
||||||
*.passages.json
|
*.passages.json
|
||||||
|
|
||||||
|
batchtest.py
|
||||||
9
.vscode/extensions.json
vendored
9
.vscode/extensions.json
vendored
@@ -1,9 +0,0 @@
|
|||||||
{
|
|
||||||
"recommendations": [
|
|
||||||
"llvm-vs-code-extensions.vscode-clangd",
|
|
||||||
"ms-python.python",
|
|
||||||
"ms-vscode.cmake-tools",
|
|
||||||
"vadimcn.vscode-lldb",
|
|
||||||
"eamodio.gitlens",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
283
.vscode/launch.json
vendored
283
.vscode/launch.json
vendored
@@ -1,283 +0,0 @@
|
|||||||
{
|
|
||||||
// Use IntelliSense to learn about possible attributes.
|
|
||||||
// Hover to view descriptions of existing attributes.
|
|
||||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
|
||||||
"version": "0.2.0",
|
|
||||||
"configurations": [
|
|
||||||
// new emdedder
|
|
||||||
{
|
|
||||||
"name": "New Embedder",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/main.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"--search",
|
|
||||||
"--use-original",
|
|
||||||
"--domain",
|
|
||||||
"dpr",
|
|
||||||
"--nprobe",
|
|
||||||
"5000",
|
|
||||||
"--load",
|
|
||||||
"flat",
|
|
||||||
"--embedder",
|
|
||||||
"intfloat/multilingual-e5-small"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
|
|
||||||
{
|
|
||||||
"name": "main.py",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/main.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"--query",
|
|
||||||
"1000",
|
|
||||||
"--load",
|
|
||||||
"bm25"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Simple Build",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"faiss/demo/simple_build.py"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
//# Fix for Intel MKL error
|
|
||||||
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
|
|
||||||
//python faiss/demo/build_demo.py
|
|
||||||
{
|
|
||||||
"name": "Build Demo",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"faiss/demo/build_demo.py"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "DiskANN Serve",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"demo/main.py",
|
|
||||||
"--mode",
|
|
||||||
"serve",
|
|
||||||
"--engine",
|
|
||||||
"sglang",
|
|
||||||
"--load-indices",
|
|
||||||
"diskann",
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--lazy-load",
|
|
||||||
"--recompute-beighbor-embeddings",
|
|
||||||
"--port",
|
|
||||||
"8082",
|
|
||||||
"--diskann-search-memory-maximum",
|
|
||||||
"2",
|
|
||||||
"--diskann-graph",
|
|
||||||
"240",
|
|
||||||
"--search-only"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
|
|
||||||
},
|
|
||||||
"preLaunchTask": "CMake: build",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "DiskANN Serve MAC",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"demo/main.py",
|
|
||||||
"--mode",
|
|
||||||
"serve",
|
|
||||||
"--engine",
|
|
||||||
"ollama",
|
|
||||||
"--load-indices",
|
|
||||||
"diskann",
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--lazy-load",
|
|
||||||
"--recompute-beighbor-embeddings"
|
|
||||||
],
|
|
||||||
"preLaunchTask": "CMake: build",
|
|
||||||
"env": {
|
|
||||||
"KMP_DUPLICATE_LIB_OK": "TRUE",
|
|
||||||
"OMP_NUM_THREADS": "1",
|
|
||||||
"MKL_NUM_THREADS": "1",
|
|
||||||
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
|
|
||||||
"KMP_BLOCKTIME": "0"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Python Debugger: Current File with Arguments",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "ric/main_ric.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"--config-name",
|
|
||||||
"${input:configSelection}"
|
|
||||||
],
|
|
||||||
"justMyCode": false
|
|
||||||
},
|
|
||||||
//python ./demo/validate_equivalence.py sglang
|
|
||||||
{
|
|
||||||
"name": "Validate Equivalence",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/validate_equivalence.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"sglang"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
|
|
||||||
{
|
|
||||||
"name": "Retrieval Demo",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/retrieval_demo.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"--engine",
|
|
||||||
"vllm",
|
|
||||||
"--skip-embeddings",
|
|
||||||
"--domain",
|
|
||||||
"dpr",
|
|
||||||
"--load-indices",
|
|
||||||
// "flat",
|
|
||||||
"ivf_flat"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
|
|
||||||
{
|
|
||||||
"name": "Retrieval Demo DiskANN",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/retrieval_demo.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": [
|
|
||||||
"--engine",
|
|
||||||
"sglang",
|
|
||||||
"--skip-embeddings",
|
|
||||||
"--domain",
|
|
||||||
"dpr",
|
|
||||||
"--load-indices",
|
|
||||||
"diskann",
|
|
||||||
"--hnsw-M",
|
|
||||||
"64",
|
|
||||||
"--hnsw-efConstruction",
|
|
||||||
"150",
|
|
||||||
"--hnsw-efSearch",
|
|
||||||
"128",
|
|
||||||
"--hnsw-sq-bits",
|
|
||||||
"8"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Find Probe",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "find_probe.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Python: Attach",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "attach",
|
|
||||||
"processId": "${command:pickProcess}",
|
|
||||||
"justMyCode": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Edge RAG",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"edgerag_demo.py"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
|
|
||||||
"MKL_NUM_THREADS": "1",
|
|
||||||
"OMP_NUM_THREADS": "1",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Launch Embedding Server",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "demo/embedding_server.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--zmq-port",
|
|
||||||
"5556",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "HNSW Serve",
|
|
||||||
"type": "lldb",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/.venv/bin/python",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"args": [
|
|
||||||
"demo/main.py",
|
|
||||||
"--domain",
|
|
||||||
"rpj_wiki",
|
|
||||||
"--load",
|
|
||||||
"hnsw",
|
|
||||||
"--mode",
|
|
||||||
"serve",
|
|
||||||
"--search",
|
|
||||||
"--skip-pa",
|
|
||||||
"--recompute",
|
|
||||||
"--hnsw-old"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"inputs": [
|
|
||||||
{
|
|
||||||
"id": "configSelection",
|
|
||||||
"type": "pickString",
|
|
||||||
"description": "Select a configuration",
|
|
||||||
"options": [
|
|
||||||
"example_config",
|
|
||||||
"vllm_gritlm"
|
|
||||||
],
|
|
||||||
"default": "example_config"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
43
.vscode/settings.json
vendored
43
.vscode/settings.json
vendored
@@ -1,43 +0,0 @@
|
|||||||
{
|
|
||||||
"python.analysis.extraPaths": [
|
|
||||||
"./sglang_repo/python"
|
|
||||||
],
|
|
||||||
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
|
|
||||||
"cmake.configureArgs": [
|
|
||||||
"-DPYBIND=True",
|
|
||||||
"-DUPDATE_EDITABLE_INSTALL=ON",
|
|
||||||
],
|
|
||||||
"cmake.environment": {
|
|
||||||
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
|
|
||||||
},
|
|
||||||
"cmake.buildDirectory": "${workspaceFolder}/build",
|
|
||||||
"files.associations": {
|
|
||||||
"*.tcc": "cpp",
|
|
||||||
"deque": "cpp",
|
|
||||||
"string": "cpp",
|
|
||||||
"unordered_map": "cpp",
|
|
||||||
"vector": "cpp",
|
|
||||||
"map": "cpp",
|
|
||||||
"unordered_set": "cpp",
|
|
||||||
"atomic": "cpp",
|
|
||||||
"inplace_vector": "cpp",
|
|
||||||
"*.ipp": "cpp",
|
|
||||||
"forward_list": "cpp",
|
|
||||||
"list": "cpp",
|
|
||||||
"any": "cpp",
|
|
||||||
"system_error": "cpp",
|
|
||||||
"__hash_table": "cpp",
|
|
||||||
"__split_buffer": "cpp",
|
|
||||||
"__tree": "cpp",
|
|
||||||
"ios": "cpp",
|
|
||||||
"set": "cpp",
|
|
||||||
"__string": "cpp",
|
|
||||||
"string_view": "cpp",
|
|
||||||
"ranges": "cpp",
|
|
||||||
"iosfwd": "cpp"
|
|
||||||
},
|
|
||||||
"lldb.displayFormat": "auto",
|
|
||||||
"lldb.showDisassembly": "auto",
|
|
||||||
"lldb.dereferencePointers": true,
|
|
||||||
"lldb.consoleMode": "commands",
|
|
||||||
}
|
|
||||||
16
.vscode/tasks.json
vendored
16
.vscode/tasks.json
vendored
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "2.0.0",
|
|
||||||
"tasks": [
|
|
||||||
{
|
|
||||||
"type": "cmake",
|
|
||||||
"label": "CMake: build",
|
|
||||||
"command": "build",
|
|
||||||
"targets": [
|
|
||||||
"all"
|
|
||||||
],
|
|
||||||
"group": "build",
|
|
||||||
"problemMatcher": [],
|
|
||||||
"detail": "CMake template build task"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
223
README.md
223
README.md
@@ -12,11 +12,11 @@
|
|||||||
The smallest vector index in the world. RAG Everything with LEANN!
|
The smallest vector index in the world. RAG Everything with LEANN!
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -26,9 +26,8 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
**The numbers speak for themselves:** Index 60 million Wikipedia articles in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks below ↓](#storage-usage-comparison)
|
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
|
||||||
|
|
||||||
## Why This Matters
|
|
||||||
|
|
||||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
|
|
||||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||||
|
|
||||||
## Quick Start in 1 minute
|
## Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||||
@@ -48,33 +47,30 @@ git submodule update --init --recursive
|
|||||||
|
|
||||||
**macOS:**
|
**macOS:**
|
||||||
```bash
|
```bash
|
||||||
brew install llvm libomp boost protobuf
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
export CC=$(brew --prefix llvm)/bin/clang
|
|
||||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
|
||||||
|
|
||||||
# Install with HNSW backend (default, recommended for most users)
|
# Install with HNSW backend (default, recommended for most users)
|
||||||
uv sync
|
# Install uv first if you don't have it:
|
||||||
|
# curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
# Or add DiskANN backend if you want to test more options
|
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
|
||||||
uv sync --extra diskann
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux (Ubuntu/Debian):**
|
**Linux:**
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
|
||||||
# Install with HNSW backend (default, recommended for most users)
|
# Install with HNSW backend (default, recommended for most users)
|
||||||
uv sync
|
uv sync
|
||||||
|
|
||||||
# Or add DiskANN backend if you want to test more options
|
|
||||||
uv sync --extra diskann
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Ollama Setup (Optional for Local LLM):**
|
|
||||||
|
|
||||||
*We support both hf-transformers and Ollama for local LLMs. Ollama is recommended for faster performance.*
|
**Ollama Setup (Recommended for full privacy):**
|
||||||
|
|
||||||
*macOS:*
|
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
||||||
|
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
|
||||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||||
|
|
||||||
@@ -83,7 +79,7 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
|
|||||||
ollama pull llama3.2:1b
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
*Linux:*
|
**Linux:**
|
||||||
```bash
|
```bash
|
||||||
# Install Ollama
|
# Install Ollama
|
||||||
curl -fsSL https://ollama.ai/install.sh | sh
|
curl -fsSL https://ollama.ai/install.sh | sh
|
||||||
@@ -95,62 +91,70 @@ ollama serve &
|
|||||||
ollama pull llama3.2:1b
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also replace `llama3.2:1b` to `deepseek-r1:1.5b` or `qwen3:4b` for better performance but higher memory usage.
|
## Quick Start in 30s
|
||||||
|
|
||||||
## Dead Simple API
|
Our declarative API makes RAG as easy as writing a config file.
|
||||||
|
[Try in this ipynb file →](demo.ipynb)
|
||||||
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
# 1. Build index (no embeddings stored!)
|
# 1. Build the index (no embeddings stored!)
|
||||||
builder = LeannBuilder(backend_name="hnsw")
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
builder.add_text("C# is a powerful programming language")
|
builder.add_text("C# is a powerful programming language")
|
||||||
builder.add_text("Python is a powerful programming language")
|
builder.add_text("Python is a powerful programming language and it is very popular")
|
||||||
builder.add_text("Machine learning transforms industries")
|
builder.add_text("Machine learning transforms industries")
|
||||||
builder.add_text("Neural networks process complex data")
|
builder.add_text("Neural networks process complex data")
|
||||||
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
|
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
|
||||||
builder.build_index("knowledge.leann")
|
builder.build_index("knowledge.leann")
|
||||||
|
|
||||||
# 2. Search with real-time embeddings
|
# 2. Search with real-time embeddings
|
||||||
searcher = LeannSearcher("knowledge.leann")
|
searcher = LeannSearcher("knowledge.leann")
|
||||||
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
|
results = searcher.search("programming languages", top_k=2)
|
||||||
print(results)
|
|
||||||
|
# 3. Chat with LEANN using retrieved results
|
||||||
|
llm_config = {
|
||||||
|
"type": "ollama",
|
||||||
|
"model": "llama3.2:1b"
|
||||||
|
}
|
||||||
|
|
||||||
|
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
|
||||||
|
response = chat.ask(
|
||||||
|
"Compare the two retrieved programming languages and say which one is more popular today.",
|
||||||
|
top_k=2,
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
## RAG on Everything!
|
||||||
|
|
||||||
[Try the interactive demo →](demo.ipynb)
|
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
## Wild Things You Can Do
|
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||||
|
|
||||||
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||||
|
|
||||||
### Process Any Documents (.pdf, .txt, .md)
|
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
|
||||||
|
|
||||||
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Drop your PDFs, .txt, .md files into examples/data/
|
# Drop your PDFs, .txt, .md files into examples/data/
|
||||||
uv run ./examples/main_cli_example.py
|
uv run ./examples/main_cli_example.py
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
# Or use python directly
|
# Or use python directly
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
python ./examples/main_cli_example.py
|
python ./examples/main_cli_example.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Uses Ollama `qwen3:8b` by default. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
|
||||||
|
|
||||||
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
|
||||||
|
|
||||||
### Search Your Entire Life
|
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||||
|
|
||||||
|
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||||
```bash
|
```bash
|
||||||
python examples/mail_reader_leann.py
|
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
|
||||||
# "What did my boss say about the Christmas party last year?"
|
|
||||||
# "Find all emails from my mom about birthday plans"
|
|
||||||
```
|
```
|
||||||
**90K emails → 14MB.** Finally, search your email like you search Google.
|
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
@@ -183,13 +187,11 @@ Once the index is built, you can ask questions like:
|
|||||||
- "Show me emails about travel expenses"
|
- "Show me emails about travel expenses"
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### Time Machine for the Web
|
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History!
|
||||||
```bash
|
```bash
|
||||||
python examples/google_history_reader_leann.py
|
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||||
# "What was that AI paper I read last month?"
|
|
||||||
# "Show me all the cooking videos I watched"
|
|
||||||
```
|
```
|
||||||
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
@@ -238,13 +240,13 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### WeChat Detective
|
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/wechat_history_reader_leann.py
|
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||||
# "Show me all group chats about weekend plans"
|
|
||||||
```
|
```
|
||||||
**400K messages → 64MB.** Search years of chat history in any language.
|
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||||
@@ -255,7 +257,13 @@ First, you need to install the WeChat exporter:
|
|||||||
sudo packages/wechat-exporter/wechattweak-cli install
|
sudo packages/wechat-exporter/wechattweak-cli install
|
||||||
```
|
```
|
||||||
|
|
||||||
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
**Troubleshooting:**
|
||||||
|
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||||
|
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||||
|
```
|
||||||
|
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||||
|
Failed to find or export WeChat data. Exiting.
|
||||||
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -290,6 +298,73 @@ Once the index is built, you can ask questions like:
|
|||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
|
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build an index from documents
|
||||||
|
leann build my-docs --docs ./documents
|
||||||
|
|
||||||
|
# Search your documents
|
||||||
|
leann search my-docs "machine learning concepts"
|
||||||
|
|
||||||
|
# Interactive chat with your documents
|
||||||
|
leann ask my-docs --interactive
|
||||||
|
|
||||||
|
# List all your indexes
|
||||||
|
leann list
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key CLI features:**
|
||||||
|
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||||
|
- Smart text chunking with overlap
|
||||||
|
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||||
|
- Organized index storage in `~/.leann/indexes/`
|
||||||
|
- Support for advanced search parameters
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||||
|
|
||||||
|
**Build Command:**
|
||||||
|
```bash
|
||||||
|
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||||
|
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||||
|
--graph-degree N Graph degree (default: 32)
|
||||||
|
--complexity N Build complexity (default: 64)
|
||||||
|
--force Force rebuild existing index
|
||||||
|
--compact Use compact storage (default: true)
|
||||||
|
--recompute Enable recomputation (default: true)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Search Command:**
|
||||||
|
```bash
|
||||||
|
leann search INDEX_NAME QUERY [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--top-k N Number of results (default: 5)
|
||||||
|
--complexity N Search complexity (default: 64)
|
||||||
|
--recompute-embeddings Use recomputation for highest accuracy
|
||||||
|
--pruning-strategy {global,local,proportional}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ask Command:**
|
||||||
|
```bash
|
||||||
|
leann ask INDEX_NAME [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||||
|
--model MODEL Model name (default: qwen3:8b)
|
||||||
|
--interactive Interactive chat mode
|
||||||
|
--top-k N Retrieval count (default: 20)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## 🏗️ Architecture & How It Works
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -321,23 +396,15 @@ python examples/compare_faiss_vs_leann.py
|
|||||||
|
|
||||||
Same dataset, same hardware, same embedding model. LEANN just works better.
|
Same dataset, same hardware, same embedding model. LEANN just works better.
|
||||||
|
|
||||||
## Reproduce Our Results
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv pip install -e ".[dev]" # Install dev dependencies
|
|
||||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
|
||||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
|
||||||
```
|
|
||||||
|
|
||||||
The evaluation script downloads data automatically on first run.
|
|
||||||
|
|
||||||
### Storage Usage Comparison
|
### Storage Usage Comparison
|
||||||
|
|
||||||
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (780K messages chunks) |Google Search History (38K entries)
|
||||||
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
||||||
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 2.4G |130.4 MB |
|
||||||
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **79 MB** |**6.4MB** |
|
||||||
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **97% smaller** |**95% smaller** |
|
||||||
|
|
||||||
<!-- ### Memory Usage Comparison
|
<!-- ### Memory Usage Comparison
|
||||||
|
|
||||||
@@ -356,6 +423,15 @@ The evaluation script downloads data automatically on first run.
|
|||||||
|
|
||||||
*Benchmarks run on Apple M3 Pro 36 GB*
|
*Benchmarks run on Apple M3 Pro 36 GB*
|
||||||
|
|
||||||
|
## Reproduce Our Results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install -e ".[dev]" # Install dev dependencies
|
||||||
|
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||||
|
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||||
|
```
|
||||||
|
|
||||||
|
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||||
## 🔬 Paper
|
## 🔬 Paper
|
||||||
|
|
||||||
If you find Leann useful, please cite:
|
If you find Leann useful, please cite:
|
||||||
@@ -432,6 +508,17 @@ export NCCL_IB_DISABLE=1
|
|||||||
export NCCL_NET_PLUGIN=none
|
export NCCL_NET_PLUGIN=none
|
||||||
export NCCL_SOCKET_IFNAME=ens5
|
export NCCL_SOCKET_IFNAME=ens5
|
||||||
``` -->
|
``` -->
|
||||||
|
## FAQ
|
||||||
|
|
||||||
|
### 1. My building time seems long
|
||||||
|
|
||||||
|
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
|
|
||||||
|
|
||||||
## 📈 Roadmap
|
## 📈 Roadmap
|
||||||
|
|
||||||
|
|||||||
322
demo.ipynb
322
demo.ipynb
@@ -1,35 +1,321 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Quick Start in 30s"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
"# install this if you areusing colab\n",
|
||||||
"# 1. Build index (no embeddings stored!)\n",
|
"! pip install leann"
|
||||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
]
|
||||||
"builder.add_text(\"C# is a powerful programming language but it is not very popular\")\n",
|
},
|
||||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
{
|
||||||
"builder.add_text(\"Machine learning transforms industries\") \n",
|
"cell_type": "markdown",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"metadata": {},
|
||||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your macbook\")\n",
|
"source": [
|
||||||
"builder.build_index(\"knowledge.leann\")\n",
|
"## Build the index"
|
||||||
"# 2. Search with real-time embeddings\n",
|
]
|
||||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
},
|
||||||
"results = searcher.search(\"programming languages\", top_k=2, recompute_beighbor_embeddings=True)\n",
|
{
|
||||||
"print(results)\n",
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO: Registering backend 'hnsw'\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||||
|
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||||
|
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||||
|
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||||
|
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
|
||||||
|
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
|
||||||
|
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
|
||||||
|
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"M: 64 for level: 0\n",
|
||||||
|
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
|
||||||
|
"[0.00s] Reading Index HNSW header...\n",
|
||||||
|
"[0.00s] Header read: d=768, ntotal=5\n",
|
||||||
|
"[0.00s] Reading HNSW struct vectors...\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
|
||||||
|
"[0.00s] Read assign_probas (6)\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
|
||||||
|
"[0.11s] Read cum_nneighbor_per_level (7)\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
|
||||||
|
"[0.21s] Read levels (5)\n",
|
||||||
|
"[0.30s] Probing for compact storage flag...\n",
|
||||||
|
"[0.30s] Found compact flag: False\n",
|
||||||
|
"[0.30s] Compact flag is False, reading original format...\n",
|
||||||
|
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
|
||||||
|
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
|
||||||
|
"[0.30s] Read offsets (6)\n",
|
||||||
|
"[0.40s] Attempting to read neighbors vector...\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
|
||||||
|
"[0.40s] Read neighbors (320)\n",
|
||||||
|
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
|
||||||
|
"[0.50s] Checking for storage data...\n",
|
||||||
|
"[0.50s] Found storage fourcc: 49467849.\n",
|
||||||
|
"[0.50s] Converting to CSR format...\n",
|
||||||
|
"[0.50s] Conversion loop finished. \n",
|
||||||
|
"[0.50s] Running validation checks...\n",
|
||||||
|
" Checking total valid neighbor count...\n",
|
||||||
|
" OK: Total valid neighbors = 20\n",
|
||||||
|
" Checking final pointer indices...\n",
|
||||||
|
" OK: Final pointers match data size.\n",
|
||||||
|
"[0.50s] Deleting original neighbors and offsets arrays...\n",
|
||||||
|
" CSR Stats: |data|=20, |level_ptr|=10\n",
|
||||||
|
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
|
||||||
|
" Pruning embeddings: Writing NULL storage marker.\n",
|
||||||
|
"[0.69s] Conversion complete.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
|
||||||
|
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannBuilder\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm_config = {\"type\": \"ollama\", \"model\": \"qwen3:8b\"}\n",
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
|
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||||
|
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
|
||||||
|
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||||
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||||
|
"builder.build_index(\"knowledge.leann\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Search with real-time embeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||||
|
"INFO:leann.api: Query: 'programming languages'\n",
|
||||||
|
"INFO:leann.api: Top_k: 2\n",
|
||||||
|
"INFO:leann.api: Additional kwargs: {}\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
|
||||||
|
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||||
|
"To disable this warning, you can either:\n",
|
||||||
|
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||||
|
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||||
|
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||||
|
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||||
|
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||||
|
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||||
|
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||||
|
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||||
|
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||||
|
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||||
|
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||||
|
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||||
|
"INFO: Skipping external storage loading, since is_recompute is true.\n",
|
||||||
|
"INFO: Registering backend 'hnsw'\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
|
||||||
|
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
|
||||||
|
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||||
|
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||||
|
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||||
|
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
|
||||||
|
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||||
|
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||||
|
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||||
|
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||||
|
"INFO:leann.api: Final enriched results: 2 passages\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
|
||||||
|
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannSearcher\n",
|
||||||
|
"\n",
|
||||||
|
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||||
|
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||||
|
"results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chat with LEANN using retrieved results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
|
||||||
|
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
|
||||||
|
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||||
|
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||||
|
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||||
|
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||||
|
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||||
|
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||||
|
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||||
|
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||||
|
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||||
|
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||||
|
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||||
|
"INFO: Skipping external storage loading, since is_recompute is true.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||||
|
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
|
||||||
|
"INFO:leann.api: Top_k: 2\n",
|
||||||
|
"INFO:leann.api: Additional kwargs: {}\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||||
|
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||||
|
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||||
|
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
|
||||||
|
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
|
||||||
|
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||||
|
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||||
|
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||||
|
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||||
|
"INFO:leann.api: Final enriched results: 2 passages\n",
|
||||||
|
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannChat\n",
|
||||||
|
"\n",
|
||||||
|
"llm_config = {\n",
|
||||||
|
" \"type\": \"hf\",\n",
|
||||||
|
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||||
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||||
"\n",
|
|
||||||
"response = chat.ask(\n",
|
"response = chat.ask(\n",
|
||||||
" \"Compare the two retrieved programming languages and say which one is more popular today. Respond in a single well-formed sentence.\",\n",
|
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||||
" top_k=2,\n",
|
" top_k=2,\n",
|
||||||
" recompute_beighbor_embeddings=True,\n",
|
" llm_kwargs={\"max_tokens\": 128}\n",
|
||||||
")\n",
|
")\n",
|
||||||
"print(response)"
|
"response"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
100
docs/RELEASE.md
Normal file
100
docs/RELEASE.md
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# Release Guide
|
||||||
|
|
||||||
|
## 📋 Prerequisites
|
||||||
|
|
||||||
|
Before releasing, ensure:
|
||||||
|
1. ✅ All code changes are committed and pushed
|
||||||
|
2. ✅ CI has passed on the latest commit (check [Actions](https://github.com/yichuan-w/LEANN/actions/workflows/ci.yml))
|
||||||
|
3. ✅ You have determined the new version number
|
||||||
|
|
||||||
|
### Required: PyPI Configuration
|
||||||
|
|
||||||
|
To enable PyPI publishing:
|
||||||
|
1. Get a PyPI API token from https://pypi.org/manage/account/token/
|
||||||
|
2. Add it to repository secrets: Settings → Secrets → Actions → New repository secret
|
||||||
|
- Name: `PYPI_API_TOKEN`
|
||||||
|
- Value: Your PyPI token (starts with `pypi-`)
|
||||||
|
|
||||||
|
### Optional: TestPyPI Configuration
|
||||||
|
|
||||||
|
To enable TestPyPI testing (recommended but not required):
|
||||||
|
1. Get a TestPyPI API token from https://test.pypi.org/manage/account/token/
|
||||||
|
2. Add it to repository secrets: Settings → Secrets → Actions → New repository secret
|
||||||
|
- Name: `TEST_PYPI_API_TOKEN`
|
||||||
|
- Value: Your TestPyPI token (starts with `pypi-`)
|
||||||
|
|
||||||
|
**Note**: TestPyPI testing is optional. If not configured, the release will skip TestPyPI and proceed.
|
||||||
|
|
||||||
|
## 🚀 Recommended: Manual Release Workflow
|
||||||
|
|
||||||
|
### Via GitHub UI (Most Reliable)
|
||||||
|
|
||||||
|
1. **Verify CI Status**: Check that the latest commit has a green checkmark ✅
|
||||||
|
2. Go to [Actions → Manual Release](https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml)
|
||||||
|
3. Click "Run workflow"
|
||||||
|
4. Enter version (e.g., `0.1.1`)
|
||||||
|
5. Toggle "Test on TestPyPI first" if desired
|
||||||
|
6. Click "Run workflow"
|
||||||
|
|
||||||
|
**What happens:**
|
||||||
|
- ✅ Downloads pre-built packages from CI (no rebuild needed!)
|
||||||
|
- ✅ Updates all package versions
|
||||||
|
- ✅ Optionally tests on TestPyPI
|
||||||
|
- ✅ **Publishes directly to PyPI**
|
||||||
|
- ✅ Creates tag and GitHub release
|
||||||
|
|
||||||
|
### Via Command Line
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh workflow run release-manual.yml -f version=0.1.1 -f test_pypi=true
|
||||||
|
```
|
||||||
|
|
||||||
|
## ⚡ Quick Release (One-Line)
|
||||||
|
|
||||||
|
For experienced users who want the fastest path:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./scripts/release.sh 0.1.1
|
||||||
|
```
|
||||||
|
|
||||||
|
This script will:
|
||||||
|
1. Update all package versions
|
||||||
|
2. Commit and push changes
|
||||||
|
3. Create GitHub release
|
||||||
|
4. **Manual Release workflow will automatically publish to PyPI**
|
||||||
|
|
||||||
|
⚠️ **Note**: If CI fails, you'll need to manually fix and re-tag
|
||||||
|
|
||||||
|
## Manual Testing Before Release
|
||||||
|
|
||||||
|
For testing specific packages locally (especially DiskANN on macOS):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build specific package locally
|
||||||
|
./scripts/build_and_test.sh diskann # or hnsw, core, meta, all
|
||||||
|
|
||||||
|
# Test installation in a clean environment
|
||||||
|
python -m venv test_env
|
||||||
|
source test_env/bin/activate
|
||||||
|
pip install packages/*/dist/*.whl
|
||||||
|
|
||||||
|
# Upload to Test PyPI (optional)
|
||||||
|
./scripts/upload_to_pypi.sh test
|
||||||
|
|
||||||
|
# Upload to Production PyPI (use with caution)
|
||||||
|
./scripts/upload_to_pypi.sh prod
|
||||||
|
```
|
||||||
|
|
||||||
|
## First-time setup
|
||||||
|
|
||||||
|
1. Install GitHub CLI:
|
||||||
|
```bash
|
||||||
|
brew install gh
|
||||||
|
gh auth login
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Set PyPI token in GitHub:
|
||||||
|
```bash
|
||||||
|
gh secret set PYPI_API_TOKEN
|
||||||
|
# Paste your PyPI token when prompted
|
||||||
|
```
|
||||||
@@ -96,14 +96,12 @@ class EmlxReader(BaseReader):
|
|||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
[EMAIL METADATA]
|
[File]: {filename}
|
||||||
File: {filename}
|
[From]: {from_addr}
|
||||||
From: {from_addr}
|
[To]: {to_addr}
|
||||||
To: {to_addr}
|
[Subject]: {subject}
|
||||||
Subject: {subject}
|
[Date]: {date}
|
||||||
Date: {date}
|
[EMAIL BODY Start]:
|
||||||
[END METADATA]
|
|
||||||
|
|
||||||
{body}
|
{body}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -65,12 +65,14 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
|||||||
|
|
||||||
if not all_documents:
|
if not all_documents:
|
||||||
print("No documents loaded from any source. Exiting.")
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||||
|
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
@@ -78,7 +80,9 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
|||||||
# Split the document into chunks
|
# Split the document into chunks
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
text = node.get_content()
|
||||||
|
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
|
||||||
@@ -225,7 +229,7 @@ async def main():
|
|||||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||||
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
|
parser.add_argument('--index-dir', type=str, default="./all_google_new",
|
||||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||||
parser.add_argument('--max-entries', type=int, default=1000,
|
parser.add_argument('--max-entries', type=int, default=1000,
|
||||||
help='Maximum number of history entries to process (default: 1000)')
|
help='Maximum number of history entries to process (default: 1000)')
|
||||||
|
|||||||
@@ -74,22 +74,17 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
[BROWSING HISTORY METADATA]
|
[Title]: {title}
|
||||||
URL: {url}
|
[URL of the page]: {url}
|
||||||
Title: {title}
|
[Last visited time]: {last_visit}
|
||||||
Last Visit: {last_visit}
|
[Visit times]: {visit_count}
|
||||||
Visit Count: {visit_count}
|
[Typed times]: {typed_count}
|
||||||
Typed Count: {typed_count}
|
|
||||||
Hidden: {hidden}
|
|
||||||
[END METADATA]
|
|
||||||
|
|
||||||
Title: {title}
|
|
||||||
URL: {url}
|
|
||||||
Last visited: {last_visit}
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(text=doc_content, metadata={})
|
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
||||||
|
# if len(title) > 150:
|
||||||
|
# print(f"Title is too long: {title}")
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|||||||
@@ -335,14 +335,15 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime('%H:%M:%S')
|
# change to YYYY-MM-DD HH:MM:SS
|
||||||
|
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
except:
|
except:
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
|
|
||||||
sender = "Me" if is_sent_from_self else "Contact"
|
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||||
message_parts.append(f"[{time_str}] {sender}: {readable_text}")
|
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||||
|
|
||||||
concatenated_text = "\n".join(message_parts)
|
concatenated_text = "\n".join(message_parts)
|
||||||
|
|
||||||
@@ -354,13 +355,11 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
|
|
||||||
{concatenated_text}
|
{concatenated_text}
|
||||||
"""
|
"""
|
||||||
|
# TODO @yichuan give better format and rich info here!
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
Contact: {contact_name}
|
|
||||||
|
|
||||||
{concatenated_text}
|
{concatenated_text}
|
||||||
"""
|
"""
|
||||||
return doc_content
|
return doc_content, contact_name
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
@@ -441,8 +440,8 @@ Contact: {contact_name}
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
doc_content = self._create_concatenated_content(message_group, contact_name)
|
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
||||||
doc = Document(text=doc_content, metadata={})
|
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def get_mail_path():
|
|||||||
return os.path.join(home_dir, "Library", "Mail")
|
return os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
# Default mail path for macOS
|
# Default mail path for macOS
|
||||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||||
|
|
||||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||||
"""
|
"""
|
||||||
@@ -74,7 +74,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
|||||||
print("No documents loaded from any source. Exiting.")
|
print("No documents loaded from any source. Exiting.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
@@ -85,9 +85,11 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
|||||||
# Split the document into chunks
|
# Split the document into chunks
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
text = node.get_content()
|
||||||
|
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
|
||||||
|
|
||||||
# Create LEANN index directory
|
# Create LEANN index directory
|
||||||
|
|
||||||
@@ -156,7 +158,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
|
|||||||
print(f"Loaded {len(documents)} email documents")
|
print(f"Loaded {len(documents)} email documents")
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
@@ -216,11 +218,10 @@ async def query_leann_index(index_path: str, query: str):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
chat_response = chat.ask(
|
chat_response = chat.ask(
|
||||||
query,
|
query,
|
||||||
top_k=10,
|
top_k=20,
|
||||||
recompute_beighbor_embeddings=True,
|
recompute_beighbor_embeddings=True,
|
||||||
complexity=12,
|
complexity=32,
|
||||||
beam_width=1,
|
beam_width=1,
|
||||||
|
|
||||||
)
|
)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"Time taken: {end_time - start_time} seconds")
|
print(f"Time taken: {end_time - start_time} seconds")
|
||||||
@@ -231,7 +232,7 @@ async def main():
|
|||||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||||
# Remove --mail-path argument and auto-detect all Messages directories
|
# Remove --mail-path argument and auto-detect all Messages directories
|
||||||
# Remove DEFAULT_MAIL_PATH
|
# Remove DEFAULT_MAIL_PATH
|
||||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
|
parser.add_argument('--index-dir', type=str, default="./mail_index_index_file",
|
||||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||||
parser.add_argument('--max-emails', type=int, default=1000,
|
parser.add_argument('--max-emails', type=int, default=1000,
|
||||||
help='Maximum number of emails to process (-1 means all)')
|
help='Maximum number of emails to process (-1 means all)')
|
||||||
@@ -251,6 +252,9 @@ async def main():
|
|||||||
mail_path = get_mail_path()
|
mail_path = get_mail_path()
|
||||||
print(f"Searching for email data in: {mail_path}")
|
print(f"Searching for email data in: {mail_path}")
|
||||||
messages_dirs = find_all_messages_directories(mail_path)
|
messages_dirs = find_all_messages_directories(mail_path)
|
||||||
|
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||||
|
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||||
|
# messages_dirs = messages_dirs[:1]
|
||||||
|
|
||||||
print('len(messages_dirs): ', len(messages_dirs))
|
print('len(messages_dirs): ', len(messages_dirs))
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +1,40 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from llama_index.core import SimpleDirectoryReader, Settings
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
import asyncio
|
import asyncio
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
|
||||||
)
|
|
||||||
print("Loading documents...")
|
|
||||||
documents = SimpleDirectoryReader(
|
|
||||||
"examples/data",
|
|
||||||
recursive=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
|
||||||
).load_data(show_progress=True)
|
|
||||||
print("Documents loaded.")
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
|
|
||||||
async def main(args):
|
async def main(args):
|
||||||
INDEX_DIR = Path(args.index_dir)
|
INDEX_DIR = Path(args.index_dir)
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
if not INDEX_DIR.exists():
|
||||||
print(f"--- Index directory not found, building new index ---")
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("Loading documents...")
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
args.data_dir,
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
print("Documents loaded.")
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
@@ -58,8 +58,9 @@ async def main(args):
|
|||||||
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
|
||||||
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||||
|
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||||
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
|
||||||
@@ -70,9 +71,7 @@ async def main(args):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(
|
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
|
|
||||||
)
|
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
|
||||||
@@ -105,6 +104,12 @@ if __name__ == "__main__":
|
|||||||
default="./test_doc_files",
|
default="./test_doc_files",
|
||||||
help="Directory where the Leann index will be stored.",
|
help="Directory where the Leann index will be stored.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
type=str,
|
||||||
|
default="examples/data",
|
||||||
|
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
asyncio.run(main(args))
|
asyncio.run(main(args))
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def create_leann_index_from_multiple_wechat_exports(
|
|||||||
documents = reader.load_data(
|
documents = reader.load_data(
|
||||||
wechat_export_dir=str(export_dir),
|
wechat_export_dir=str(export_dir),
|
||||||
max_count=max_count,
|
max_count=max_count,
|
||||||
concatenate_messages=False, # Disable concatenation - one message per document
|
concatenate_messages=True, # Disable concatenation - one message per document
|
||||||
)
|
)
|
||||||
if documents:
|
if documents:
|
||||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
@@ -74,11 +74,11 @@ def create_leann_index_from_multiple_wechat_exports(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=64)
|
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
@@ -86,10 +86,11 @@ def create_leann_index_from_multiple_wechat_exports(
|
|||||||
# Split the document into chunks
|
# Split the document into chunks
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create LEANN index directory
|
# Create LEANN index directory
|
||||||
@@ -224,7 +225,7 @@ async def query_leann_index(index_path: str, query: str):
|
|||||||
query,
|
query,
|
||||||
top_k=20,
|
top_k=20,
|
||||||
recompute_beighbor_embeddings=True,
|
recompute_beighbor_embeddings=True,
|
||||||
complexity=128,
|
complexity=16,
|
||||||
beam_width=1,
|
beam_width=1,
|
||||||
llm_config={
|
llm_config={
|
||||||
"type": "openai",
|
"type": "openai",
|
||||||
@@ -252,13 +253,13 @@ async def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--index-dir",
|
"--index-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="./wechat_history_june19_test",
|
default="./wechat_history_magic_test_11Debug_new",
|
||||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-entries",
|
"--max-entries",
|
||||||
type=int,
|
type=int,
|
||||||
default=5000,
|
default=50,
|
||||||
help="Maximum number of chat entries to process (default: 5000)",
|
help="Maximum number of chat entries to process (default: 5000)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||||
|
|
||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
project(leann_backend_diskann_wrapper)
|
project(leann_backend_diskann_wrapper)
|
||||||
|
|
||||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
# DiskANN will handle everything itself, including compiling Python bindings
|
||||||
add_subdirectory(src/third_party/DiskANN)
|
add_subdirectory(src/third_party/DiskANN)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Literal
|
from typing import Dict, Any, List, Literal, Optional
|
||||||
import contextlib
|
import contextlib
|
||||||
import pickle
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
from leann.registry import register_backend
|
from leann.registry import register_backend
|
||||||
@@ -14,6 +16,46 @@ from leann.interface import (
|
|||||||
LeannBackendSearcherInterface,
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def suppress_cpp_output_if_needed():
|
||||||
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||||
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
|
||||||
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||||
|
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
if not should_suppress:
|
||||||
|
# Don't suppress, just yield
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save original file descriptors
|
||||||
|
stdout_fd = sys.stdout.fileno()
|
||||||
|
stderr_fd = sys.stderr.fileno()
|
||||||
|
|
||||||
|
# Save original stdout/stderr
|
||||||
|
stdout_dup = os.dup(stdout_fd)
|
||||||
|
stderr_dup = os.dup(stderr_fd)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Redirect to /dev/null
|
||||||
|
devnull = os.open(os.devnull, os.O_WRONLY)
|
||||||
|
os.dup2(devnull, stdout_fd)
|
||||||
|
os.dup2(devnull, stderr_fd)
|
||||||
|
os.close(devnull)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original file descriptors
|
||||||
|
os.dup2(stdout_dup, stdout_fd)
|
||||||
|
os.dup2(stderr_dup, stderr_fd)
|
||||||
|
os.close(stdout_dup)
|
||||||
|
os.close(stderr_dup)
|
||||||
|
|
||||||
|
|
||||||
def _get_diskann_metrics():
|
def _get_diskann_metrics():
|
||||||
from . import _diskannpy as diskannpy # type: ignore
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
@@ -65,22 +107,20 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
|
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
data_filename = f"{index_prefix}_data.bin"
|
data_filename = f"{index_prefix}_data.bin"
|
||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
|
||||||
label_map_file = index_dir / "leann.labels.map"
|
|
||||||
with open(label_map_file, "wb") as f:
|
|
||||||
pickle.dump(label_map, f)
|
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
build_kwargs.get("distance_metric", "mips").lower()
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
)
|
)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError("Unsupported distance_metric.")
|
raise ValueError(
|
||||||
|
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from . import _diskannpy as diskannpy # type: ignore
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
@@ -102,36 +142,40 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
os.remove(temp_data_file)
|
os.remove(temp_data_file)
|
||||||
|
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
|
||||||
|
|
||||||
|
|
||||||
class DiskannSearcher(BaseSearcher):
|
class DiskannSearcher(BaseSearcher):
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
index_path,
|
index_path,
|
||||||
backend_module_name="leann_backend_diskann.embedding_server",
|
backend_module_name="leann_backend_diskann.diskann_embedding_server",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
from . import _diskannpy as diskannpy # type: ignore
|
|
||||||
|
|
||||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
# Initialize DiskANN index with suppressed C++ output based on log level
|
||||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
with suppress_cpp_output_if_needed():
|
||||||
if metric_enum is None:
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
|
||||||
|
|
||||||
self.num_threads = kwargs.get("num_threads", 8)
|
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||||
self.zmq_port = kwargs.get("zmq_port", 6666)
|
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||||
|
if metric_enum is None:
|
||||||
|
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||||
|
|
||||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
self.num_threads = kwargs.get("num_threads", 8)
|
||||||
self._index = diskannpy.StaticDiskFloatIndex(
|
|
||||||
metric_enum,
|
fake_zmq_port = 6666
|
||||||
full_index_prefix,
|
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||||
self.num_threads,
|
self._index = diskannpy.StaticDiskFloatIndex(
|
||||||
kwargs.get("num_nodes_to_cache", 0),
|
metric_enum,
|
||||||
1,
|
full_index_prefix,
|
||||||
self.zmq_port,
|
self.num_threads,
|
||||||
"",
|
kwargs.get("num_nodes_to_cache", 0),
|
||||||
"",
|
1,
|
||||||
)
|
fake_zmq_port, # Initial port, can be updated at runtime
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -142,7 +186,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
zmq_port: Optional[int] = None,
|
||||||
batch_recompute: bool = False,
|
batch_recompute: bool = False,
|
||||||
dedup_node_dis: bool = False,
|
dedup_node_dis: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -161,7 +205,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
- "global": Use global pruning strategy (default)
|
- "global": Use global pruning strategy (default)
|
||||||
- "local": Use local pruning strategy
|
- "local": Use local pruning strategy
|
||||||
- "proportional": Not supported in DiskANN, falls back to global
|
- "proportional": Not supported in DiskANN, falls back to global
|
||||||
zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||||
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||||
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||||
@@ -169,22 +213,25 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||||
"""
|
"""
|
||||||
|
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||||
|
if recompute_embeddings:
|
||||||
|
if zmq_port is None:
|
||||||
|
raise ValueError(
|
||||||
|
"zmq_port must be provided if recompute_embeddings is True"
|
||||||
|
)
|
||||||
|
current_port = self._index.get_zmq_port()
|
||||||
|
if zmq_port != current_port:
|
||||||
|
logger.debug(
|
||||||
|
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
|
||||||
|
)
|
||||||
|
self._index.set_zmq_port(zmq_port)
|
||||||
|
|
||||||
# DiskANN doesn't support "proportional" strategy
|
# DiskANN doesn't support "proportional" strategy
|
||||||
if pruning_strategy == "proportional":
|
if pruning_strategy == "proportional":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use recompute_embeddings parameter
|
|
||||||
use_recompute = recompute_embeddings
|
|
||||||
if use_recompute:
|
|
||||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
|
||||||
if not meta_file_path.exists():
|
|
||||||
raise RuntimeError(
|
|
||||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
|
||||||
)
|
|
||||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
|
|
||||||
@@ -194,28 +241,26 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
else: # "global"
|
else: # "global"
|
||||||
use_global_pruning = True
|
use_global_pruning = True
|
||||||
|
|
||||||
labels, distances = self._index.batch_search(
|
# Perform search with suppressed C++ output based on log level
|
||||||
query,
|
with suppress_cpp_output_if_needed():
|
||||||
query.shape[0],
|
labels, distances = self._index.batch_search(
|
||||||
top_k,
|
query,
|
||||||
complexity,
|
query.shape[0],
|
||||||
beam_width,
|
top_k,
|
||||||
self.num_threads,
|
complexity,
|
||||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
beam_width,
|
||||||
kwargs.get("skip_search_reorder", False),
|
self.num_threads,
|
||||||
use_recompute,
|
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||||
dedup_node_dis,
|
kwargs.get("skip_search_reorder", False),
|
||||||
prune_ratio,
|
recompute_embeddings,
|
||||||
batch_recompute,
|
dedup_node_dis,
|
||||||
use_global_pruning,
|
prune_ratio,
|
||||||
)
|
batch_recompute,
|
||||||
|
use_global_pruning,
|
||||||
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [
|
||||||
[
|
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
|
||||||
for int_label in batch_labels
|
|
||||||
]
|
|
||||||
for batch_labels in labels
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -0,0 +1,283 @@
|
|||||||
|
"""
|
||||||
|
DiskANN-specific embedding server
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
# Ensure we have a handler if none exists
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def create_diskann_embedding_server(
|
||||||
|
passages_file: Optional[str] = None,
|
||||||
|
zmq_port: int = 5555,
|
||||||
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||||
|
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
||||||
|
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||||
|
|
||||||
|
# Add leann-core to path for unified embedding computation
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||||
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.api import PassageManager
|
||||||
|
|
||||||
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import embedding computation module: {e}")
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
sys.path.pop(0)
|
||||||
|
|
||||||
|
# Check port availability
|
||||||
|
import socket
|
||||||
|
|
||||||
|
def check_port(port):
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
if check_port(zmq_port):
|
||||||
|
logger.error(f"Port {zmq_port} is already in use")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only support metadata file, fail fast for everything else
|
||||||
|
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||||
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||||
|
|
||||||
|
# Load metadata to get passage sources
|
||||||
|
with open(passages_file, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passages = PassageManager(meta["passage_sources"])
|
||||||
|
logger.info(
|
||||||
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import protobuf after ensuring the path is correct
|
||||||
|
try:
|
||||||
|
from . import embedding_pb2
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import protobuf module: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def zmq_server_thread():
|
||||||
|
"""ZMQ server thread using REP socket for universal compatibility"""
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(
|
||||||
|
zmq.REP
|
||||||
|
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||||
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# REP socket receives single-part messages
|
||||||
|
message = socket.recv()
|
||||||
|
|
||||||
|
# Check for empty messages - REP socket requires response to every request
|
||||||
|
if len(message) == 0:
|
||||||
|
logger.debug("Received empty message, sending empty response")
|
||||||
|
socket.send(b"") # REP socket must respond to every request
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
||||||
|
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
||||||
|
|
||||||
|
e2e_start = time.time()
|
||||||
|
|
||||||
|
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
||||||
|
texts = []
|
||||||
|
node_ids = []
|
||||||
|
is_text_request = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
req_proto.ParseFromString(message)
|
||||||
|
node_ids = list(req_proto.node_ids)
|
||||||
|
|
||||||
|
if not node_ids:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
||||||
|
)
|
||||||
|
except Exception as protobuf_error:
|
||||||
|
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
||||||
|
# Fallback to msgpack (for BaseSearcher direct text requests)
|
||||||
|
try:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
request = msgpack.unpackb(message)
|
||||||
|
# For BaseSearcher compatibility, request is a list of texts directly
|
||||||
|
if isinstance(request, list) and all(
|
||||||
|
isinstance(item, str) for item in request
|
||||||
|
):
|
||||||
|
texts = request
|
||||||
|
is_text_request = True
|
||||||
|
logger.info(
|
||||||
|
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not a valid msgpack text request")
|
||||||
|
except Exception as msgpack_error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look up texts by node IDs (only if not direct text request)
|
||||||
|
if not is_text_request:
|
||||||
|
for nid in node_ids:
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FATAL: Empty text for passage ID {nid}"
|
||||||
|
)
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"Passage ID {nid} not found: {e}")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Debug logging
|
||||||
|
logger.debug(f"Processing {len(texts)} texts")
|
||||||
|
logger.debug(
|
||||||
|
f"Text lengths: {[len(t) for t in texts[:5]]}"
|
||||||
|
) # Show first 5
|
||||||
|
|
||||||
|
# Process embeddings using unified computation
|
||||||
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare response based on request type
|
||||||
|
if is_text_request:
|
||||||
|
# For BaseSearcher compatibility: return msgpack format
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb(embeddings.tolist())
|
||||||
|
else:
|
||||||
|
# For DiskANN C++ compatibility: return protobuf format
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
hidden_contiguous = np.ascontiguousarray(
|
||||||
|
embeddings, dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Serialize embeddings data
|
||||||
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||||
|
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
|
||||||
|
# Send response back to the client
|
||||||
|
socket.send(response_data)
|
||||||
|
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
|
zmq_thread.start()
|
||||||
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
|
# Keep the main thread alive
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("DiskANN Server shutting down...")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||||
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
|
parser.add_argument(
|
||||||
|
"--passages-file",
|
||||||
|
type=str,
|
||||||
|
help="Metadata JSON file containing passage sources",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create and start the DiskANN embedding server
|
||||||
|
create_diskann_embedding_server(
|
||||||
|
passages_file=args.passages_file,
|
||||||
|
zmq_port=args.zmq_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
)
|
||||||
@@ -1,741 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Dict, Any, Optional, Union
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import zmq
|
|
||||||
import numpy as np
|
|
||||||
import msgpack
|
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
|
||||||
|
|
||||||
RED = "\033[91m"
|
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
|
||||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
|
||||||
logging.basicConfig(
|
|
||||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
# --- New Passage Loader from HNSW backend ---
|
|
||||||
class SimplePassageLoader:
|
|
||||||
"""
|
|
||||||
Simple passage loader that replaces config.py dependencies
|
|
||||||
"""
|
|
||||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
|
||||||
self.passages_data = passages_data or {}
|
|
||||||
self._meta_path = ''
|
|
||||||
|
|
||||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
|
||||||
"""Get passage by ID"""
|
|
||||||
str_id = str(passage_id)
|
|
||||||
if str_id in self.passages_data:
|
|
||||||
return {"text": self.passages_data[str_id]}
|
|
||||||
else:
|
|
||||||
# Return empty text for missing passages
|
|
||||||
return {"text": ""}
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.passages_data)
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self.passages_data.keys()
|
|
||||||
|
|
||||||
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|
||||||
"""
|
|
||||||
Load passages using metadata file with PassageManager for lazy loading
|
|
||||||
"""
|
|
||||||
# Load metadata to get passage sources
|
|
||||||
with open(meta_file, 'r') as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
|
|
||||||
# Import PassageManager dynamically to avoid circular imports
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Find the leann package directory relative to this file
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
|
||||||
sys.path.insert(0, str(leann_core_path))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from leann.api import PassageManager
|
|
||||||
passage_manager = PassageManager(meta['passage_sources'])
|
|
||||||
finally:
|
|
||||||
sys.path.pop(0)
|
|
||||||
|
|
||||||
# Load label map
|
|
||||||
passages_dir = Path(meta_file).parent
|
|
||||||
label_map_file = passages_dir / "leann.labels.map"
|
|
||||||
|
|
||||||
if label_map_file.exists():
|
|
||||||
import pickle
|
|
||||||
with open(label_map_file, 'rb') as f:
|
|
||||||
label_map = pickle.load(f)
|
|
||||||
print(f"Loaded label map with {len(label_map)} entries")
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
|
|
||||||
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
|
||||||
|
|
||||||
class LazyPassageLoader(SimplePassageLoader):
|
|
||||||
def __init__(self, passage_manager, label_map):
|
|
||||||
self.passage_manager = passage_manager
|
|
||||||
self.label_map = label_map
|
|
||||||
# Initialize parent with empty data
|
|
||||||
super().__init__({})
|
|
||||||
|
|
||||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
|
||||||
"""Get passage by ID with lazy loading"""
|
|
||||||
try:
|
|
||||||
int_id = int(passage_id)
|
|
||||||
if int_id in self.label_map:
|
|
||||||
string_id = self.label_map[int_id]
|
|
||||||
passage_data = self.passage_manager.get_passage(string_id)
|
|
||||||
if passage_data and passage_data.get("text"):
|
|
||||||
return {"text": passage_data["text"]}
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.label_map)
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self.label_map.keys()
|
|
||||||
|
|
||||||
loader = LazyPassageLoader(passage_manager, label_map)
|
|
||||||
loader._meta_path = meta_file
|
|
||||||
return loader
|
|
||||||
|
|
||||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
|
||||||
"""
|
|
||||||
Load passages from a JSONL file with label map support
|
|
||||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not os.path.exists(passages_file):
|
|
||||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
|
||||||
|
|
||||||
if not passages_file.endswith('.jsonl'):
|
|
||||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
|
||||||
|
|
||||||
# Load label map (int -> string_id)
|
|
||||||
passages_dir = Path(passages_file).parent
|
|
||||||
label_map_file = passages_dir / "leann.labels.map"
|
|
||||||
|
|
||||||
label_map = {}
|
|
||||||
if label_map_file.exists():
|
|
||||||
with open(label_map_file, 'rb') as f:
|
|
||||||
label_map = pickle.load(f)
|
|
||||||
print(f"Loaded label map with {len(label_map)} entries")
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
|
|
||||||
# Load passages by string ID
|
|
||||||
string_id_passages = {}
|
|
||||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
|
||||||
for line in f:
|
|
||||||
if line.strip():
|
|
||||||
passage = json.loads(line)
|
|
||||||
string_id_passages[passage['id']] = passage['text']
|
|
||||||
|
|
||||||
# Create int ID -> text mapping using label map
|
|
||||||
passages_data = {}
|
|
||||||
for int_id, string_id in label_map.items():
|
|
||||||
if string_id in string_id_passages:
|
|
||||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
|
||||||
else:
|
|
||||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
|
||||||
|
|
||||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
|
||||||
return SimplePassageLoader(passages_data)
|
|
||||||
|
|
||||||
def create_embedding_server_thread(
|
|
||||||
zmq_port=5555,
|
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
max_batch_size=128,
|
|
||||||
passages_file: Optional[str] = None,
|
|
||||||
embedding_mode: str = "sentence-transformers",
|
|
||||||
enable_warmup: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create and run embedding server in the current thread
|
|
||||||
This function is designed to be called in a separate thread
|
|
||||||
"""
|
|
||||||
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check if port is already occupied
|
|
||||||
import socket
|
|
||||||
def check_port(port):
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
if check_port(zmq_port):
|
|
||||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Auto-detect mode based on model name if not explicitly set
|
|
||||||
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
|
||||||
embedding_mode = "openai"
|
|
||||||
|
|
||||||
if embedding_mode == "mlx":
|
|
||||||
from leann.api import compute_embeddings_mlx
|
|
||||||
import torch
|
|
||||||
logger.info("Using MLX for embeddings")
|
|
||||||
# Set device to CPU for compatibility with DeviceTimer class
|
|
||||||
device = torch.device("cpu")
|
|
||||||
cuda_available = False
|
|
||||||
mps_available = False
|
|
||||||
elif embedding_mode == "openai":
|
|
||||||
from leann.api import compute_embeddings_openai
|
|
||||||
import torch
|
|
||||||
logger.info("Using OpenAI API for embeddings")
|
|
||||||
# Set device to CPU for compatibility with DeviceTimer class
|
|
||||||
device = torch.device("cpu")
|
|
||||||
cuda_available = False
|
|
||||||
mps_available = False
|
|
||||||
elif embedding_mode == "sentence-transformers":
|
|
||||||
# Initialize model
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Select device
|
|
||||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
|
||||||
cuda_available = torch.cuda.is_available()
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
device = torch.device("cuda")
|
|
||||||
logger.info("Using CUDA device")
|
|
||||||
elif mps_available:
|
|
||||||
device = torch.device("mps")
|
|
||||||
logger.info("Using MPS device (Apple Silicon)")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
logger.info("Using CPU device")
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
logger.info(f"Loading model {model_name}")
|
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
|
||||||
|
|
||||||
# Optimize model
|
|
||||||
if cuda_available or mps_available:
|
|
||||||
try:
|
|
||||||
model = model.half()
|
|
||||||
model = torch.compile(model)
|
|
||||||
logger.info(f"Using FP16 precision with model: {model_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
|
|
||||||
|
|
||||||
# Load passages from file if provided
|
|
||||||
if passages_file and os.path.exists(passages_file):
|
|
||||||
# Check if it's a metadata file or a single passages file
|
|
||||||
if passages_file.endswith('.meta.json'):
|
|
||||||
passages = load_passages_from_metadata(passages_file)
|
|
||||||
else:
|
|
||||||
# Try to find metadata file in same directory
|
|
||||||
passages_dir = Path(passages_file).parent
|
|
||||||
meta_files = list(passages_dir.glob("*.meta.json"))
|
|
||||||
if meta_files:
|
|
||||||
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
|
|
||||||
passages = load_passages_from_metadata(str(meta_files[0]))
|
|
||||||
else:
|
|
||||||
# Fallback to original single file loading (will cause warnings)
|
|
||||||
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
|
|
||||||
passages = load_passages_from_file(passages_file)
|
|
||||||
else:
|
|
||||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
|
||||||
passages = SimplePassageLoader()
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(passages)} passages.")
|
|
||||||
|
|
||||||
def client_warmup(zmq_port):
|
|
||||||
"""Perform client-side warmup for DiskANN server"""
|
|
||||||
time.sleep(2)
|
|
||||||
print(f"Performing client-side warmup with model {model_name}...")
|
|
||||||
|
|
||||||
# Get actual passage IDs from the loaded passages
|
|
||||||
sample_ids = []
|
|
||||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
|
||||||
available_ids = list(passages.keys())
|
|
||||||
# Take up to 5 actual IDs, but at least 1
|
|
||||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
|
||||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
|
||||||
else:
|
|
||||||
print("No passages available for warmup, skipping warmup...")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.REQ)
|
|
||||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
|
||||||
|
|
||||||
try:
|
|
||||||
ids_to_send = [int(x) for x in sample_ids]
|
|
||||||
except ValueError:
|
|
||||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not ids_to_send:
|
|
||||||
print("Skipping warmup send.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use protobuf format for warmup
|
|
||||||
from . import embedding_pb2
|
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
||||||
req_proto.node_ids.extend(ids_to_send)
|
|
||||||
request_bytes = req_proto.SerializeToString()
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
|
|
||||||
socket.send(request_bytes)
|
|
||||||
response_bytes = socket.recv()
|
|
||||||
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
resp_proto.ParseFromString(response_bytes)
|
|
||||||
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
|
|
||||||
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
print("Client-side Protobuf ZMQ warmup complete")
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during Protobuf ZMQ warmup: {e}")
|
|
||||||
|
|
||||||
class DeviceTimer:
|
|
||||||
"""Device timer"""
|
|
||||||
def __init__(self, name="", device=device):
|
|
||||||
self.name = name
|
|
||||||
self.device = device
|
|
||||||
self.start_time = 0
|
|
||||||
self.end_time = 0
|
|
||||||
|
|
||||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
else:
|
|
||||||
self.start_event = None
|
|
||||||
self.end_event = None
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start()
|
|
||||||
yield
|
|
||||||
self.end()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.start_event.record()
|
|
||||||
else:
|
|
||||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
|
||||||
self.end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
else:
|
|
||||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.end_time = time.time()
|
|
||||||
|
|
||||||
def elapsed_time(self):
|
|
||||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
|
||||||
else:
|
|
||||||
return self.end_time - self.start_time
|
|
||||||
|
|
||||||
def print_elapsed(self):
|
|
||||||
elapsed = self.elapsed_time()
|
|
||||||
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
|
|
||||||
|
|
||||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
|
||||||
"""Process text batch"""
|
|
||||||
if not texts_batch:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
# Filter out empty texts and their corresponding IDs
|
|
||||||
valid_texts = []
|
|
||||||
valid_ids = []
|
|
||||||
for i, text in enumerate(texts_batch):
|
|
||||||
if text.strip(): # Only include non-empty texts
|
|
||||||
valid_texts.append(text)
|
|
||||||
valid_ids.append(ids_batch[i])
|
|
||||||
|
|
||||||
if not valid_texts:
|
|
||||||
print("WARNING: No valid texts in batch")
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
# Tokenize
|
|
||||||
token_timer = DeviceTimer("tokenization")
|
|
||||||
with token_timer.timing():
|
|
||||||
inputs = tokenizer(
|
|
||||||
valid_texts,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=512,
|
|
||||||
return_tensors="pt"
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# Compute embeddings
|
|
||||||
embed_timer = DeviceTimer("embedding computation")
|
|
||||||
with embed_timer.timing():
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**inputs)
|
|
||||||
hidden_states = outputs.last_hidden_state
|
|
||||||
|
|
||||||
# Mean pooling
|
|
||||||
attention_mask = inputs['attention_mask']
|
|
||||||
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
|
||||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
|
||||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
|
||||||
batch_embeddings = sum_embeddings / sum_mask
|
|
||||||
embed_timer.print_elapsed()
|
|
||||||
|
|
||||||
return batch_embeddings.cpu().numpy()
|
|
||||||
|
|
||||||
# ZMQ server main loop - modified to use REP socket
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.ROUTER) # Changed to REP socket
|
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
|
||||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
|
||||||
|
|
||||||
# Set timeouts
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
|
|
||||||
|
|
||||||
from . import embedding_pb2
|
|
||||||
|
|
||||||
print(f"INFO: Embedding server ready to serve requests")
|
|
||||||
|
|
||||||
# Start warmup thread if enabled
|
|
||||||
if enable_warmup and len(passages) > 0:
|
|
||||||
import threading
|
|
||||||
print(f"Warmup enabled: starting warmup thread")
|
|
||||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
|
||||||
warmup_thread.daemon = True
|
|
||||||
warmup_thread.start()
|
|
||||||
else:
|
|
||||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
parts = socket.recv_multipart()
|
|
||||||
|
|
||||||
# --- Restore robust message format detection ---
|
|
||||||
# Must check parts length to avoid IndexError
|
|
||||||
if len(parts) >= 3:
|
|
||||||
identity = parts[0]
|
|
||||||
# empty = parts[1] # We usually don't care about the middle empty frame
|
|
||||||
message = parts[2]
|
|
||||||
elif len(parts) == 2:
|
|
||||||
# Can also handle cases without empty frame
|
|
||||||
identity = parts[0]
|
|
||||||
message = parts[1]
|
|
||||||
else:
|
|
||||||
# If received message format is wrong, print warning and ignore it instead of crashing
|
|
||||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
|
||||||
continue
|
|
||||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
|
||||||
|
|
||||||
# Handle control messages (MessagePack format)
|
|
||||||
try:
|
|
||||||
request_payload = msgpack.unpackb(message)
|
|
||||||
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
|
||||||
if request_payload[0] == "__QUERY_META_PATH__":
|
|
||||||
# Return the current meta path being used by the server
|
|
||||||
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
|
|
||||||
response = [current_meta_path]
|
|
||||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
|
|
||||||
# Update the server's meta path and reload passages
|
|
||||||
new_meta_path = request_payload[1]
|
|
||||||
try:
|
|
||||||
print(f"INFO: Updating server meta path to: {new_meta_path}")
|
|
||||||
# Reload passages from the new meta file
|
|
||||||
passages = load_passages_from_metadata(new_meta_path)
|
|
||||||
# Store the meta path for future queries
|
|
||||||
passages._meta_path = new_meta_path
|
|
||||||
response = ["SUCCESS"]
|
|
||||||
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to update meta path: {e}")
|
|
||||||
response = ["FAILED", str(e)]
|
|
||||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif request_payload[0] == "__QUERY_MODEL__":
|
|
||||||
# Return the current model being used by the server
|
|
||||||
response = [model_name]
|
|
||||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
|
|
||||||
# Update the server's embedding model
|
|
||||||
new_model_name = request_payload[1]
|
|
||||||
try:
|
|
||||||
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
|
|
||||||
|
|
||||||
# Clean up old model to free memory
|
|
||||||
if not use_mlx:
|
|
||||||
print("INFO: Releasing old model from memory...")
|
|
||||||
old_model = model
|
|
||||||
old_tokenizer = tokenizer
|
|
||||||
|
|
||||||
# Load new tokenizer first
|
|
||||||
print(f"Loading new tokenizer for {new_model_name}...")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
|
|
||||||
|
|
||||||
# Load new model
|
|
||||||
print(f"Loading new model {new_model_name}...")
|
|
||||||
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
|
|
||||||
|
|
||||||
# Optimize new model
|
|
||||||
if cuda_available or mps_available:
|
|
||||||
try:
|
|
||||||
model = model.half()
|
|
||||||
model = torch.compile(model)
|
|
||||||
print(f"INFO: Using FP16 precision with model: {new_model_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
|
||||||
|
|
||||||
# Now safely delete old model after new one is loaded
|
|
||||||
del old_model
|
|
||||||
del old_tokenizer
|
|
||||||
|
|
||||||
# Clear GPU cache if available
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
print("INFO: Cleared CUDA cache")
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
print("INFO: Cleared MPS cache")
|
|
||||||
|
|
||||||
# Force garbage collection
|
|
||||||
import gc
|
|
||||||
gc.collect()
|
|
||||||
print("INFO: Memory cleanup completed")
|
|
||||||
|
|
||||||
# Update model name
|
|
||||||
model_name = new_model_name
|
|
||||||
|
|
||||||
response = ["SUCCESS"]
|
|
||||||
print(f"INFO: Successfully updated model to: {new_model_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to update model: {e}")
|
|
||||||
response = ["FAILED", str(e)]
|
|
||||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
|
||||||
continue
|
|
||||||
except:
|
|
||||||
# Not a control message, continue with normal protobuf processing
|
|
||||||
pass
|
|
||||||
|
|
||||||
e2e_start = time.time()
|
|
||||||
lookup_timer = DeviceTimer("text lookup")
|
|
||||||
|
|
||||||
# Parse request
|
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
||||||
req_proto.ParseFromString(message)
|
|
||||||
node_ids = req_proto.node_ids
|
|
||||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
|
||||||
|
|
||||||
# Add debug information
|
|
||||||
if len(node_ids) > 0:
|
|
||||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
|
||||||
|
|
||||||
# Look up texts
|
|
||||||
texts = []
|
|
||||||
missing_ids = []
|
|
||||||
with lookup_timer.timing():
|
|
||||||
for nid in node_ids:
|
|
||||||
txtinfo = passages[nid]
|
|
||||||
txt = txtinfo["text"]
|
|
||||||
if txt:
|
|
||||||
texts.append(txt)
|
|
||||||
else:
|
|
||||||
# If text is empty, we still need a placeholder for batch processing,
|
|
||||||
# but record its ID as missing
|
|
||||||
texts.append("")
|
|
||||||
missing_ids.append(nid)
|
|
||||||
lookup_timer.print_elapsed()
|
|
||||||
|
|
||||||
if missing_ids:
|
|
||||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
|
||||||
|
|
||||||
# Process batch
|
|
||||||
total_size = len(texts)
|
|
||||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
|
||||||
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
if total_size > max_batch_size:
|
|
||||||
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
|
||||||
for i in range(0, total_size, max_batch_size):
|
|
||||||
end_idx = min(i + max_batch_size, total_size)
|
|
||||||
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
|
||||||
|
|
||||||
chunk_texts = texts[i:end_idx]
|
|
||||||
chunk_ids = node_ids[i:end_idx]
|
|
||||||
|
|
||||||
if embedding_mode == "mlx":
|
|
||||||
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
|
|
||||||
elif embedding_mode == "openai":
|
|
||||||
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
|
|
||||||
else: # sentence-transformers
|
|
||||||
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
|
|
||||||
all_embeddings.append(embeddings_chunk)
|
|
||||||
|
|
||||||
if embedding_mode == "sentence-transformers":
|
|
||||||
if cuda_available:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
|
|
||||||
hidden = np.vstack(all_embeddings)
|
|
||||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
|
||||||
else:
|
|
||||||
if embedding_mode == "mlx":
|
|
||||||
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
|
|
||||||
elif embedding_mode == "openai":
|
|
||||||
hidden = compute_embeddings_openai(texts, model_name)
|
|
||||||
else: # sentence-transformers
|
|
||||||
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
|
|
||||||
|
|
||||||
# Serialize response
|
|
||||||
ser_start = time.time()
|
|
||||||
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
|
|
||||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
|
||||||
resp_proto.missing_ids.extend(missing_ids)
|
|
||||||
|
|
||||||
response_data = resp_proto.SerializeToString()
|
|
||||||
|
|
||||||
# REP socket sends a single response
|
|
||||||
socket.send_multipart([identity, b'', response_data])
|
|
||||||
|
|
||||||
ser_end = time.time()
|
|
||||||
|
|
||||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
|
||||||
|
|
||||||
if embedding_mode == "sentence-transformers":
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
e2e_end = time.time()
|
|
||||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Error in ZMQ server: {e}")
|
|
||||||
try:
|
|
||||||
# Send empty response to maintain REQ-REP state
|
|
||||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
socket.send(empty_resp.SerializeToString())
|
|
||||||
except:
|
|
||||||
# If sending fails, recreate socket
|
|
||||||
socket.close()
|
|
||||||
socket = context.socket(zmq.REP)
|
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000)
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
||||||
print("INFO: ZMQ socket recreated after error")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to start embedding server: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def create_embedding_server(
|
|
||||||
domain="demo",
|
|
||||||
load_passages=True,
|
|
||||||
load_embeddings=False,
|
|
||||||
use_fp16=True,
|
|
||||||
use_int8=False,
|
|
||||||
use_cuda_graphs=False,
|
|
||||||
zmq_port=5555,
|
|
||||||
max_batch_size=128,
|
|
||||||
lazy_load_passages=False,
|
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
passages_file: Optional[str] = None,
|
|
||||||
embedding_mode: str = "sentence-transformers",
|
|
||||||
enable_warmup: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
原有的 create_embedding_server 函数保持不变
|
|
||||||
这个是阻塞版本,用于直接运行
|
|
||||||
"""
|
|
||||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Embedding service")
|
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
|
||||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
|
||||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
|
||||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
|
||||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
|
||||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
|
||||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
|
||||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
help="Embedding model name")
|
|
||||||
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
|
|
||||||
choices=["sentence-transformers", "mlx", "openai"],
|
|
||||||
help="Embedding backend mode")
|
|
||||||
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
|
|
||||||
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Handle backward compatibility with use_mlx
|
|
||||||
embedding_mode = args.embedding_mode
|
|
||||||
if args.use_mlx:
|
|
||||||
embedding_mode = "mlx"
|
|
||||||
|
|
||||||
create_embedding_server(
|
|
||||||
domain=args.domain,
|
|
||||||
load_passages=args.load_passages,
|
|
||||||
load_embeddings=args.load_embeddings,
|
|
||||||
use_fp16=args.use_fp16,
|
|
||||||
use_int8=args.use_int8,
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
zmq_port=args.zmq_port,
|
|
||||||
max_batch_size=args.max_batch_size,
|
|
||||||
lazy_load_passages=args.lazy_load_passages,
|
|
||||||
model_name=args.model_name,
|
|
||||||
passages_file=args.passages_file,
|
|
||||||
embedding_mode=embedding_mode,
|
|
||||||
enable_warmup=not args.disable_warmup,
|
|
||||||
)
|
|
||||||
@@ -4,15 +4,15 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.1.0"
|
version = "0.1.2"
|
||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = ["leann-core==0.1.2", "numpy"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# 关键:简化的 CMake 路径
|
# Key: simplified CMake path
|
||||||
cmake.source-dir = "third_party/DiskANN"
|
cmake.source-dir = "third_party/DiskANN"
|
||||||
# 关键:Python 包在根目录,路径完全匹配
|
# Key: Python package in root directory, paths match exactly
|
||||||
wheel.packages = ["leann_backend_diskann"]
|
wheel.packages = ["leann_backend_diskann"]
|
||||||
# 使用默认的 redirect 模式
|
# Use default redirect mode
|
||||||
editable.mode = "redirect"
|
editable.mode = "redirect"
|
||||||
cmake.build-type = "Release"
|
cmake.build-type = "Release"
|
||||||
build.verbose = true
|
build.verbose = true
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: af2a26481e...25339b0341
@@ -1,6 +1,7 @@
|
|||||||
# 最终简化版
|
|
||||||
cmake_minimum_required(VERSION 3.24)
|
cmake_minimum_required(VERSION 3.24)
|
||||||
project(leann_backend_hnsw_wrapper)
|
project(leann_backend_hnsw_wrapper)
|
||||||
|
set(CMAKE_C_COMPILER_WORKS 1)
|
||||||
|
set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||||
|
|
||||||
# Set OpenMP path for macOS
|
# Set OpenMP path for macOS
|
||||||
if(APPLE)
|
if(APPLE)
|
||||||
@@ -11,15 +12,9 @@ if(APPLE)
|
|||||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Build ZeroMQ from source
|
# Use system ZeroMQ instead of building from source
|
||||||
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
find_package(PkgConfig REQUIRED)
|
||||||
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||||
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
|
||||||
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
|
||||||
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
|
||||||
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
|
||||||
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
|
||||||
add_subdirectory(third_party/libzmq)
|
|
||||||
|
|
||||||
# Add cppzmq headers
|
# Add cppzmq headers
|
||||||
include_directories(third_party/cppzmq)
|
include_directories(third_party/cppzmq)
|
||||||
@@ -29,6 +24,7 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
|||||||
add_compile_definitions(MSGPACK_NO_BOOST)
|
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||||
include_directories(third_party/msgpack-c/include)
|
include_directories(third_party/msgpack-c/include)
|
||||||
|
|
||||||
|
# Faiss configuration - streamlined build
|
||||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||||
@@ -36,4 +32,24 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
|
# Disable additional SIMD versions to speed up compilation
|
||||||
|
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# Additional optimization options from INSTALL.md
|
||||||
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||||
|
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
|
||||||
|
|
||||||
|
# Avoid building demos and benchmarks
|
||||||
|
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
|
||||||
|
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# NEW: Tell Faiss to only build the generic version
|
||||||
|
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||||
|
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
add_subdirectory(third_party/faiss)
|
add_subdirectory(third_party/faiss)
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Literal
|
from typing import Dict, Any, List, Literal, Optional
|
||||||
import pickle
|
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import logging
|
||||||
|
|
||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||||
@@ -16,6 +15,8 @@ from leann.interface import (
|
|||||||
LeannBackendSearcherInterface,
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_metric_map():
|
def get_metric_map():
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
@@ -57,13 +58,9 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
|
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
|
||||||
label_map_file = index_dir / "leann.labels.map"
|
|
||||||
with open(label_map_file, "wb") as f:
|
|
||||||
pickle.dump(label_map, f)
|
|
||||||
|
|
||||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
@@ -85,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
def _convert_to_csr(self, index_file: Path):
|
def _convert_to_csr(self, index_file: Path):
|
||||||
"""Convert built index to CSR format"""
|
"""Convert built index to CSR format"""
|
||||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||||
|
|
||||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||||
|
|
||||||
@@ -94,11 +91,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
print("✅ CSR conversion successful.")
|
logger.info("✅ CSR conversion successful.")
|
||||||
index_file_old = index_file.with_suffix(".old")
|
index_file_old = index_file.with_suffix(".old")
|
||||||
shutil.move(str(index_file), str(index_file_old))
|
shutil.move(str(index_file), str(index_file_old))
|
||||||
shutil.move(str(csr_temp_file), str(index_file))
|
shutil.move(str(csr_temp_file), str(index_file))
|
||||||
print(
|
logger.info(
|
||||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -135,31 +132,22 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
|
|
||||||
hnsw_config = faiss.HNSWIndexConfig()
|
hnsw_config = faiss.HNSWIndexConfig()
|
||||||
hnsw_config.is_compact = self.is_compact
|
hnsw_config.is_compact = self.is_compact
|
||||||
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
hnsw_config.is_recompute = (
|
||||||
|
self.is_pruned
|
||||||
if self.is_pruned and not hnsw_config.is_recompute:
|
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
|
||||||
raise RuntimeError("Index is pruned but recompute is disabled.")
|
|
||||||
|
|
||||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||||
|
|
||||||
# Load label mapping
|
|
||||||
label_map_file = self.index_dir / "leann.labels.map"
|
|
||||||
if not label_map_file.exists():
|
|
||||||
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
|
|
||||||
|
|
||||||
with open(label_map_file, "rb") as f:
|
|
||||||
self.label_map = pickle.load(f)
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: np.ndarray,
|
query: np.ndarray,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
zmq_port: Optional[int] = None,
|
||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@@ -177,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
- "global": Use global PQ queue size for selection (default)
|
- "global": Use global PQ queue size for selection (default)
|
||||||
- "local": Local pruning, sort and select best candidates
|
- "local": Local pruning, sort and select best candidates
|
||||||
- "proportional": Base selection on new neighbor count ratio
|
- "proportional": Base selection on new neighbor count ratio
|
||||||
zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
||||||
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
||||||
|
|
||||||
@@ -186,15 +174,14 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
"""
|
"""
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
# Use recompute_embeddings parameter
|
if not recompute_embeddings:
|
||||||
use_recompute = recompute_embeddings or self.is_pruned
|
if self.is_pruned:
|
||||||
if use_recompute:
|
raise RuntimeError("Recompute is required for pruned index.")
|
||||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
if recompute_embeddings:
|
||||||
if not meta_file_path.exists():
|
if zmq_port is None:
|
||||||
raise RuntimeError(
|
raise ValueError(
|
||||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
"zmq_port must be provided if recompute_embeddings is True"
|
||||||
)
|
)
|
||||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
@@ -202,7 +189,10 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
faiss.normalize_L2(query)
|
faiss.normalize_L2(query)
|
||||||
|
|
||||||
params = faiss.SearchParametersHNSW()
|
params = faiss.SearchParametersHNSW()
|
||||||
params.zmq_port = zmq_port
|
if zmq_port is not None:
|
||||||
|
params.zmq_port = (
|
||||||
|
zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||||
|
)
|
||||||
params.efSearch = complexity
|
params.efSearch = complexity
|
||||||
params.beam_size = beam_width
|
params.beam_size = beam_width
|
||||||
|
|
||||||
@@ -239,11 +229,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [
|
||||||
[
|
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
|
||||||
for int_label in batch_labels
|
|
||||||
]
|
|
||||||
for batch_labels in labels
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -6,9 +6,14 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.1.0"
|
version = "0.1.2"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = [
|
||||||
|
"leann-core==0.1.2",
|
||||||
|
"numpy",
|
||||||
|
"pyzmq>=23.0.0",
|
||||||
|
"msgpack>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
wheel.packages = ["leann_backend_hnsw"]
|
wheel.packages = ["leann_backend_hnsw"]
|
||||||
@@ -16,3 +21,7 @@ editable.mode = "redirect"
|
|||||||
cmake.build-type = "Release"
|
cmake.build-type = "Release"
|
||||||
build.verbose = true
|
build.verbose = true
|
||||||
build.tool-args = ["-j8"]
|
build.tool-args = ["-j8"]
|
||||||
|
|
||||||
|
# CMake definitions to optimize compilation
|
||||||
|
[tool.scikit-build.cmake.define]
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 2547df4377...ff22e2c86b
Submodule packages/leann-backend-hnsw/third_party/msgpack-c updated: 9b801f087a...a0b2ec09da
@@ -4,16 +4,27 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.1.0"
|
version = "0.1.2"
|
||||||
description = "Core API and plugin system for Leann."
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
|
|
||||||
|
# All required dependencies included
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy>=1.20.0",
|
"numpy>=1.20.0",
|
||||||
"tqdm>=4.60.0"
|
"tqdm>=4.60.0",
|
||||||
|
"psutil>=5.8.0",
|
||||||
|
"pyzmq>=23.0.0",
|
||||||
|
"msgpack>=1.0.0",
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"sentence-transformers>=2.2.0",
|
||||||
|
"llama-index-core>=0.12.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
leann = "leann.cli:main"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
@@ -5,16 +5,18 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any, Optional, Literal
|
from typing import List, Dict, Any, Optional, Literal
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import uuid
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
@@ -22,7 +24,8 @@ def compute_embeddings(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
port: Optional[int] = None,
|
||||||
|
is_build=False,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
@@ -39,251 +42,63 @@ def compute_embeddings(
|
|||||||
Returns:
|
Returns:
|
||||||
numpy array of embeddings
|
numpy array of embeddings
|
||||||
"""
|
"""
|
||||||
# Override mode for backward compatibility
|
if use_server:
|
||||||
if use_mlx:
|
# Use embedding server (for search/query)
|
||||||
mode = "mlx"
|
if port is None:
|
||||||
|
raise ValueError("port is required when use_server is True")
|
||||||
# Auto-detect mode based on model name if not explicitly set
|
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
|
||||||
mode = "openai"
|
|
||||||
|
|
||||||
if mode == "mlx":
|
|
||||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
|
||||||
elif mode == "openai":
|
|
||||||
return compute_embeddings_openai(chunks, model_name)
|
|
||||||
elif mode == "sentence-transformers":
|
|
||||||
return compute_embeddings_sentence_transformers(
|
|
||||||
chunks, model_name, use_server=use_server
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
# Use direct computation (for build_index)
|
||||||
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
from .embedding_compute import (
|
||||||
|
compute_embeddings as compute_embeddings_direct,
|
||||||
|
)
|
||||||
|
|
||||||
|
return compute_embeddings_direct(
|
||||||
|
chunks,
|
||||||
|
model_name,
|
||||||
|
mode=mode,
|
||||||
|
is_build=is_build,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_sentence_transformers(
|
def compute_embeddings_via_server(
|
||||||
chunks: List[str], model_name: str, use_server: bool = True
|
chunks: List[str], model_name: str, port: int
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Computes embeddings using sentence-transformers.
|
"""Computes embeddings using sentence-transformers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunks: List of text chunks to embed
|
chunks: List of text chunks to embed
|
||||||
model_name: Name of the sentence transformer model
|
model_name: Name of the sentence transformer model
|
||||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
|
||||||
"""
|
"""
|
||||||
if not use_server:
|
logger.info(
|
||||||
print(
|
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
|
||||||
)
|
|
||||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
|
||||||
)
|
)
|
||||||
|
import zmq
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Use embedding server for sentence-transformers too
|
# Connect to embedding server
|
||||||
# This avoids loading the model twice (once in API, once in server)
|
context = zmq.Context()
|
||||||
try:
|
socket = context.socket(zmq.REQ)
|
||||||
# Import ZMQ client functionality and server manager
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
import zmq
|
|
||||||
import msgpack
|
|
||||||
import numpy as np
|
|
||||||
from .embedding_server_manager import EmbeddingServerManager
|
|
||||||
|
|
||||||
# Ensure embedding server is running
|
# Send chunks to server for embedding computation
|
||||||
port = 5557
|
request = chunks
|
||||||
server_manager = EmbeddingServerManager(
|
socket.send(msgpack.packb(request))
|
||||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
|
||||||
)
|
|
||||||
|
|
||||||
server_started = server_manager.start_server(
|
# Receive embeddings from server
|
||||||
port=port,
|
response = socket.recv()
|
||||||
model_name=model_name,
|
embeddings_list = msgpack.unpackb(response)
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
enable_warmup=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not server_started:
|
# Convert back to numpy array
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||||
|
|
||||||
# Connect to embedding server
|
socket.close()
|
||||||
context = zmq.Context()
|
context.term()
|
||||||
socket = context.socket(zmq.REQ)
|
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
|
||||||
|
|
||||||
# Send chunks to server for embedding computation
|
|
||||||
request = chunks
|
|
||||||
socket.send(msgpack.packb(request))
|
|
||||||
|
|
||||||
# Receive embeddings from server
|
|
||||||
response = socket.recv()
|
|
||||||
embeddings_list = msgpack.unpackb(response)
|
|
||||||
|
|
||||||
# Convert back to numpy array
|
|
||||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Fallback to direct sentence-transformers if server connection fails
|
|
||||||
print(
|
|
||||||
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
|
||||||
)
|
|
||||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_embeddings_sentence_transformers_direct(
|
|
||||||
chunks: List[str], model_name: str
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Direct sentence-transformers computation (fallback)."""
|
|
||||||
try:
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Load model using sentence-transformers
|
|
||||||
model = SentenceTransformer(model_name)
|
|
||||||
|
|
||||||
model = model.half()
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
|
||||||
)
|
|
||||||
# use acclerater GPU or MAC GPU
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
model = model.to("cuda")
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
model = model.to("mps")
|
|
||||||
|
|
||||||
# Generate embeddings
|
|
||||||
# give use an warning if OOM here means we need to turn down the batch size
|
|
||||||
embeddings = model.encode(
|
|
||||||
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
|
||||||
"""Computes embeddings using OpenAI API."""
|
|
||||||
try:
|
|
||||||
import openai
|
|
||||||
import os
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"openai not available. Install with: uv pip install openai"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Get API key from environment
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=api_key)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# OpenAI has a limit on batch size and input length
|
|
||||||
max_batch_size = 100 # Conservative batch size
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm
|
|
||||||
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
|
|
||||||
batch_range = range(0, len(chunks), max_batch_size)
|
|
||||||
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
|
|
||||||
except ImportError:
|
|
||||||
# Fallback without progress bar
|
|
||||||
batch_iterator = range(0, len(chunks), max_batch_size)
|
|
||||||
|
|
||||||
for i in batch_iterator:
|
|
||||||
batch_chunks = chunks[i:i + max_batch_size]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
|
||||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
|
||||||
print(
|
|
||||||
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
|
||||||
)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
|
||||||
"""Computes embeddings using an MLX model."""
|
|
||||||
try:
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx_lm.utils import load
|
|
||||||
from tqdm import tqdm
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model and tokenizer
|
|
||||||
model, tokenizer = load(model_name)
|
|
||||||
|
|
||||||
# Process chunks in batches with progress bar
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm
|
|
||||||
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
|
|
||||||
except ImportError:
|
|
||||||
batch_iterator = range(0, len(chunks), batch_size)
|
|
||||||
|
|
||||||
for i in batch_iterator:
|
|
||||||
batch_chunks = chunks[i:i + batch_size]
|
|
||||||
|
|
||||||
# Tokenize all chunks in the batch
|
|
||||||
batch_token_ids = []
|
|
||||||
for chunk in batch_chunks:
|
|
||||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
|
||||||
batch_token_ids.append(token_ids)
|
|
||||||
|
|
||||||
# Pad sequences to the same length for batch processing
|
|
||||||
max_length = max(len(ids) for ids in batch_token_ids)
|
|
||||||
padded_token_ids = []
|
|
||||||
for token_ids in batch_token_ids:
|
|
||||||
# Pad with tokenizer.pad_token_id or 0
|
|
||||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
|
||||||
padded_token_ids.append(padded)
|
|
||||||
|
|
||||||
# Convert to MLX array with batch dimension
|
|
||||||
input_ids = mx.array(padded_token_ids)
|
|
||||||
|
|
||||||
# Get embeddings for the batch
|
|
||||||
embeddings = model(input_ids)
|
|
||||||
|
|
||||||
# Mean pooling for each sequence in the batch
|
|
||||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
|
||||||
|
|
||||||
# Convert batch embeddings to numpy
|
|
||||||
for j in range(len(batch_chunks)):
|
|
||||||
pooled_list = pooled[j].tolist() # Convert to list
|
|
||||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
|
||||||
all_embeddings.append(pooled_numpy)
|
|
||||||
|
|
||||||
# Stack numpy arrays
|
|
||||||
return np.stack(all_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchResult:
|
class SearchResult:
|
||||||
id: str
|
id: str
|
||||||
@@ -299,25 +114,24 @@ class PassageManager:
|
|||||||
self.global_offset_map = {} # Combined map for fast lookup
|
self.global_offset_map = {} # Combined map for fast lookup
|
||||||
|
|
||||||
for source in passage_sources:
|
for source in passage_sources:
|
||||||
if source["type"] == "jsonl":
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
passage_file = source["path"]
|
passage_file = source["path"]
|
||||||
index_file = source["index_path"]
|
index_file = source["index_path"] # .idx file
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
f"Passage index file not found: {index_file}"
|
with open(index_file, "rb") as f:
|
||||||
)
|
offset_map = pickle.load(f)
|
||||||
with open(index_file, "rb") as f:
|
self.offset_maps[passage_file] = offset_map
|
||||||
offset_map = pickle.load(f)
|
self.passage_files[passage_file] = passage_file
|
||||||
self.offset_maps[passage_file] = offset_map
|
|
||||||
self.passage_files[passage_file] = passage_file
|
|
||||||
|
|
||||||
# Build global map for O(1) lookup
|
# Build global map for O(1) lookup
|
||||||
for passage_id, offset in offset_map.items():
|
for passage_id, offset in offset_map.items():
|
||||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||||
|
|
||||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||||
if passage_id in self.global_offset_map:
|
if passage_id in self.global_offset_map:
|
||||||
passage_file, offset = self.global_offset_map[passage_id]
|
passage_file, offset = self.global_offset_map[passage_id]
|
||||||
|
# Lazy file opening - only open when needed
|
||||||
with open(passage_file, "r", encoding="utf-8") as f:
|
with open(passage_file, "r", encoding="utf-8") as f:
|
||||||
f.seek(offset)
|
f.seek(offset)
|
||||||
return json.loads(f.readline())
|
return json.loads(f.readline())
|
||||||
@@ -328,7 +142,7 @@ class LeannBuilder:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backend_name: str,
|
backend_name: str,
|
||||||
embedding_model: str = "facebook/contriever-msmarco",
|
embedding_model: str = "facebook/contriever",
|
||||||
dimensions: Optional[int] = None,
|
dimensions: Optional[int] = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
@@ -344,14 +158,12 @@ class LeannBuilder:
|
|||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
self.backend_kwargs = backend_kwargs
|
self.backend_kwargs = backend_kwargs
|
||||||
if 'mlx' in self.embedding_model:
|
|
||||||
self.embedding_mode = "mlx"
|
|
||||||
self.chunks: List[Dict[str, Any]] = []
|
self.chunks: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
passage_id = metadata.get("id", str(uuid.uuid4()))
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
||||||
self.chunks.append(chunk_data)
|
self.chunks.append(chunk_data)
|
||||||
|
|
||||||
@@ -377,7 +189,10 @@ class LeannBuilder:
|
|||||||
with open(passages_file, "w", encoding="utf-8") as f:
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
try:
|
try:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
|
||||||
|
chunk_iterator = tqdm(
|
||||||
|
self.chunks, desc="Writing passages", unit="chunk"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
chunk_iterator = self.chunks
|
chunk_iterator = self.chunks
|
||||||
|
|
||||||
@@ -398,7 +213,11 @@ class LeannBuilder:
|
|||||||
pickle.dump(offset_map, f)
|
pickle.dump(offset_map, f)
|
||||||
texts_to_embed = [c["text"] for c in self.chunks]
|
texts_to_embed = [c["text"] for c in self.chunks]
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
|
texts_to_embed,
|
||||||
|
self.embedding_model,
|
||||||
|
self.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
is_build=True,
|
||||||
)
|
)
|
||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
@@ -472,7 +291,7 @@ class LeannBuilder:
|
|||||||
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
logger.info(
|
||||||
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -480,7 +299,7 @@ class LeannBuilder:
|
|||||||
if len(self.chunks) != len(ids):
|
if len(self.chunks) != len(ids):
|
||||||
# If no text chunks provided, create placeholder text entries
|
# If no text chunks provided, create placeholder text entries
|
||||||
if not self.chunks:
|
if not self.chunks:
|
||||||
print("No text chunks provided, creating placeholder entries...")
|
logger.info("No text chunks provided, creating placeholder entries...")
|
||||||
for id_val in ids:
|
for id_val in ids:
|
||||||
self.add_text(
|
self.add_text(
|
||||||
f"Document {id_val}",
|
f"Document {id_val}",
|
||||||
@@ -555,15 +374,19 @@ class LeannBuilder:
|
|||||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
print(f"Index built successfully from precomputed embeddings: {index_path}")
|
logger.info(
|
||||||
|
f"Index built successfully from precomputed embeddings: {index_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
meta_path_str = f"{index_path}.meta.json"
|
self.meta_path_str = f"{index_path}.meta.json"
|
||||||
if not Path(meta_path_str).exists():
|
if not Path(self.meta_path_str).exists():
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
|
raise FileNotFoundError(
|
||||||
with open(meta_path_str, "r", encoding="utf-8") as f:
|
f"Leann metadata file not found at {self.meta_path_str}"
|
||||||
|
)
|
||||||
|
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
||||||
self.meta_data = json.load(f)
|
self.meta_data = json.load(f)
|
||||||
backend_name = self.meta_data["backend_name"]
|
backend_name = self.meta_data["backend_name"]
|
||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
@@ -571,16 +394,15 @@ class LeannSearcher:
|
|||||||
self.embedding_mode = self.meta_data.get(
|
self.embedding_mode = self.meta_data.get(
|
||||||
"embedding_mode", "sentence-transformers"
|
"embedding_mode", "sentence-transformers"
|
||||||
)
|
)
|
||||||
# Backward compatibility with use_mlx
|
|
||||||
if self.meta_data.get("use_mlx", False):
|
|
||||||
self.embedding_mode = "mlx"
|
|
||||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = enable_warmup
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
|
index_path, **final_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -589,26 +411,39 @@ class LeannSearcher:
|
|||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[SearchResult]:
|
) -> List[SearchResult]:
|
||||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
print(f" Query: '{query}'")
|
logger.info(f" Query: '{query}'")
|
||||||
print(f" Top_k: {top_k}")
|
logger.info(f" Top_k: {top_k}")
|
||||||
print(f" Additional kwargs: {kwargs}")
|
logger.info(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
# Use backend's compute_query_embedding method
|
zmq_port = None
|
||||||
# This will automatically use embedding server if available and needed
|
|
||||||
import time
|
start_time = time.time()
|
||||||
|
if recompute_embeddings:
|
||||||
|
zmq_port = self.backend_impl._ensure_server_running(
|
||||||
|
self.meta_path_str,
|
||||||
|
port=expected_zmq_port,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
del expected_zmq_port
|
||||||
|
zmq_time = time.time() - start_time
|
||||||
|
logger.info(f" Launching server time: {zmq_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
query,
|
||||||
|
use_server_if_available=recompute_embeddings,
|
||||||
|
zmq_port=zmq_port,
|
||||||
|
)
|
||||||
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
embedding_time = time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
print(f" Embedding time: {embedding_time} seconds")
|
logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = self.backend_impl.search(
|
results = self.backend_impl.search(
|
||||||
@@ -623,14 +458,14 @@ class LeannSearcher:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
print(f" Search time: {search_time} seconds")
|
logger.info(f" Search time: {search_time} seconds")
|
||||||
print(
|
logger.info(
|
||||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||||
)
|
)
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
if "labels" in results and "distances" in results:
|
if "labels" in results and "distances" in results:
|
||||||
print(f" Processing {len(results['labels'][0])} passage IDs:")
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||||
for i, (string_id, dist) in enumerate(
|
for i, (string_id, dist) in enumerate(
|
||||||
zip(results["labels"][0], results["distances"][0])
|
zip(results["labels"][0], results["distances"][0])
|
||||||
):
|
):
|
||||||
@@ -644,15 +479,15 @@ class LeannSearcher:
|
|||||||
metadata=passage_data.get("metadata", {}),
|
metadata=passage_data.get("metadata", {}),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
print(
|
logger.info(
|
||||||
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
|
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(
|
logger.error(
|
||||||
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
|
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f" Final enriched results: {len(enriched_results)} passages")
|
logger.info(f" Final enriched results: {len(enriched_results)} passages")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
|
||||||
@@ -674,10 +509,10 @@ class LeannChat:
|
|||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
|
||||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
expected_zmq_port: int = 5557,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
@@ -691,7 +526,7 @@ class LeannChat:
|
|||||||
prune_ratio=prune_ratio,
|
prune_ratio=prune_ratio,
|
||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
zmq_port=zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Dict, Any, Optional, List
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import difflib
|
import difflib
|
||||||
|
import torch
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -28,6 +29,68 @@ def check_ollama_models() -> List[str]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||||
|
"""Check if a model exists in Ollama's remote library and return available tags
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(model_exists, available_tags): bool and list of matching tags
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Split model name and tag
|
||||||
|
if ':' in model_name:
|
||||||
|
base_model, requested_tag = model_name.split(':', 1)
|
||||||
|
else:
|
||||||
|
base_model, requested_tag = model_name, None
|
||||||
|
|
||||||
|
# First check if base model exists in library
|
||||||
|
library_response = requests.get("https://ollama.com/library", timeout=8)
|
||||||
|
if library_response.status_code != 200:
|
||||||
|
return True, [] # Assume exists if can't check
|
||||||
|
|
||||||
|
# Extract model names from library page
|
||||||
|
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
|
||||||
|
|
||||||
|
if base_model not in models_in_library:
|
||||||
|
return False, [] # Base model doesn't exist
|
||||||
|
|
||||||
|
# If base model exists, get available tags
|
||||||
|
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
|
||||||
|
if tags_response.status_code != 200:
|
||||||
|
return True, [] # Base model exists but can't get tags
|
||||||
|
|
||||||
|
# Extract tags for this model - be more specific to avoid HTML artifacts
|
||||||
|
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
|
||||||
|
raw_tags = re.findall(tag_pattern, tags_response.text)
|
||||||
|
|
||||||
|
# Clean up tags - remove HTML artifacts and duplicates
|
||||||
|
available_tags = []
|
||||||
|
seen = set()
|
||||||
|
for tag in raw_tags:
|
||||||
|
# Skip if it looks like HTML (contains < or >)
|
||||||
|
if '<' in tag or '>' in tag:
|
||||||
|
continue
|
||||||
|
if tag not in seen:
|
||||||
|
seen.add(tag)
|
||||||
|
available_tags.append(tag)
|
||||||
|
|
||||||
|
# Check if exact model exists
|
||||||
|
if requested_tag is None:
|
||||||
|
# User just requested base model, suggest tags
|
||||||
|
return True, available_tags[:10] # Return up to 10 tags
|
||||||
|
else:
|
||||||
|
exact_match = model_name in available_tags
|
||||||
|
return exact_match, available_tags[:10]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If scraping fails, assume model might exist (don't block user)
|
||||||
|
return True, []
|
||||||
|
|
||||||
|
|
||||||
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
||||||
"""Use intelligent fuzzy search for Ollama models"""
|
"""Use intelligent fuzzy search for Ollama models"""
|
||||||
if not available_models:
|
if not available_models:
|
||||||
@@ -243,24 +306,66 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
|||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models()
|
available_models = check_ollama_models()
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
# Use intelligent fuzzy search based on locally installed models
|
|
||||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
|
||||||
|
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
if suggestions:
|
|
||||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
|
||||||
for i, suggestion in enumerate(suggestions, 1):
|
|
||||||
error_msg += f" {i}. {suggestion}\n"
|
|
||||||
else:
|
|
||||||
error_msg += "\n\nYour installed models:\n"
|
|
||||||
for i, model in enumerate(available_models[:8], 1):
|
|
||||||
error_msg += f" {i}. {model}\n"
|
|
||||||
if len(available_models) > 8:
|
|
||||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
|
||||||
|
|
||||||
error_msg += "\nTo list all models: ollama list"
|
# Check if the model exists remotely and get available tags
|
||||||
error_msg += "\nTo download a new model: ollama pull <model_name>"
|
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
|
||||||
error_msg += "\nBrowse models: https://ollama.com/library"
|
|
||||||
|
if model_exists_remotely and model_name in available_tags:
|
||||||
|
# Exact model exists remotely - suggest pulling it
|
||||||
|
error_msg += f"\n\nTo install the requested model:\n"
|
||||||
|
error_msg += f" ollama pull {model_name}\n"
|
||||||
|
|
||||||
|
# Show local alternatives
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\nOr use one of these similar installed models:\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
|
||||||
|
elif model_exists_remotely and available_tags:
|
||||||
|
# Base model exists but requested tag doesn't - suggest correct tags
|
||||||
|
base_model = model_name.split(':')[0]
|
||||||
|
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
|
||||||
|
|
||||||
|
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
||||||
|
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
|
||||||
|
for i, tag in enumerate(available_tags[:8], 1):
|
||||||
|
error_msg += f" {i}. ollama pull {tag}\n"
|
||||||
|
if len(available_tags) > 8:
|
||||||
|
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
|
||||||
|
|
||||||
|
# Also show local alternatives
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\nOr use one of these similar installed models:\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Model doesn't exist remotely - show fuzzy suggestions
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||||
|
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
else:
|
||||||
|
error_msg += "\n\nYour installed models:\n"
|
||||||
|
for i, model in enumerate(available_models[:8], 1):
|
||||||
|
error_msg += f" {i}. {model}\n"
|
||||||
|
if len(available_models) > 8:
|
||||||
|
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||||
|
|
||||||
|
error_msg += "\n\nCommands:"
|
||||||
|
error_msg += "\n ollama list # List installed models"
|
||||||
|
if model_exists_remotely and available_tags:
|
||||||
|
if model_name in available_tags:
|
||||||
|
error_msg += f"\n ollama pull {model_name} # Install requested model"
|
||||||
|
else:
|
||||||
|
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
||||||
|
error_msg += "\n https://ollama.com/library # Browse available models"
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
@@ -375,8 +480,9 @@ class OllamaChat(LLMInterface):
|
|||||||
"stream": False, # Keep it simple for now
|
"stream": False, # Keep it simple for now
|
||||||
"options": kwargs,
|
"options": kwargs,
|
||||||
}
|
}
|
||||||
logger.info(f"Sending request to Ollama: {payload}")
|
logger.debug(f"Sending request to Ollama: {payload}")
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"Sending request to Ollama and waiting for response...")
|
||||||
response = requests.post(full_url, data=json.dumps(payload))
|
response = requests.post(full_url, data=json.dumps(payload))
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -396,7 +502,7 @@ class OllamaChat(LLMInterface):
|
|||||||
|
|
||||||
|
|
||||||
class HFChat(LLMInterface):
|
class HFChat(LLMInterface):
|
||||||
"""LLM interface for local Hugging Face Transformers models."""
|
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||||
@@ -407,7 +513,7 @@ class HFChat(LLMInterface):
|
|||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.pipelines import pipeline
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
import torch
|
import torch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -416,54 +522,101 @@ class HFChat(LLMInterface):
|
|||||||
|
|
||||||
# Auto-detect device
|
# Auto-detect device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
self.device = "cuda"
|
||||||
logger.info("CUDA is available. Using GPU.")
|
logger.info("CUDA is available. Using GPU.")
|
||||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
device = "mps"
|
self.device = "mps"
|
||||||
logger.info("MPS is available. Using Apple Silicon GPU.")
|
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
self.device = "cpu"
|
||||||
logger.info("No GPU detected. Using CPU.")
|
logger.info("No GPU detected. Using CPU.")
|
||||||
|
|
||||||
self.pipeline = pipeline("text-generation", model=model_name, device=device)
|
# Load tokenizer and model
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||||
|
device_map="auto" if self.device != "cpu" else None,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move model to device if not using device_map
|
||||||
|
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
|
# Set pad token if not present
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
# Map OpenAI-style arguments to Hugging Face equivalents
|
print('kwargs in HF: ', kwargs)
|
||||||
if "max_tokens" in kwargs:
|
# Check if this is a Qwen model and add /no_think by default
|
||||||
# Prefer user-provided max_new_tokens if both are present
|
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
|
||||||
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
|
|
||||||
# Remove the unsupported key to avoid errors in Transformers
|
|
||||||
kwargs.pop("max_tokens")
|
|
||||||
|
|
||||||
# Handle temperature=0 edge-case for greedy decoding
|
# For Qwen models, automatically add /no_think to the prompt
|
||||||
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
|
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
|
||||||
# Remove unsupported zero temperature and use deterministic generation
|
prompt = prompt + " /no_think"
|
||||||
kwargs.pop("temperature")
|
|
||||||
kwargs.setdefault("do_sample", False)
|
|
||||||
|
|
||||||
# Sensible defaults for text generation
|
# Prepare chat template
|
||||||
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
|
messages = [{"role": "user", "content": prompt}]
|
||||||
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
|
||||||
results = self.pipeline(prompt, **params)
|
|
||||||
|
|
||||||
# Handle different response formats from transformers
|
# Apply chat template if available
|
||||||
if isinstance(results, list) and len(results) > 0:
|
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||||
generated_text = (
|
try:
|
||||||
results[0].get("generated_text", "")
|
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||||
if isinstance(results[0], dict)
|
messages,
|
||||||
else str(results[0])
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Chat template failed, using raw prompt: {e}")
|
||||||
|
formatted_prompt = prompt
|
||||||
|
else:
|
||||||
|
# Fallback for models without chat template
|
||||||
|
formatted_prompt = prompt
|
||||||
|
|
||||||
|
# Tokenize input
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
formatted_prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=2048
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move inputs to device
|
||||||
|
if self.device != "cpu":
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# Set generation parameters
|
||||||
|
generation_config = {
|
||||||
|
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
|
||||||
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
"top_p": kwargs.get("top_p", 0.9),
|
||||||
|
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||||
|
"pad_token_id": self.tokenizer.eos_token_id,
|
||||||
|
"eos_token_id": self.tokenizer.eos_token_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle temperature=0 for greedy decoding
|
||||||
|
if generation_config["temperature"] == 0.0:
|
||||||
|
generation_config["do_sample"] = False
|
||||||
|
generation_config.pop("temperature")
|
||||||
|
|
||||||
|
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
**generation_config
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
generated_text = str(results)
|
|
||||||
|
|
||||||
# Extract only the newly generated portion by removing the original prompt
|
# Decode response
|
||||||
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
||||||
response = generated_text[len(prompt) :].strip()
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
else:
|
|
||||||
# Fallback: return the full response if prompt removal fails
|
|
||||||
response = str(generated_text)
|
|
||||||
|
|
||||||
return response
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
|
|||||||
315
packages/leann-core/src/leann/cli.py
Normal file
315
packages/leann-core/src/leann/cli.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
from .api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
|
|
||||||
|
class LeannCLI:
|
||||||
|
def __init__(self):
|
||||||
|
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||||
|
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_index_path(self, index_name: str) -> str:
|
||||||
|
index_dir = self.indexes_dir / index_name
|
||||||
|
return str(index_dir / "documents.leann")
|
||||||
|
|
||||||
|
def index_exists(self, index_name: str) -> bool:
|
||||||
|
index_dir = self.indexes_dir / index_name
|
||||||
|
meta_file = index_dir / "documents.leann.meta.json"
|
||||||
|
return meta_file.exists()
|
||||||
|
|
||||||
|
def create_parser(self) -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="leann",
|
||||||
|
description="LEANN - Local Enhanced AI Navigation",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
leann build my-docs --docs ./documents # Build index named my-docs
|
||||||
|
leann search my-docs "query" # Search in my-docs index
|
||||||
|
leann ask my-docs "question" # Ask my-docs index
|
||||||
|
leann list # List all stored indexes
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||||
|
build_parser.add_argument("index_name", help="Index name")
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--docs", type=str, required=True, help="Documents directory"
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-model", type=str, default="facebook/contriever"
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--force", "-f", action="store_true", help="Force rebuild"
|
||||||
|
)
|
||||||
|
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||||
|
build_parser.add_argument("--complexity", type=int, default=64)
|
||||||
|
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||||
|
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||||
|
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||||
|
|
||||||
|
# Search command
|
||||||
|
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||||
|
search_parser.add_argument("index_name", help="Index name")
|
||||||
|
search_parser.add_argument("query", help="Search query")
|
||||||
|
search_parser.add_argument("--top-k", type=int, default=5)
|
||||||
|
search_parser.add_argument("--complexity", type=int, default=64)
|
||||||
|
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
|
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
|
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--pruning-strategy",
|
||||||
|
choices=["global", "local", "proportional"],
|
||||||
|
default="global",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ask command
|
||||||
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
|
ask_parser.add_argument("index_name", help="Index name")
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="ollama",
|
||||||
|
choices=["simulated", "ollama", "hf", "openai"],
|
||||||
|
)
|
||||||
|
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||||
|
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||||
|
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||||
|
ask_parser.add_argument("--top-k", type=int, default=20)
|
||||||
|
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||||
|
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
|
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
|
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--pruning-strategy",
|
||||||
|
choices=["global", "local", "proportional"],
|
||||||
|
default="global",
|
||||||
|
)
|
||||||
|
|
||||||
|
# List command
|
||||||
|
list_parser = subparsers.add_parser("list", help="List all indexes")
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def list_indexes(self):
|
||||||
|
print("Stored LEANN indexes:")
|
||||||
|
|
||||||
|
if not self.indexes_dir.exists():
|
||||||
|
print(
|
||||||
|
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||||
|
|
||||||
|
if not index_dirs:
|
||||||
|
print(
|
||||||
|
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Found {len(index_dirs)} indexes:")
|
||||||
|
for i, index_dir in enumerate(index_dirs, 1):
|
||||||
|
index_name = index_dir.name
|
||||||
|
status = "✓" if self.index_exists(index_name) else "✗"
|
||||||
|
|
||||||
|
print(f" {i}. {index_name} [{status}]")
|
||||||
|
if self.index_exists(index_name):
|
||||||
|
meta_file = index_dir / "documents.leann.meta.json"
|
||||||
|
size_mb = sum(
|
||||||
|
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
|
||||||
|
) / (1024 * 1024)
|
||||||
|
print(f" Size: {size_mb:.1f} MB")
|
||||||
|
|
||||||
|
if index_dirs:
|
||||||
|
example_name = index_dirs[0].name
|
||||||
|
print(f"\nUsage:")
|
||||||
|
print(f' leann search {example_name} "your query"')
|
||||||
|
print(f" leann ask {example_name} --interactive")
|
||||||
|
|
||||||
|
def load_documents(self, docs_dir: str):
|
||||||
|
print(f"Loading documents from {docs_dir}...")
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
docs_dir,
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md", ".docx"],
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
async def build_index(self, args):
|
||||||
|
docs_dir = args.docs
|
||||||
|
index_name = args.index_name
|
||||||
|
index_dir = self.indexes_dir / index_name
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
|
if index_dir.exists() and not args.force:
|
||||||
|
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||||
|
return
|
||||||
|
|
||||||
|
all_texts = self.load_documents(docs_dir)
|
||||||
|
if not all_texts:
|
||||||
|
print("No documents found")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
complexity=args.complexity,
|
||||||
|
is_compact=args.compact,
|
||||||
|
is_recompute=args.recompute,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"Index built at {index_path}")
|
||||||
|
|
||||||
|
async def search_documents(self, args):
|
||||||
|
index_name = args.index_name
|
||||||
|
query = args.query
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
|
if not self.index_exists(index_name):
|
||||||
|
print(
|
||||||
|
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
|
results = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Search results for '{query}' (top {len(results)}):")
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f"{i}. Score: {result.score:.3f}")
|
||||||
|
print(f" {result.text[:200]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
async def ask_questions(self, args):
|
||||||
|
index_name = args.index_name
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
|
if not self.index_exists(index_name):
|
||||||
|
print(
|
||||||
|
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Starting chat with index '{index_name}'...")
|
||||||
|
print(f"Using {args.model} ({args.llm})")
|
||||||
|
|
||||||
|
llm_config = {"type": args.llm, "model": args.model}
|
||||||
|
if args.llm == "ollama":
|
||||||
|
llm_config["host"] = args.host
|
||||||
|
|
||||||
|
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||||
|
|
||||||
|
if args.interactive:
|
||||||
|
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("\nYou: ").strip()
|
||||||
|
if user_input.lower() in ["quit", "exit", "q"]:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
user_input,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
)
|
||||||
|
print(f"LEANN: {response}")
|
||||||
|
else:
|
||||||
|
query = input("Enter your question: ").strip()
|
||||||
|
if query:
|
||||||
|
response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
)
|
||||||
|
print(f"LEANN: {response}")
|
||||||
|
|
||||||
|
async def run(self, args=None):
|
||||||
|
parser = self.create_parser()
|
||||||
|
|
||||||
|
if args is None:
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.command:
|
||||||
|
parser.print_help()
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.command == "list":
|
||||||
|
self.list_indexes()
|
||||||
|
elif args.command == "build":
|
||||||
|
await self.build_index(args)
|
||||||
|
elif args.command == "search":
|
||||||
|
await self.search_documents(args)
|
||||||
|
elif args.command == "ask":
|
||||||
|
await self.ask_questions(args)
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
asyncio.run(cli.run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
375
packages/leann-core/src/leann/embedding_compute.py
Normal file
375
packages/leann-core/src/leann/embedding_compute.py
Normal file
@@ -0,0 +1,375 @@
|
|||||||
|
"""
|
||||||
|
Unified embedding computation module
|
||||||
|
Consolidates all embedding computation logic using SentenceTransformer
|
||||||
|
Preserves all optimization parameters to ensure performance
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set up logger with proper level
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
# Global model cache to avoid repeated loading
|
||||||
|
_model_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings(
|
||||||
|
texts: List[str],
|
||||||
|
model_name: str,
|
||||||
|
mode: str = "sentence-transformers",
|
||||||
|
is_build: bool = False,
|
||||||
|
batch_size: int = 32,
|
||||||
|
adaptive_optimization: bool = True,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Unified embedding computation entry point
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Model name
|
||||||
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||||
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
|
batch_size: Batch size for processing
|
||||||
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
if mode == "sentence-transformers":
|
||||||
|
return compute_embeddings_sentence_transformers(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
is_build=is_build,
|
||||||
|
batch_size=batch_size,
|
||||||
|
adaptive_optimization=adaptive_optimization,
|
||||||
|
)
|
||||||
|
elif mode == "openai":
|
||||||
|
return compute_embeddings_openai(texts, model_name)
|
||||||
|
elif mode == "mlx":
|
||||||
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_sentence_transformers(
|
||||||
|
texts: List[str],
|
||||||
|
model_name: str,
|
||||||
|
use_fp16: bool = True,
|
||||||
|
device: str = "auto",
|
||||||
|
batch_size: int = 32,
|
||||||
|
is_build: bool = False,
|
||||||
|
adaptive_optimization: bool = True,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Model name
|
||||||
|
use_fp16: Whether to use FP16 precision
|
||||||
|
device: Device to use ('auto', 'cuda', 'mps', 'cpu')
|
||||||
|
batch_size: Batch size for processing
|
||||||
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
|
"""
|
||||||
|
# Handle empty input
|
||||||
|
if not texts:
|
||||||
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-detect device
|
||||||
|
if device == "auto":
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
# Apply optimizations based on benchmark results
|
||||||
|
if adaptive_optimization:
|
||||||
|
# Use optimal batch_size constants for different devices based on benchmark results
|
||||||
|
if device == "mps":
|
||||||
|
batch_size = 128 # MPS optimal batch size from benchmark
|
||||||
|
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
||||||
|
batch_size = 32
|
||||||
|
elif device == "cuda":
|
||||||
|
batch_size = 256 # CUDA optimal batch size
|
||||||
|
# Keep original batch_size for CPU
|
||||||
|
|
||||||
|
# Create cache key
|
||||||
|
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||||
|
|
||||||
|
# Check if model is already cached
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
logger.info(f"Using cached optimized model: {model_name}")
|
||||||
|
model = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Loading and caching optimized SentenceTransformer model: {model_name}"
|
||||||
|
)
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Apply hardware optimizations
|
||||||
|
if device == "cuda":
|
||||||
|
# TODO: Haven't tested this yet
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.deterministic = False
|
||||||
|
torch.cuda.set_per_process_memory_fraction(0.9)
|
||||||
|
elif device == "mps":
|
||||||
|
try:
|
||||||
|
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
||||||
|
torch.mps.set_per_process_memory_fraction(0.9)
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning(
|
||||||
|
"Some MPS optimizations not available in this PyTorch version"
|
||||||
|
)
|
||||||
|
elif device == "cpu":
|
||||||
|
# TODO: Haven't tested this yet
|
||||||
|
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
||||||
|
try:
|
||||||
|
torch.backends.mkldnn.enabled = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Prepare optimized model and tokenizer parameters
|
||||||
|
model_kwargs = {
|
||||||
|
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
"_fast_init": True,
|
||||||
|
"attn_implementation": "eager", # Use eager attention for speed
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer_kwargs = {
|
||||||
|
"use_fast": True,
|
||||||
|
"padding": True,
|
||||||
|
"truncation": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try local loading first
|
||||||
|
model_kwargs["local_files_only"] = True
|
||||||
|
tokenizer_kwargs["local_files_only"] = True
|
||||||
|
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
logger.info("Model loaded successfully! (local + optimized)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Local loading failed ({e}), trying network download...")
|
||||||
|
# Fallback to network loading
|
||||||
|
model_kwargs["local_files_only"] = False
|
||||||
|
tokenizer_kwargs["local_files_only"] = False
|
||||||
|
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
logger.info("Model loaded successfully! (network + optimized)")
|
||||||
|
|
||||||
|
# Apply additional optimizations based on mode
|
||||||
|
if use_fp16 and device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
model = model.half()
|
||||||
|
logger.info(f"Applied FP16 precision: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"FP16 optimization failed: {e}")
|
||||||
|
|
||||||
|
# Apply torch.compile optimization
|
||||||
|
if device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
model = torch.compile(model, mode="reduce-overhead", dynamic=True)
|
||||||
|
logger.info(f"Applied torch.compile optimization: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"torch.compile optimization failed: {e}")
|
||||||
|
|
||||||
|
# Set model to eval mode and disable gradients for inference
|
||||||
|
model.eval()
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
# Cache the model
|
||||||
|
_model_cache[cache_key] = model
|
||||||
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
|
# Compute embeddings with optimized inference mode
|
||||||
|
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
||||||
|
|
||||||
|
# Use torch.inference_mode for optimal performance
|
||||||
|
with torch.inference_mode():
|
||||||
|
embeddings = model.encode(
|
||||||
|
texts,
|
||||||
|
batch_size=batch_size,
|
||||||
|
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Detected NaN or Inf values in embeddings, model: {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||||
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
|
"""Compute embeddings using OpenAI API"""
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
import os
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
# Cache OpenAI client
|
||||||
|
cache_key = "openai_client"
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
client = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
client = openai.OpenAI(api_key=api_key)
|
||||||
|
_model_cache[cache_key] = client
|
||||||
|
logger.info("OpenAI client cached")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenAI has limits on batch size and input length
|
||||||
|
max_batch_size = 100 # Conservative batch size
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||||
|
batch_range = range(0, len(texts), max_batch_size)
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback when tqdm is not available
|
||||||
|
batch_iterator = range(0, len(texts), max_batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_texts = texts[i : i + max_batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||||
|
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch {i} failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
logger.info(
|
||||||
|
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_mlx(
|
||||||
|
chunks: List[str], model_name: str, batch_size: int = 16
|
||||||
|
) -> np.ndarray:
|
||||||
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
|
"""Computes embeddings using an MLX model."""
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.utils import load
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache MLX model and tokenizer
|
||||||
|
cache_key = f"mlx_{model_name}"
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
logger.info(f"Using cached MLX model: {model_name}")
|
||||||
|
model, tokenizer = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
logger.info(f"Loading and caching MLX model: {model_name}")
|
||||||
|
model, tokenizer = load(model_name)
|
||||||
|
_model_cache[cache_key] = (model, tokenizer)
|
||||||
|
logger.info(f"MLX model cached: {cache_key}")
|
||||||
|
|
||||||
|
# Process chunks in batches with progress bar
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
batch_iterator = range(0, len(chunks), batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_chunks = chunks[i : i + batch_size]
|
||||||
|
|
||||||
|
# Tokenize all chunks in the batch
|
||||||
|
batch_token_ids = []
|
||||||
|
for chunk in batch_chunks:
|
||||||
|
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||||
|
batch_token_ids.append(token_ids)
|
||||||
|
|
||||||
|
# Pad sequences to the same length for batch processing
|
||||||
|
max_length = max(len(ids) for ids in batch_token_ids)
|
||||||
|
padded_token_ids = []
|
||||||
|
for token_ids in batch_token_ids:
|
||||||
|
# Pad with tokenizer.pad_token_id or 0
|
||||||
|
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||||
|
padded_token_ids.append(padded)
|
||||||
|
|
||||||
|
# Convert to MLX array with batch dimension
|
||||||
|
input_ids = mx.array(padded_token_ids)
|
||||||
|
|
||||||
|
# Get embeddings for the batch
|
||||||
|
embeddings = model(input_ids)
|
||||||
|
|
||||||
|
# Mean pooling for each sequence in the batch
|
||||||
|
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||||
|
|
||||||
|
# Convert batch embeddings to numpy
|
||||||
|
for j in range(len(batch_chunks)):
|
||||||
|
pooled_list = pooled[j].tolist() # Convert to list
|
||||||
|
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||||
|
all_embeddings.append(pooled_numpy)
|
||||||
|
|
||||||
|
# Stack numpy arrays
|
||||||
|
return np.stack(all_embeddings)
|
||||||
@@ -1,14 +1,21 @@
|
|||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import atexit
|
import atexit
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import zmq
|
import os
|
||||||
import msgpack
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import select
|
import psutil
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||||
|
format="%(levelname)s - %(name)s - %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
def _check_port(port: int) -> bool:
|
||||||
@@ -17,151 +24,135 @@ def _check_port(port: int) -> bool:
|
|||||||
return s.connect_ex(("localhost", port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
|
def _check_process_matches_config(
|
||||||
|
port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the existing server on the port is using the correct meta file.
|
Check if the process using the port matches our expected model and passages file.
|
||||||
Returns True if the server has the right meta path, False otherwise.
|
Returns True if matches, False otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
context = zmq.Context()
|
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||||
socket = context.socket(zmq.REQ)
|
if not _is_process_listening_on_port(proc, port):
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
continue
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
|
||||||
|
|
||||||
# Send a special control message to query the server's meta path
|
cmdline = proc.info["cmdline"]
|
||||||
control_request = ["__QUERY_META_PATH__"]
|
if not cmdline:
|
||||||
request_bytes = msgpack.packb(control_request)
|
continue
|
||||||
socket.send(request_bytes)
|
|
||||||
|
|
||||||
# Wait for response
|
return _check_cmdline_matches_config(
|
||||||
response_bytes = socket.recv()
|
cmdline, port, expected_model, expected_passages_file
|
||||||
response = msgpack.unpackb(response_bytes)
|
)
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
# Check if the response contains the meta path and if it matches
|
|
||||||
if isinstance(response, list) and len(response) > 0:
|
|
||||||
server_meta_path = response[0]
|
|
||||||
# Normalize paths for comparison
|
|
||||||
expected_path = Path(expected_meta_path).resolve()
|
|
||||||
server_path = Path(server_meta_path).resolve() if server_meta_path else None
|
|
||||||
return server_path == expected_path
|
|
||||||
|
|
||||||
|
logger.debug(f"No process found listening on port {port}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"WARNING: Could not query server meta path on port {port}: {e}")
|
logger.warning(f"Could not check process on port {port}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
|
def _is_process_listening_on_port(proc, port: int) -> bool:
|
||||||
"""
|
"""Check if a process is listening on the given port."""
|
||||||
Send a control message to update the server's meta path.
|
|
||||||
Returns True if successful, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
context = zmq.Context()
|
connections = proc.net_connections()
|
||||||
socket = context.socket(zmq.REQ)
|
for conn in connections:
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
return True
|
||||||
|
|
||||||
# Send a control message to update the meta path
|
|
||||||
control_request = ["__UPDATE_META_PATH__", new_meta_path]
|
|
||||||
request_bytes = msgpack.packb(control_request)
|
|
||||||
socket.send(request_bytes)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response_bytes = socket.recv()
|
|
||||||
response = msgpack.unpackb(response_bytes)
|
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
# Check if the update was successful
|
|
||||||
if isinstance(response, list) and len(response) > 0:
|
|
||||||
return response[0] == "SUCCESS"
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Could not update server meta path on port {port}: {e}")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _check_server_model(port: int, expected_model: str) -> bool:
|
def _check_cmdline_matches_config(
|
||||||
|
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if command line matches our expected configuration."""
|
||||||
|
cmdline_str = " ".join(cmdline)
|
||||||
|
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
||||||
|
|
||||||
|
# Check if it's our embedding server
|
||||||
|
is_embedding_server = any(
|
||||||
|
server_type in cmdline_str
|
||||||
|
for server_type in [
|
||||||
|
"embedding_server",
|
||||||
|
"leann_backend_diskann.embedding_server",
|
||||||
|
"leann_backend_hnsw.hnsw_embedding_server",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_embedding_server:
|
||||||
|
logger.debug(f"Process on port {port} is not our embedding server")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check model name
|
||||||
|
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
||||||
|
|
||||||
|
# Check passages file if provided
|
||||||
|
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||||
|
|
||||||
|
result = model_matches and passages_matches
|
||||||
|
logger.debug(
|
||||||
|
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected model."""
|
||||||
|
if "--model-name" not in cmdline:
|
||||||
|
return False
|
||||||
|
|
||||||
|
model_idx = cmdline.index("--model-name")
|
||||||
|
if model_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_model = cmdline[model_idx + 1]
|
||||||
|
return actual_model == expected_model
|
||||||
|
|
||||||
|
|
||||||
|
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected passages file."""
|
||||||
|
if "--passages-file" not in cmdline:
|
||||||
|
return False # Expected but not found
|
||||||
|
|
||||||
|
passages_idx = cmdline.index("--passages-file")
|
||||||
|
if passages_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_passages = cmdline[passages_idx + 1]
|
||||||
|
expected_path = Path(expected_passages_file).resolve()
|
||||||
|
actual_path = Path(actual_passages).resolve()
|
||||||
|
return actual_path == expected_path
|
||||||
|
|
||||||
|
|
||||||
|
def _find_compatible_port_or_next_available(
|
||||||
|
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||||
|
) -> tuple[int, bool]:
|
||||||
"""
|
"""
|
||||||
Check if the existing server on the port is using the correct embedding model.
|
Find a port that either has a compatible server or is available.
|
||||||
Returns True if the server has the right model, False otherwise.
|
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||||
"""
|
"""
|
||||||
try:
|
for port in range(start_port, start_port + max_attempts):
|
||||||
context = zmq.Context()
|
if not _check_port(port):
|
||||||
socket = context.socket(zmq.REQ)
|
# Port is available
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
return port, False
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
|
||||||
|
|
||||||
# Send a special control message to query the server's model
|
# Port is in use, check if it's compatible
|
||||||
control_request = ["__QUERY_MODEL__"]
|
if _check_process_matches_config(port, model_name, passages_file):
|
||||||
request_bytes = msgpack.packb(control_request)
|
logger.info(f"Found compatible server on port {port}")
|
||||||
socket.send(request_bytes)
|
return port, True
|
||||||
|
else:
|
||||||
|
logger.info(f"Port {port} has incompatible server, trying next port...")
|
||||||
|
|
||||||
# Wait for response
|
raise RuntimeError(
|
||||||
response_bytes = socket.recv()
|
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||||
response = msgpack.unpackb(response_bytes)
|
)
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
# Check if the response contains the model name and if it matches
|
|
||||||
if isinstance(response, list) and len(response) > 0:
|
|
||||||
server_model = response[0]
|
|
||||||
return server_model == expected_model
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"WARNING: Could not query server model on port {port}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _update_server_model(port: int, new_model: str) -> bool:
|
|
||||||
"""
|
|
||||||
Send a control message to update the server's embedding model.
|
|
||||||
Returns True if successful, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.REQ)
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
|
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
|
||||||
|
|
||||||
# Send a control message to update the model
|
|
||||||
control_request = ["__UPDATE_MODEL__", new_model]
|
|
||||||
request_bytes = msgpack.packb(control_request)
|
|
||||||
socket.send(request_bytes)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response_bytes = socket.recv()
|
|
||||||
response = msgpack.unpackb(response_bytes)
|
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
# Check if the update was successful
|
|
||||||
if isinstance(response, list) and len(response) > 0:
|
|
||||||
return response[0] == "SUCCESS"
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Could not update server model on port {port}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
"""
|
"""
|
||||||
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backend_module_name: str):
|
def __init__(self, backend_module_name: str):
|
||||||
@@ -175,246 +166,183 @@ class EmbeddingServerManager:
|
|||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: Optional[subprocess.Popen] = None
|
self.server_process: Optional[subprocess.Popen] = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: Optional[int] = None
|
||||||
atexit.register(self.stop_server)
|
self._atexit_registered = False
|
||||||
|
|
||||||
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
|
def start_server(
|
||||||
|
self,
|
||||||
|
port: int,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[bool, int]:
|
||||||
"""
|
"""
|
||||||
Starts the embedding server process.
|
Starts the embedding server process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port (int): The ZMQ port for the server.
|
port (int): The preferred ZMQ port for the server.
|
||||||
model_name (str): The name of the embedding model to use.
|
model_name (str): The name of the embedding model to use.
|
||||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
|
**kwargs: Additional arguments for the server.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the server is started successfully or already running, False otherwise.
|
tuple[bool, int]: (success, actual_port_used)
|
||||||
"""
|
"""
|
||||||
if self.server_process and self.server_process.poll() is None:
|
passages_file = kwargs.get("passages_file")
|
||||||
# Even if we have a running process, check if model/meta path match
|
assert isinstance(passages_file, str), "passages_file must be a string"
|
||||||
if self.server_port is not None:
|
|
||||||
port_in_use = _check_port(self.server_port)
|
|
||||||
if port_in_use:
|
|
||||||
print(
|
|
||||||
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check model compatibility
|
# Check if we have a compatible running server
|
||||||
model_matches = _check_server_model(self.server_port, model_name)
|
if self._has_compatible_running_server(model_name, passages_file):
|
||||||
if model_matches:
|
assert self.server_port is not None, (
|
||||||
print(
|
"a compatible running server should set server_port"
|
||||||
f"✅ Existing server already using correct model: {model_name}"
|
)
|
||||||
)
|
return True, self.server_port
|
||||||
|
|
||||||
# Still check meta path if provided
|
# Find available port (compatible or free)
|
||||||
passages_file = kwargs.get("passages_file")
|
try:
|
||||||
if passages_file and str(passages_file).endswith(
|
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||||
".meta.json"
|
port, model_name, passages_file
|
||||||
):
|
)
|
||||||
meta_matches = _check_server_meta_path(
|
except RuntimeError as e:
|
||||||
self.server_port, str(passages_file)
|
logger.error(str(e))
|
||||||
)
|
return False, port
|
||||||
if not meta_matches:
|
|
||||||
print("⚠️ Updating meta path to: {passages_file}")
|
|
||||||
_update_server_meta_path(
|
|
||||||
self.server_port, str(passages_file)
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
if is_compatible:
|
||||||
else:
|
logger.info(f"Using existing compatible server on port {actual_port}")
|
||||||
print(
|
self.server_port = actual_port
|
||||||
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
self.server_process = None # We don't own this process
|
||||||
)
|
return True, actual_port
|
||||||
if not _update_server_model(self.server_port, model_name):
|
|
||||||
print(
|
|
||||||
"❌ Failed to update existing server model. Restarting server..."
|
|
||||||
)
|
|
||||||
self.stop_server()
|
|
||||||
# Continue to start new server below
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"✅ Successfully updated existing server model to: {model_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Also check meta path if provided
|
if actual_port != port:
|
||||||
passages_file = kwargs.get("passages_file")
|
logger.info(f"Using port {actual_port} instead of {port}")
|
||||||
if passages_file and str(passages_file).endswith(
|
|
||||||
".meta.json"
|
|
||||||
):
|
|
||||||
meta_matches = _check_server_meta_path(
|
|
||||||
self.server_port, str(passages_file)
|
|
||||||
)
|
|
||||||
if not meta_matches:
|
|
||||||
print("⚠️ Updating meta path to: {passages_file}")
|
|
||||||
_update_server_meta_path(
|
|
||||||
self.server_port, str(passages_file)
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
# Start new server
|
||||||
else:
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
# Server process exists but port not responding - restart
|
|
||||||
print("⚠️ Server process exists but not responding. Restarting...")
|
|
||||||
self.stop_server()
|
|
||||||
# Continue to start new server below
|
|
||||||
else:
|
|
||||||
# No port stored - restart
|
|
||||||
print("⚠️ No port information stored. Restarting server...")
|
|
||||||
self.stop_server()
|
|
||||||
# Continue to start new server below
|
|
||||||
|
|
||||||
if _check_port(port):
|
def _has_compatible_running_server(
|
||||||
# Port is in use, check if it's using the correct meta file and model
|
self, model_name: str, passages_file: str
|
||||||
passages_file = kwargs.get("passages_file")
|
) -> bool:
|
||||||
|
"""Check if we have a compatible running server."""
|
||||||
|
if not (
|
||||||
|
self.server_process
|
||||||
|
and self.server_process.poll() is None
|
||||||
|
and self.server_port
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
print(f"INFO: Port {port} is in use. Checking server compatibility...")
|
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||||
|
logger.info(
|
||||||
# Check model compatibility first
|
f"Existing server process (PID {self.server_process.pid}) is compatible"
|
||||||
model_matches = _check_server_model(port, model_name)
|
)
|
||||||
if model_matches:
|
|
||||||
print(
|
|
||||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
|
||||||
)
|
|
||||||
if not _update_server_model(port, model_name):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
|
||||||
)
|
|
||||||
print(f"✅ Successfully updated server model to: {model_name}")
|
|
||||||
|
|
||||||
# Check meta path compatibility if provided
|
|
||||||
if passages_file and str(passages_file).endswith(".meta.json"):
|
|
||||||
meta_matches = _check_server_meta_path(port, str(passages_file))
|
|
||||||
if not meta_matches:
|
|
||||||
print(
|
|
||||||
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
|
|
||||||
)
|
|
||||||
if not _update_server_meta_path(port, str(passages_file)):
|
|
||||||
raise RuntimeError(
|
|
||||||
"❌ Failed to update server meta path. This may cause data synchronization issues."
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"✅ Successfully updated server meta path to: {passages_file}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"✅ Server on port {port} is compatible and ready to use.")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
print(
|
logger.info(
|
||||||
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
|
"Existing server process is incompatible. Should start a new server."
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _start_new_server(
|
||||||
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
|
) -> tuple[bool, int]:
|
||||||
|
"""Start a new embedding server on the given port."""
|
||||||
|
logger.info(f"Starting embedding server on port {port}...")
|
||||||
|
|
||||||
|
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
command = [
|
self._launch_server_process(command, port)
|
||||||
sys.executable,
|
return self._wait_for_server_ready(port)
|
||||||
"-m",
|
|
||||||
self.backend_module_name,
|
|
||||||
"--zmq-port",
|
|
||||||
str(port),
|
|
||||||
"--model-name",
|
|
||||||
model_name,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add extra arguments for specific backends
|
|
||||||
if "passages_file" in kwargs and kwargs["passages_file"]:
|
|
||||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
|
||||||
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
|
|
||||||
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
|
||||||
if embedding_mode != "sentence-transformers":
|
|
||||||
command.extend(["--embedding-mode", embedding_mode])
|
|
||||||
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
|
|
||||||
command.extend(["--disable-warmup"])
|
|
||||||
|
|
||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
|
||||||
print(f"INFO: Running command from project root: {project_root}")
|
|
||||||
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
|
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
cwd=project_root,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
|
|
||||||
text=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
bufsize=1, # Line buffered
|
|
||||||
universal_newlines=True,
|
|
||||||
)
|
|
||||||
self.server_port = port
|
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
|
||||||
|
|
||||||
max_wait, wait_interval = 120, 0.5
|
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
|
||||||
if _check_port(port):
|
|
||||||
print("✅ Embedding server is up and ready for this session.")
|
|
||||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
|
||||||
log_thread.start()
|
|
||||||
return True
|
|
||||||
if self.server_process.poll() is not None:
|
|
||||||
print(
|
|
||||||
"❌ ERROR: Server process terminated unexpectedly during startup."
|
|
||||||
)
|
|
||||||
self._print_recent_output()
|
|
||||||
return False
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
|
|
||||||
)
|
|
||||||
self.stop_server()
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
logger.error(f"Failed to start embedding server: {e}")
|
||||||
return False
|
return False, port
|
||||||
|
|
||||||
def _print_recent_output(self):
|
def _build_server_command(
|
||||||
"""Print any recent output from the server process."""
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
if not self.server_process or not self.server_process.stdout:
|
) -> list:
|
||||||
return
|
"""Build the command to start the embedding server."""
|
||||||
try:
|
command = [
|
||||||
# Read any available output
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
self.backend_module_name,
|
||||||
|
"--zmq-port",
|
||||||
|
str(port),
|
||||||
|
"--model-name",
|
||||||
|
model_name,
|
||||||
|
]
|
||||||
|
|
||||||
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
if kwargs.get("passages_file"):
|
||||||
output = self.server_process.stdout.read()
|
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||||
if output:
|
if embedding_mode != "sentence-transformers":
|
||||||
print(f"[{self.backend_module_name} OUTPUT]: {output}")
|
command.extend(["--embedding-mode", embedding_mode])
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading server output: {e}")
|
|
||||||
|
|
||||||
def _log_monitor(self):
|
return command
|
||||||
"""Monitors and prints the server's stdout and stderr."""
|
|
||||||
if not self.server_process:
|
def _launch_server_process(self, command: list, port: int) -> None:
|
||||||
return
|
"""Launch the server process."""
|
||||||
try:
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
if self.server_process.stdout:
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
while True:
|
|
||||||
line = self.server_process.stdout.readline()
|
# Let server output go directly to console
|
||||||
if not line:
|
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||||
break
|
self.server_process = subprocess.Popen(
|
||||||
print(
|
command,
|
||||||
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
|
cwd=project_root,
|
||||||
)
|
stdout=None, # Direct to console
|
||||||
except Exception as e:
|
stderr=None, # Direct to console
|
||||||
print(f"Log monitor error: {e}")
|
)
|
||||||
|
self.server_port = port
|
||||||
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
|
# Register atexit callback only when we actually start a process
|
||||||
|
if not self._atexit_registered:
|
||||||
|
# Use a lambda to avoid issues with bound methods
|
||||||
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
|
self._atexit_registered = True
|
||||||
|
|
||||||
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
|
"""Wait for the server to be ready."""
|
||||||
|
max_wait, wait_interval = 120, 0.5
|
||||||
|
for _ in range(int(max_wait / wait_interval)):
|
||||||
|
if _check_port(port):
|
||||||
|
logger.info("Embedding server is ready!")
|
||||||
|
return True, port
|
||||||
|
|
||||||
|
if self.server_process and self.server_process.poll() is not None:
|
||||||
|
logger.error("Server terminated during startup.")
|
||||||
|
return False, port
|
||||||
|
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
|
||||||
|
logger.error(f"Server failed to start within {max_wait} seconds.")
|
||||||
|
self.stop_server()
|
||||||
|
return False, port
|
||||||
|
|
||||||
def stop_server(self):
|
def stop_server(self):
|
||||||
"""Stops the embedding server process if it's running."""
|
"""Stops the embedding server process if it's running."""
|
||||||
if self.server_process and self.server_process.poll() is None:
|
if not self.server_process:
|
||||||
print(
|
return
|
||||||
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
|
|
||||||
|
if self.server_process.poll() is not None:
|
||||||
|
# Process already terminated
|
||||||
|
self.server_process = None
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
|
)
|
||||||
|
self.server_process.terminate()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.server_process.wait(timeout=5)
|
||||||
|
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.warning(
|
||||||
|
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||||
)
|
)
|
||||||
self.server_process.terminate()
|
self.server_process.kill()
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=5)
|
# Clean up process resources to prevent resource tracker warnings
|
||||||
print("INFO: Server process terminated.")
|
try:
|
||||||
except subprocess.TimeoutExpired:
|
self.server_process.wait() # Ensure process is fully cleaned up
|
||||||
print(
|
except Exception:
|
||||||
"WARNING: Server process did not terminate gracefully, killing it."
|
pass
|
||||||
)
|
|
||||||
self.server_process.kill()
|
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Dict, Any, List, Literal
|
from typing import Dict, Any, List, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
class LeannBackendBuilderInterface(ABC):
|
class LeannBackendBuilderInterface(ABC):
|
||||||
@@ -34,6 +34,13 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _ensure_server_running(
|
||||||
|
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||||
|
) -> int:
|
||||||
|
"""Ensure server is running"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -44,7 +51,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
@@ -57,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
zmq_port: ZMQ port for embedding server communication
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
**kwargs: Backend-specific parameters
|
**kwargs: Backend-specific parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -67,7 +74,10 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_query_embedding(
|
def compute_query_embedding(
|
||||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
self,
|
||||||
|
query: str,
|
||||||
|
use_server_if_available: bool = True,
|
||||||
|
zmq_port: Optional[int] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
|
|||||||
@@ -7,30 +7,37 @@ import importlib.metadata
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
|
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_backend(name: str):
|
def register_backend(name: str):
|
||||||
"""A decorator to register a new backend class."""
|
"""A decorator to register a new backend class."""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
print(f"INFO: Registering backend '{name}'")
|
print(f"INFO: Registering backend '{name}'")
|
||||||
BACKEND_REGISTRY[name] = cls
|
BACKEND_REGISTRY[name] = cls
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def autodiscover_backends():
|
def autodiscover_backends():
|
||||||
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
||||||
print("INFO: Starting backend auto-discovery...")
|
# print("INFO: Starting backend auto-discovery...")
|
||||||
discovered_backends = []
|
discovered_backends = []
|
||||||
for dist in importlib.metadata.distributions():
|
for dist in importlib.metadata.distributions():
|
||||||
dist_name = dist.metadata['name']
|
dist_name = dist.metadata["name"]
|
||||||
if dist_name.startswith('leann-backend-'):
|
if dist_name.startswith("leann-backend-"):
|
||||||
backend_module_name = dist_name.replace('-', '_')
|
backend_module_name = dist_name.replace("-", "_")
|
||||||
discovered_backends.append(backend_module_name)
|
discovered_backends.append(backend_module_name)
|
||||||
|
|
||||||
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
for backend_module_name in sorted(
|
||||||
|
discovered_backends
|
||||||
|
): # sort for deterministic loading
|
||||||
try:
|
try:
|
||||||
importlib.import_module(backend_module_name)
|
importlib.import_module(backend_module_name)
|
||||||
# Registration message is printed by the decorator
|
# Registration message is printed by the decorator
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
print("INFO: Backend auto-discovery finished.")
|
pass
|
||||||
|
# print("INFO: Backend auto-discovery finished.")
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import pickle
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Literal
|
from typing import Dict, Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -43,10 +42,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.label_map = self._load_label_map()
|
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
backend_module_name=backend_module_name
|
backend_module_name=backend_module_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_meta(self) -> Dict[str, Any]:
|
def _load_meta(self) -> Dict[str, Any]:
|
||||||
@@ -58,17 +57,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
with open(meta_path, "r", encoding="utf-8") as f:
|
with open(meta_path, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def _load_label_map(self) -> Dict[int, str]:
|
|
||||||
"""Loads the mapping from integer IDs to string IDs."""
|
|
||||||
label_map_file = self.index_dir / "leann.labels.map"
|
|
||||||
if not label_map_file.exists():
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
with open(label_map_file, "rb") as f:
|
|
||||||
return pickle.load(f)
|
|
||||||
|
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(
|
||||||
self, passages_source_file: str, port: int, **kwargs
|
self, passages_source_file: str, port: int, **kwargs
|
||||||
) -> None:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Ensures the embedding server is running if recompute is needed.
|
Ensures the embedding server is running if recompute is needed.
|
||||||
This is a helper for subclasses.
|
This is a helper for subclasses.
|
||||||
@@ -78,21 +69,26 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
"Cannot use recompute mode without 'embedding_model' in meta.json."
|
"Cannot use recompute mode without 'embedding_model' in meta.json."
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||||
|
|
||||||
server_started = self.embedding_server_manager.start_server(
|
|
||||||
port=port,
|
port=port,
|
||||||
model_name=self.embedding_model,
|
model_name=self.embedding_model,
|
||||||
|
embedding_mode=self.embedding_mode,
|
||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
distance_metric=kwargs.get("distance_metric"),
|
distance_metric=kwargs.get("distance_metric"),
|
||||||
embedding_mode=embedding_mode,
|
|
||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
raise RuntimeError(
|
||||||
|
f"Failed to start embedding server on port {actual_port}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return actual_port
|
||||||
|
|
||||||
def compute_query_embedding(
|
def compute_query_embedding(
|
||||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
self,
|
||||||
|
query: str,
|
||||||
|
use_server_if_available: bool = True,
|
||||||
|
zmq_port: int = 5557,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embedding for a query string.
|
Compute embedding for a query string.
|
||||||
@@ -106,12 +102,20 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
Query embedding as numpy array
|
Query embedding as numpy array
|
||||||
"""
|
"""
|
||||||
# Try to use embedding server if available and requested
|
# Try to use embedding server if available and requested
|
||||||
if (
|
if use_server_if_available:
|
||||||
use_server_if_available
|
|
||||||
and self.embedding_server_manager
|
|
||||||
and self.embedding_server_manager.server_process
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
|
# TODO: Maybe we can directly use this port here?
|
||||||
|
# For this internal method, it's ok to assume that the server is running
|
||||||
|
# on that port?
|
||||||
|
|
||||||
|
# Ensure we have a server with passages_file for compatibility
|
||||||
|
passages_source_file = (
|
||||||
|
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
|
)
|
||||||
|
zmq_port = self._ensure_server_running(
|
||||||
|
str(passages_source_file), zmq_port
|
||||||
|
)
|
||||||
|
|
||||||
return self._compute_embedding_via_server([query], zmq_port)[
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
0:1
|
0:1
|
||||||
] # Return (1, D) shape
|
] # Return (1, D) shape
|
||||||
@@ -120,7 +124,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
print("⏭️ Falling back to direct model loading...")
|
print("⏭️ Falling back to direct model loading...")
|
||||||
|
|
||||||
# Fallback to direct computation
|
# Fallback to direct computation
|
||||||
from .api import compute_embeddings
|
from .embedding_compute import compute_embeddings
|
||||||
|
|
||||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||||
@@ -167,7 +171,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -181,7 +185,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
zmq_port: ZMQ port for embedding server communication
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
40
packages/leann/README.md
Normal file
40
packages/leann/README.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# LEANN - The smallest vector index in the world
|
||||||
|
|
||||||
|
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Default installation (HNSW backend, recommended)
|
||||||
|
uv pip install leann
|
||||||
|
|
||||||
|
# With DiskANN backend (for large-scale deployments)
|
||||||
|
uv pip install leann[diskann]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
|
# Build an index
|
||||||
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
|
builder.build_index("my_index.leann")
|
||||||
|
|
||||||
|
# Search
|
||||||
|
searcher = LeannSearcher("my_index.leann")
|
||||||
|
results = searcher.search("storage savings", top_k=3)
|
||||||
|
|
||||||
|
# Chat with your data
|
||||||
|
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
|
||||||
|
response = chat.ask("How much storage does LEANN save?")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License
|
||||||
12
packages/leann/__init__.py
Normal file
12
packages/leann/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
LEANN - Low-storage Embedding Approximation for Neural Networks
|
||||||
|
|
||||||
|
A revolutionary vector database that democratizes personal AI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
# Re-export main API from leann-core
|
||||||
|
from leann_core import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
|
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat"]
|
||||||
42
packages/leann/pyproject.toml
Normal file
42
packages/leann/pyproject.toml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "leann"
|
||||||
|
version = "0.1.2"
|
||||||
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
license = { text = "MIT" }
|
||||||
|
authors = [
|
||||||
|
{ name = "LEANN Team" }
|
||||||
|
]
|
||||||
|
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default installation: core + hnsw
|
||||||
|
dependencies = [
|
||||||
|
"leann-core>=0.1.0",
|
||||||
|
"leann-backend-hnsw>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
diskann = [
|
||||||
|
"leann-backend-diskann>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/yourusername/leann"
|
||||||
|
Documentation = "https://leann.readthedocs.io"
|
||||||
|
Repository = "https://github.com/yourusername/leann"
|
||||||
|
Issues = "https://github.com/yourusername/leann/issues"
|
||||||
@@ -33,8 +33,9 @@ dependencies = [
|
|||||||
"msgpack>=1.1.1",
|
"msgpack>=1.1.1",
|
||||||
"llama-index-vector-stores-faiss>=0.4.0",
|
"llama-index-vector-stores-faiss>=0.4.0",
|
||||||
"llama-index-embeddings-huggingface>=0.5.5",
|
"llama-index-embeddings-huggingface>=0.5.5",
|
||||||
"mlx>=0.26.3",
|
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||||
"mlx-lm>=0.26.0",
|
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||||
|
"psutil>=5.8.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
87
scripts/build_and_test.sh
Executable file
87
scripts/build_and_test.sh
Executable file
@@ -0,0 +1,87 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Manual build and test script for local testing
|
||||||
|
|
||||||
|
PACKAGE=${1:-"all"} # Default to all packages
|
||||||
|
|
||||||
|
echo "Building package: $PACKAGE"
|
||||||
|
|
||||||
|
# Ensure we're in a virtual environment
|
||||||
|
if [ -z "$VIRTUAL_ENV" ]; then
|
||||||
|
echo "Error: Please activate a virtual environment first"
|
||||||
|
echo "Run: source .venv/bin/activate (or .venv/bin/activate.fish for fish shell)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install build tools
|
||||||
|
uv pip install build twine delocate auditwheel scikit-build-core cmake pybind11 numpy
|
||||||
|
|
||||||
|
build_package() {
|
||||||
|
local package_dir=$1
|
||||||
|
local package_name=$(basename $package_dir)
|
||||||
|
|
||||||
|
echo "Building $package_name..."
|
||||||
|
cd $package_dir
|
||||||
|
|
||||||
|
# Clean previous builds
|
||||||
|
rm -rf dist/ build/ _skbuild/
|
||||||
|
|
||||||
|
# Build directly with pip wheel (avoids sdist issues)
|
||||||
|
pip wheel . --no-deps -w dist
|
||||||
|
|
||||||
|
# Repair wheel for binary packages
|
||||||
|
if [[ "$package_name" != "leann-core" ]] && [[ "$package_name" != "leann" ]]; then
|
||||||
|
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||||
|
# For macOS
|
||||||
|
for wheel in dist/*.whl; do
|
||||||
|
if [[ -f "$wheel" ]]; then
|
||||||
|
delocate-wheel -w dist_repaired -v "$wheel"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if [[ -d dist_repaired ]]; then
|
||||||
|
rm -rf dist/*.whl
|
||||||
|
mv dist_repaired/*.whl dist/
|
||||||
|
rmdir dist_repaired
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
# For Linux
|
||||||
|
for wheel in dist/*.whl; do
|
||||||
|
if [[ -f "$wheel" ]]; then
|
||||||
|
auditwheel repair "$wheel" -w dist_repaired
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if [[ -d dist_repaired ]]; then
|
||||||
|
rm -rf dist/*.whl
|
||||||
|
mv dist_repaired/*.whl dist/
|
||||||
|
rmdir dist_repaired
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Built wheels in $package_dir/dist/"
|
||||||
|
ls -la dist/
|
||||||
|
cd - > /dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build specific package or all
|
||||||
|
if [ "$PACKAGE" == "diskann" ]; then
|
||||||
|
build_package "packages/leann-backend-diskann"
|
||||||
|
elif [ "$PACKAGE" == "hnsw" ]; then
|
||||||
|
build_package "packages/leann-backend-hnsw"
|
||||||
|
elif [ "$PACKAGE" == "core" ]; then
|
||||||
|
build_package "packages/leann-core"
|
||||||
|
elif [ "$PACKAGE" == "meta" ]; then
|
||||||
|
build_package "packages/leann"
|
||||||
|
elif [ "$PACKAGE" == "all" ]; then
|
||||||
|
build_package "packages/leann-core"
|
||||||
|
build_package "packages/leann-backend-hnsw"
|
||||||
|
build_package "packages/leann-backend-diskann"
|
||||||
|
build_package "packages/leann"
|
||||||
|
else
|
||||||
|
echo "Unknown package: $PACKAGE"
|
||||||
|
echo "Usage: $0 [diskann|hnsw|core|meta|all]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "\nBuild complete! Test with:"
|
||||||
|
echo "uv pip install packages/*/dist/*.whl"
|
||||||
31
scripts/bump_version.sh
Executable file
31
scripts/bump_version.sh
Executable file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [ $# -eq 0 ]; then
|
||||||
|
echo "Usage: $0 <new_version>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
NEW_VERSION=$1
|
||||||
|
|
||||||
|
# Get the directory where the script is located
|
||||||
|
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
|
PROJECT_ROOT="$( cd "$SCRIPT_DIR/.." && pwd )"
|
||||||
|
|
||||||
|
# Update all pyproject.toml files
|
||||||
|
echo "Updating versions in $PROJECT_ROOT/packages/"
|
||||||
|
|
||||||
|
# Use different sed syntax for macOS vs Linux
|
||||||
|
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||||
|
# Update version fields
|
||||||
|
find "$PROJECT_ROOT/packages" -name "pyproject.toml" -exec sed -i '' "s/version = \".*\"/version = \"$NEW_VERSION\"/" {} \;
|
||||||
|
# Update leann-core dependencies
|
||||||
|
find "$PROJECT_ROOT/packages" -name "pyproject.toml" -exec sed -i '' "s/leann-core==[0-9.]*/leann-core==$NEW_VERSION/" {} \;
|
||||||
|
else
|
||||||
|
# Update version fields
|
||||||
|
find "$PROJECT_ROOT/packages" -name "pyproject.toml" -exec sed -i "s/version = \".*\"/version = \"$NEW_VERSION\"/" {} \;
|
||||||
|
# Update leann-core dependencies
|
||||||
|
find "$PROJECT_ROOT/packages" -name "pyproject.toml" -exec sed -i "s/leann-core==[0-9.]*/leann-core==$NEW_VERSION/" {} \;
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✅ Version updated to $NEW_VERSION"
|
||||||
|
echo "✅ Dependencies updated to use leann-core==$NEW_VERSION"
|
||||||
18
scripts/release.sh
Executable file
18
scripts/release.sh
Executable file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [ $# -eq 0 ]; then
|
||||||
|
echo "Usage: $0 <version>"
|
||||||
|
echo "Example: $0 0.1.1"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
VERSION=$1
|
||||||
|
|
||||||
|
# Update version
|
||||||
|
./scripts/bump_version.sh $VERSION
|
||||||
|
|
||||||
|
# Commit and push
|
||||||
|
git add . && git commit -m "chore: bump version to $VERSION" && git push
|
||||||
|
|
||||||
|
# Create release (triggers CI)
|
||||||
|
gh release create v$VERSION --generate-notes
|
||||||
30
scripts/upload_to_pypi.sh
Executable file
30
scripts/upload_to_pypi.sh
Executable file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Manual upload script for testing
|
||||||
|
|
||||||
|
TARGET=${1:-"test"} # Default to test pypi
|
||||||
|
|
||||||
|
if [ "$TARGET" != "test" ] && [ "$TARGET" != "prod" ]; then
|
||||||
|
echo "Usage: $0 [test|prod]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for built packages
|
||||||
|
if ! ls packages/*/dist/*.whl >/dev/null 2>&1; then
|
||||||
|
echo "No built packages found. Run ./scripts/build_and_test.sh first"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$TARGET" == "test" ]; then
|
||||||
|
echo "Uploading to Test PyPI..."
|
||||||
|
twine upload --repository testpypi packages/*/dist/*
|
||||||
|
else
|
||||||
|
echo "Uploading to PyPI..."
|
||||||
|
echo "Are you sure? (y/N)"
|
||||||
|
read -r response
|
||||||
|
if [ "$response" == "y" ]; then
|
||||||
|
twine upload packages/*/dist/*
|
||||||
|
else
|
||||||
|
echo "Cancelled"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
@@ -12,7 +12,7 @@ else:
|
|||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
|
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
|
||||||
use_mlx=True,
|
embedding_mode="mlx",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Add documents
|
# 2. Add documents
|
||||||
|
|||||||
Reference in New Issue
Block a user