Compare commits
68 Commits
perf-build
...
v0.1.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 | ||
|
|
a97d3ada1c | ||
|
|
6217bb5638 | ||
|
|
2760e99e18 | ||
|
|
0544f96b79 | ||
|
|
2ebb29de65 | ||
|
|
43762d44c7 | ||
|
|
cdaf0c98be | ||
|
|
aa9a14a917 | ||
|
|
9efcc6d95c | ||
|
|
f3f5d91207 | ||
|
|
6070160959 | ||
|
|
43155d2811 | ||
|
|
d3f85678ec | ||
|
|
2a96d05b21 | ||
|
|
851e888535 | ||
|
|
90120d4dff | ||
|
|
8513471573 | ||
|
|
71e5f1774c | ||
|
|
870a443446 | ||
|
|
cefaa2a4cc | ||
|
|
ab72a2ab9d | ||
|
|
046d457d22 | ||
|
|
7fd0a30fee | ||
|
|
c2f35c8e73 | ||
|
|
573313f0b6 | ||
|
|
f7af6805fa | ||
|
|
966de3a399 | ||
|
|
8a75829f3a | ||
|
|
0f7e34b9e2 | ||
|
|
be0322b616 | ||
|
|
232a525a62 | ||
|
|
587ce65cf6 | ||
|
|
ccf6c8bfd7 | ||
|
|
c112956d2d | ||
|
|
b3970793cf | ||
|
|
727724990e | ||
|
|
530f6e4af5 | ||
|
|
2f224f5793 | ||
|
|
1b6272ce0e |
256
.github/workflows/build-and-publish.yml
vendored
Normal file
256
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,256 @@
|
||||
name: Build and Publish to PyPI
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
publish:
|
||||
description: 'Publish to PyPI'
|
||||
required: true
|
||||
default: 'false'
|
||||
type: choice
|
||||
options:
|
||||
- 'false'
|
||||
- 'test'
|
||||
- 'prod'
|
||||
|
||||
jobs:
|
||||
# Build pure Python package: leann-core
|
||||
build-core:
|
||||
name: Build leann-core
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system build twine
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: leann-core-dist
|
||||
path: packages/leann-core/dist/
|
||||
|
||||
# Build binary package: leann-backend-hnsw (default backend)
|
||||
build-hnsw:
|
||||
name: Build leann-backend-hnsw
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev libzmq3-dev \
|
||||
pkg-config libopenblas-dev patchelf
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
brew install libomp boost zeromq
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy swig
|
||||
uv pip install --system auditwheel delocate
|
||||
|
||||
- name: Build wheel
|
||||
run: |
|
||||
cd packages/leann-backend-hnsw
|
||||
uv build --wheel
|
||||
|
||||
- name: Repair wheel (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
cd packages/leann-backend-hnsw
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
|
||||
- name: Repair wheel (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
cd packages/leann-backend-hnsw
|
||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: hnsw-${{ matrix.os }}-py${{ matrix.python-version }}
|
||||
path: packages/leann-backend-hnsw/dist/
|
||||
|
||||
# Build binary package: leann-backend-diskann (multi-platform)
|
||||
build-diskann:
|
||||
name: Build leann-backend-diskann
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev libaio-dev libzmq3-dev \
|
||||
protobuf-compiler libprotobuf-dev libabsl-dev patchelf
|
||||
|
||||
# Install Intel MKL using Intel's installer
|
||||
wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
brew install libomp boost zeromq protobuf
|
||||
# MKL is not available on Homebrew, but DiskANN can work without it
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --system auditwheel
|
||||
else
|
||||
uv pip install --system delocate
|
||||
fi
|
||||
|
||||
- name: Build wheel
|
||||
run: |
|
||||
cd packages/leann-backend-diskann
|
||||
uv build --wheel
|
||||
|
||||
- name: Repair wheel (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
cd packages/leann-backend-diskann
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
|
||||
- name: Repair wheel (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
cd packages/leann-backend-diskann
|
||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: diskann-${{ matrix.os }}-py${{ matrix.python-version }}
|
||||
path: packages/leann-backend-diskann/dist/
|
||||
|
||||
# Build meta-package: leann (build last)
|
||||
build-meta:
|
||||
name: Build leann meta-package
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system build
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
cd packages/leann
|
||||
uv build
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: leann-meta-dist
|
||||
path: packages/leann/dist/
|
||||
|
||||
# Publish to PyPI
|
||||
publish:
|
||||
name: Publish to PyPI
|
||||
needs: [build-core, build-hnsw, build-diskann, build-meta]
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name == 'release' || github.event.inputs.publish != 'false'
|
||||
|
||||
steps:
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist
|
||||
|
||||
- name: Flatten directory structure
|
||||
run: |
|
||||
mkdir -p all_wheels
|
||||
find dist -name "*.whl" -exec cp {} all_wheels/ \;
|
||||
find dist -name "*.tar.gz" -exec cp {} all_wheels/ \;
|
||||
|
||||
- name: Publish to Test PyPI
|
||||
if: github.event.inputs.publish == 'test' || github.event_name == 'workflow_dispatch'
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: all_wheels/
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: github.event_name == 'release' || github.event.inputs.publish == 'prod'
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
packages-dir: all_wheels/
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -84,4 +84,6 @@ test_*.py
|
||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
*.passages.json
|
||||
|
||||
batchtest.py
|
||||
9
.vscode/extensions.json
vendored
9
.vscode/extensions.json
vendored
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"llvm-vs-code-extensions.vscode-clangd",
|
||||
"ms-python.python",
|
||||
"ms-vscode.cmake-tools",
|
||||
"vadimcn.vscode-lldb",
|
||||
"eamodio.gitlens",
|
||||
]
|
||||
}
|
||||
283
.vscode/launch.json
vendored
283
.vscode/launch.json
vendored
@@ -1,283 +0,0 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
// new emdedder
|
||||
{
|
||||
"name": "New Embedder",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--search",
|
||||
"--use-original",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--nprobe",
|
||||
"5000",
|
||||
"--load",
|
||||
"flat",
|
||||
"--embedder",
|
||||
"intfloat/multilingual-e5-small"
|
||||
]
|
||||
}
|
||||
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
|
||||
{
|
||||
"name": "main.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--query",
|
||||
"1000",
|
||||
"--load",
|
||||
"bm25"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Simple Build",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/simple_build.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
//# Fix for Intel MKL error
|
||||
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
|
||||
//python faiss/demo/build_demo.py
|
||||
{
|
||||
"name": "Build Demo",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/build_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings",
|
||||
"--port",
|
||||
"8082",
|
||||
"--diskann-search-memory-maximum",
|
||||
"2",
|
||||
"--diskann-graph",
|
||||
"240",
|
||||
"--search-only"
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
|
||||
},
|
||||
"preLaunchTask": "CMake: build",
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve MAC",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"ollama",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings"
|
||||
],
|
||||
"preLaunchTask": "CMake: build",
|
||||
"env": {
|
||||
"KMP_DUPLICATE_LIB_OK": "TRUE",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
|
||||
"KMP_BLOCKTIME": "0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Python Debugger: Current File with Arguments",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "ric/main_ric.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--config-name",
|
||||
"${input:configSelection}"
|
||||
],
|
||||
"justMyCode": false
|
||||
},
|
||||
//python ./demo/validate_equivalence.py sglang
|
||||
{
|
||||
"name": "Validate Equivalence",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/validate_equivalence.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"sglang"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
|
||||
{
|
||||
"name": "Retrieval Demo",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"vllm",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
// "flat",
|
||||
"ivf_flat"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
|
||||
{
|
||||
"name": "Retrieval Demo DiskANN",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--hnsw-M",
|
||||
"64",
|
||||
"--hnsw-efConstruction",
|
||||
"150",
|
||||
"--hnsw-efSearch",
|
||||
"128",
|
||||
"--hnsw-sq-bits",
|
||||
"8"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Find Probe",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "find_probe.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
},
|
||||
{
|
||||
"name": "Python: Attach",
|
||||
"type": "debugpy",
|
||||
"request": "attach",
|
||||
"processId": "${command:pickProcess}",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Edge RAG",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"edgerag_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Launch Embedding Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/embedding_server.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--zmq-port",
|
||||
"5556",
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "HNSW Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--load",
|
||||
"hnsw",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--search",
|
||||
"--skip-pa",
|
||||
"--recompute",
|
||||
"--hnsw-old"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"id": "configSelection",
|
||||
"type": "pickString",
|
||||
"description": "Select a configuration",
|
||||
"options": [
|
||||
"example_config",
|
||||
"vllm_gritlm"
|
||||
],
|
||||
"default": "example_config"
|
||||
}
|
||||
],
|
||||
}
|
||||
43
.vscode/settings.json
vendored
43
.vscode/settings.json
vendored
@@ -1,43 +0,0 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"./sglang_repo/python"
|
||||
],
|
||||
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
|
||||
"cmake.configureArgs": [
|
||||
"-DPYBIND=True",
|
||||
"-DUPDATE_EDITABLE_INSTALL=ON",
|
||||
],
|
||||
"cmake.environment": {
|
||||
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
|
||||
},
|
||||
"cmake.buildDirectory": "${workspaceFolder}/build",
|
||||
"files.associations": {
|
||||
"*.tcc": "cpp",
|
||||
"deque": "cpp",
|
||||
"string": "cpp",
|
||||
"unordered_map": "cpp",
|
||||
"vector": "cpp",
|
||||
"map": "cpp",
|
||||
"unordered_set": "cpp",
|
||||
"atomic": "cpp",
|
||||
"inplace_vector": "cpp",
|
||||
"*.ipp": "cpp",
|
||||
"forward_list": "cpp",
|
||||
"list": "cpp",
|
||||
"any": "cpp",
|
||||
"system_error": "cpp",
|
||||
"__hash_table": "cpp",
|
||||
"__split_buffer": "cpp",
|
||||
"__tree": "cpp",
|
||||
"ios": "cpp",
|
||||
"set": "cpp",
|
||||
"__string": "cpp",
|
||||
"string_view": "cpp",
|
||||
"ranges": "cpp",
|
||||
"iosfwd": "cpp"
|
||||
},
|
||||
"lldb.displayFormat": "auto",
|
||||
"lldb.showDisassembly": "auto",
|
||||
"lldb.dereferencePointers": true,
|
||||
"lldb.consoleMode": "commands",
|
||||
}
|
||||
16
.vscode/tasks.json
vendored
16
.vscode/tasks.json
vendored
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "cmake",
|
||||
"label": "CMake: build",
|
||||
"command": "build",
|
||||
"targets": [
|
||||
"all"
|
||||
],
|
||||
"group": "build",
|
||||
"problemMatcher": [],
|
||||
"detail": "CMake template build task"
|
||||
}
|
||||
]
|
||||
}
|
||||
223
README.md
223
README.md
@@ -12,11 +12,11 @@
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
|
||||
@@ -26,9 +26,8 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||
</p>
|
||||
|
||||
**The numbers speak for themselves:** Index 60 million Wikipedia articles in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks below ↓](#storage-usage-comparison)
|
||||
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
|
||||
|
||||
## Why This Matters
|
||||
|
||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||
|
||||
@@ -38,7 +37,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
|
||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||
|
||||
## Quick Start in 1 minute
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
@@ -48,33 +47,30 @@ git submodule update --init --recursive
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
export CC=$(brew --prefix llvm)/bin/clang
|
||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
uv sync --extra diskann
|
||||
# Install uv first if you don't have it:
|
||||
# curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
```
|
||||
|
||||
**Linux (Ubuntu/Debian):**
|
||||
**Linux:**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Ollama Setup (Optional for Local LLM):**
|
||||
|
||||
*We support both hf-transformers and Ollama for local LLMs. Ollama is recommended for faster performance.*
|
||||
**Ollama Setup (Recommended for full privacy):**
|
||||
|
||||
*macOS:*
|
||||
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
||||
|
||||
|
||||
**macOS:**
|
||||
|
||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
|
||||
@@ -83,7 +79,7 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
*Linux:*
|
||||
**Linux:**
|
||||
```bash
|
||||
# Install Ollama
|
||||
curl -fsSL https://ollama.ai/install.sh | sh
|
||||
@@ -95,62 +91,70 @@ ollama serve &
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
You can also replace `llama3.2:1b` to `deepseek-r1:1.5b` or `qwen3:4b` for better performance but higher memory usage.
|
||||
## Quick Start in 30s
|
||||
|
||||
## Dead Simple API
|
||||
|
||||
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
[Try in this ipynb file →](demo.ipynb)
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
# 1. Build index (no embeddings stored!)
|
||||
# 1. Build the index (no embeddings stored!)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("C# is a powerful programming language")
|
||||
builder.add_text("Python is a powerful programming language")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Python is a powerful programming language and it is very popular")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Neural networks process complex data")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
|
||||
builder.build_index("knowledge.leann")
|
||||
|
||||
# 2. Search with real-time embeddings
|
||||
searcher = LeannSearcher("knowledge.leann")
|
||||
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
|
||||
print(results)
|
||||
results = searcher.search("programming languages", top_k=2)
|
||||
|
||||
# 3. Chat with LEANN using retrieved results
|
||||
llm_config = {
|
||||
"type": "ollama",
|
||||
"model": "llama3.2:1b"
|
||||
}
|
||||
|
||||
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
|
||||
response = chat.ask(
|
||||
"Compare the two retrieved programming languages and say which one is more popular today.",
|
||||
top_k=2,
|
||||
)
|
||||
```
|
||||
|
||||
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
||||
## RAG on Everything!
|
||||
|
||||
[Try the interactive demo →](demo.ipynb)
|
||||
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||
|
||||
## Wild Things You Can Do
|
||||
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||
|
||||
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||
|
||||
### Process Any Documents (.pdf, .txt, .md)
|
||||
|
||||
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents.
|
||||
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
|
||||
|
||||
```bash
|
||||
# Drop your PDFs, .txt, .md files into examples/data/
|
||||
uv run ./examples/main_cli_example.py
|
||||
```
|
||||
|
||||
```
|
||||
# Or use python directly
|
||||
source .venv/bin/activate
|
||||
python ./examples/main_cli_example.py
|
||||
```
|
||||
|
||||
Uses Ollama `qwen3:8b` by default. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
||||
|
||||
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
||||
|
||||
### Search Your Entire Life
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
```bash
|
||||
python examples/mail_reader_leann.py
|
||||
# "What did my boss say about the Christmas party last year?"
|
||||
# "Find all emails from my mom about birthday plans"
|
||||
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
|
||||
```
|
||||
**90K emails → 14MB.** Finally, search your email like you search Google.
|
||||
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
@@ -183,13 +187,11 @@ Once the index is built, you can ask questions like:
|
||||
- "Show me emails about travel expenses"
|
||||
</details>
|
||||
|
||||
### Time Machine for the Web
|
||||
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History!
|
||||
```bash
|
||||
python examples/google_history_reader_leann.py
|
||||
# "What was that AI paper I read last month?"
|
||||
# "Show me all the cooking videos I watched"
|
||||
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||
```
|
||||
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
@@ -238,13 +240,13 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### WeChat Detective
|
||||
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||
|
||||
```bash
|
||||
python examples/wechat_history_reader_leann.py
|
||||
# "Show me all group chats about weekend plans"
|
||||
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||
```
|
||||
**400K messages → 64MB.** Search years of chat history in any language.
|
||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||
@@ -255,7 +257,13 @@ First, you need to install the WeChat exporter:
|
||||
sudo packages/wechat-exporter/wechattweak-cli install
|
||||
```
|
||||
|
||||
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
||||
**Troubleshooting:**
|
||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||
```
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -290,6 +298,73 @@ Once the index is built, you can ask questions like:
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## 🖥️ Command Line Interface
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
```bash
|
||||
# Build an index from documents
|
||||
leann build my-docs --docs ./documents
|
||||
|
||||
# Search your documents
|
||||
leann search my-docs "machine learning concepts"
|
||||
|
||||
# Interactive chat with your documents
|
||||
leann ask my-docs --interactive
|
||||
|
||||
# List all your indexes
|
||||
leann list
|
||||
```
|
||||
|
||||
**Key CLI features:**
|
||||
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||
- Smart text chunking with overlap
|
||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||
- Organized index storage in `~/.leann/indexes/`
|
||||
- Support for advanced search parameters
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||
|
||||
**Build Command:**
|
||||
```bash
|
||||
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||
|
||||
Options:
|
||||
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||
--graph-degree N Graph degree (default: 32)
|
||||
--complexity N Build complexity (default: 64)
|
||||
--force Force rebuild existing index
|
||||
--compact Use compact storage (default: true)
|
||||
--recompute Enable recomputation (default: true)
|
||||
```
|
||||
|
||||
**Search Command:**
|
||||
```bash
|
||||
leann search INDEX_NAME QUERY [OPTIONS]
|
||||
|
||||
Options:
|
||||
--top-k N Number of results (default: 5)
|
||||
--complexity N Search complexity (default: 64)
|
||||
--recompute-embeddings Use recomputation for highest accuracy
|
||||
--pruning-strategy {global,local,proportional}
|
||||
```
|
||||
|
||||
**Ask Command:**
|
||||
```bash
|
||||
leann ask INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🏗️ Architecture & How It Works
|
||||
|
||||
<p align="center">
|
||||
@@ -321,23 +396,15 @@ python examples/compare_faiss_vs_leann.py
|
||||
|
||||
Same dataset, same hardware, same embedding model. LEANN just works better.
|
||||
|
||||
## Reproduce Our Results
|
||||
|
||||
```bash
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run.
|
||||
|
||||
### Storage Usage Comparison
|
||||
|
||||
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
||||
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (780K messages chunks) |Google Search History (38K entries)
|
||||
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
||||
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
||||
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
||||
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
||||
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 2.4G |130.4 MB |
|
||||
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **79 MB** |**6.4MB** |
|
||||
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **97% smaller** |**95% smaller** |
|
||||
|
||||
<!-- ### Memory Usage Comparison
|
||||
|
||||
@@ -356,6 +423,15 @@ The evaluation script downloads data automatically on first run.
|
||||
|
||||
*Benchmarks run on Apple M3 Pro 36 GB*
|
||||
|
||||
## Reproduce Our Results
|
||||
|
||||
```bash
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
## 🔬 Paper
|
||||
|
||||
If you find Leann useful, please cite:
|
||||
@@ -432,6 +508,17 @@ export NCCL_IB_DISABLE=1
|
||||
export NCCL_NET_PLUGIN=none
|
||||
export NCCL_SOCKET_IFNAME=ens5
|
||||
``` -->
|
||||
## FAQ
|
||||
|
||||
### 1. My building time seems long
|
||||
|
||||
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||
|
||||
```bash
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
|
||||
|
||||
## 📈 Roadmap
|
||||
|
||||
|
||||
297
demo.ipynb
297
demo.ipynb
@@ -1,35 +1,296 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Quick Start in 30s"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
||||
"# 1. Build index (no embeddings stored!)\n",
|
||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language but it is not very popular\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\") \n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your macbook\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")\n",
|
||||
"# 2. Search with real-time embeddings\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2, recompute_beighbor_embeddings=True)\n",
|
||||
"print(results)\n",
|
||||
"# install this if you areusing colab\n",
|
||||
"! pip install leann"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build the index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO: Registering backend 'hnsw'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 31254.13chunk/s]\n",
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 12.19it/s]\n",
|
||||
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"M: 64 for level: 0\n",
|
||||
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
|
||||
"[0.00s] Reading Index HNSW header...\n",
|
||||
"[0.00s] Header read: d=768, ntotal=5\n",
|
||||
"[0.00s] Reading HNSW struct vectors...\n",
|
||||
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
|
||||
"[0.00s] Read assign_probas (6)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
|
||||
"[0.11s] Read cum_nneighbor_per_level (7)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
|
||||
"[0.23s] Read levels (5)\n",
|
||||
"[0.34s] Probing for compact storage flag...\n",
|
||||
"[0.34s] Found compact flag: False\n",
|
||||
"[0.34s] Compact flag is False, reading original format...\n",
|
||||
"[0.34s] Probing for potential extra byte before non-compact offsets...\n",
|
||||
"[0.34s] Found and consumed an unexpected 0x00 byte.\n",
|
||||
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
|
||||
"[0.34s] Read offsets (6)\n",
|
||||
"[0.44s] Attempting to read neighbors vector...\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
|
||||
"[0.44s] Read neighbors (320)\n",
|
||||
"[0.54s] Read scalar params (ep=4, max_lvl=0)\n",
|
||||
"[0.54s] Checking for storage data...\n",
|
||||
"[0.54s] Found storage fourcc: 49467849.\n",
|
||||
"[0.54s] Converting to CSR format...\n",
|
||||
"[0.54s] Conversion loop finished. \n",
|
||||
"[0.54s] Running validation checks...\n",
|
||||
" Checking total valid neighbor count...\n",
|
||||
" OK: Total valid neighbors = 20\n",
|
||||
" Checking final pointer indices...\n",
|
||||
" OK: Final pointers match data size.\n",
|
||||
"[0.54s] Deleting original neighbors and offsets arrays...\n",
|
||||
" CSR Stats: |data|=20, |level_ptr|=10\n",
|
||||
"[0.63s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
|
||||
" Pruning embeddings: Writing NULL storage marker.\n",
|
||||
"[0.73s] Conversion complete.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder\n",
|
||||
"\n",
|
||||
"llm_config = {\"type\": \"ollama\", \"model\": \"qwen3:8b\"}\n",
|
||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Search with real-time embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'programming languages'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Launching server time: 0.05758476257324219 seconds\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||
"INFO:leann.api: Embedding time: 0.05983591079711914 seconds\n",
|
||||
"INFO:leann.api: Search time: 0.039762258529663086 seconds\n",
|
||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language...\n",
|
||||
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is very popular...\n",
|
||||
"INFO:leann.api: Final enriched results: 2 passages\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n",
|
||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[SearchResult(id='0', score=np.float32(0.9646692), text='C# is a powerful programming language', metadata={}),\n",
|
||||
" SearchResult(id='1', score=np.float32(0.91955304), text='Python is a powerful programming language and it is very popular', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannSearcher\n",
|
||||
"\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||
"results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chat with LEANN using retrieved results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
|
||||
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO: Registering backend 'hnsw'\n",
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n",
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and say which one is more popular today.'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Launching server time: 0.11421084403991699 seconds\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||
"INFO:leann.api: Embedding time: 0.1147918701171875 seconds\n",
|
||||
"INFO:leann.api: Search time: 0.05468583106994629 seconds\n",
|
||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||
"INFO:leann.api: 1. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is very popular...\n",
|
||||
"INFO:leann.api: 2. passage_id='0' -> SUCCESS: C# is a powerful programming language...\n",
|
||||
"INFO:leann.api: Final enriched results: 2 passages\n",
|
||||
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 512, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'<think>\\n\\n</think>\\n\\nBased on the context provided, both Python and C# are mentioned as powerful programming languages, but no specific information is given about their popularity today. However, generally, Python is more popular for data science, web development, and other tasks, while C# is widely used in enterprise applications and game development. Since the context does not explicitly state which is more popular, but Python is often considered more popular in many cases, the best answer would be:\\n\\n**Python is more popular today.**'"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannChat\n",
|
||||
"\n",
|
||||
"llm_config = {\n",
|
||||
" \"type\": \"hf\",\n",
|
||||
" \"model\": \"Qwen/Qwen3-0.6B\"\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||
"\n",
|
||||
"response = chat.ask(\n",
|
||||
" \"Compare the two retrieved programming languages and say which one is more popular today. Respond in a single well-formed sentence.\",\n",
|
||||
" \"Compare the two retrieved programming languages and say which one is more popular today.\",\n",
|
||||
" top_k=2,\n",
|
||||
" recompute_beighbor_embeddings=True,\n",
|
||||
")\n",
|
||||
"print(response)"
|
||||
"response"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
54
docs/RELEASE.md
Normal file
54
docs/RELEASE.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Release Guide
|
||||
|
||||
## One-line Release 🚀
|
||||
|
||||
```bash
|
||||
./scripts/release.sh 0.1.1
|
||||
```
|
||||
|
||||
That's it! This script will:
|
||||
1. Update all package versions
|
||||
2. Commit and push changes
|
||||
3. Create GitHub release
|
||||
4. CI automatically builds and publishes to PyPI
|
||||
|
||||
## Manual Testing Before Release
|
||||
|
||||
For testing specific packages locally (especially DiskANN on macOS):
|
||||
|
||||
```bash
|
||||
# Build specific package locally
|
||||
./scripts/build_and_test.sh diskann # or hnsw, core, meta, all
|
||||
|
||||
# Test installation in a clean environment
|
||||
python -m venv test_env
|
||||
source test_env/bin/activate
|
||||
pip install packages/*/dist/*.whl
|
||||
|
||||
# Upload to Test PyPI (optional)
|
||||
./scripts/upload_to_pypi.sh test
|
||||
|
||||
# Upload to Production PyPI (use with caution)
|
||||
./scripts/upload_to_pypi.sh prod
|
||||
```
|
||||
|
||||
### Why Manual Build for DiskANN?
|
||||
|
||||
DiskANN's complex dependencies (protobuf, abseil, etc.) sometimes require local testing before release. The build script will:
|
||||
- Compile the C++ extension
|
||||
- Use `delocate` (macOS) or `auditwheel` (Linux) to bundle system libraries
|
||||
- Create a self-contained wheel with no external dependencies
|
||||
|
||||
## First-time setup
|
||||
|
||||
1. Install GitHub CLI:
|
||||
```bash
|
||||
brew install gh
|
||||
gh auth login
|
||||
```
|
||||
|
||||
2. Set PyPI token in GitHub:
|
||||
```bash
|
||||
gh secret set PYPI_API_TOKEN
|
||||
# Paste your PyPI token when prompted
|
||||
```
|
||||
@@ -96,14 +96,12 @@ class EmlxReader(BaseReader):
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[EMAIL METADATA]
|
||||
File: {filename}
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
[END METADATA]
|
||||
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
[Subject]: {subject}
|
||||
[Date]: {date}
|
||||
[EMAIL BODY Start]:
|
||||
{body}
|
||||
"""
|
||||
|
||||
|
||||
@@ -65,12 +65,14 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -78,7 +80,9 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
text = node.get_content()
|
||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
@@ -225,7 +229,7 @@ async def main():
|
||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
|
||||
parser.add_argument('--index-dir', type=str, default="./all_google_new",
|
||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||
parser.add_argument('--max-entries', type=int, default=1000,
|
||||
help='Maximum number of history entries to process (default: 1000)')
|
||||
|
||||
@@ -74,22 +74,17 @@ class ChromeHistoryReader(BaseReader):
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[BROWSING HISTORY METADATA]
|
||||
URL: {url}
|
||||
Title: {title}
|
||||
Last Visit: {last_visit}
|
||||
Visit Count: {visit_count}
|
||||
Typed Count: {typed_count}
|
||||
Hidden: {hidden}
|
||||
[END METADATA]
|
||||
|
||||
Title: {title}
|
||||
URL: {url}
|
||||
Last visited: {last_visit}
|
||||
[Title]: {title}
|
||||
[URL of the page]: {url}
|
||||
[Last visited time]: {last_visit}
|
||||
[Visit times]: {visit_count}
|
||||
[Typed times]: {typed_count}
|
||||
"""
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
||||
# if len(title) > 150:
|
||||
# print(f"Title is too long: {title}")
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -335,14 +335,15 @@ class WeChatHistoryReader(BaseReader):
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
time_str = timestamp.strftime('%H:%M:%S')
|
||||
# change to YYYY-MM-DD HH:MM:SS
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
sender = "Me" if is_sent_from_self else "Contact"
|
||||
message_parts.append(f"[{time_str}] {sender}: {readable_text}")
|
||||
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||
|
||||
concatenated_text = "\n".join(message_parts)
|
||||
|
||||
@@ -354,13 +355,11 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
|
||||
# TODO @yichuan give better format and rich info here!
|
||||
doc_content = f"""
|
||||
Contact: {contact_name}
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
return doc_content, contact_name
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
"""
|
||||
@@ -441,8 +440,8 @@ Contact: {contact_name}
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
doc_content = self._create_concatenated_content(message_group, contact_name)
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_mail_path():
|
||||
return os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
# Default mail path for macOS
|
||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
|
||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||
"""
|
||||
@@ -74,7 +74,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
@@ -85,9 +85,11 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
text = node.get_content()
|
||||
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
@@ -156,7 +158,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
|
||||
print(f"Loaded {len(documents)} email documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -216,9 +218,9 @@ async def query_leann_index(index_path: str, query: str):
|
||||
start_time = time.time()
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=12,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
|
||||
)
|
||||
@@ -231,7 +233,7 @@ async def main():
|
||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||
# Remove --mail-path argument and auto-detect all Messages directories
|
||||
# Remove DEFAULT_MAIL_PATH
|
||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
|
||||
parser.add_argument('--index-dir', type=str, default="./mail_index_index_file",
|
||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||
parser.add_argument('--max-emails', type=int, default=1000,
|
||||
help='Maximum number of emails to process (-1 means all)')
|
||||
@@ -251,6 +253,9 @@ async def main():
|
||||
mail_path = get_mail_path()
|
||||
print(f"Searching for email data in: {mail_path}")
|
||||
messages_dirs = find_all_messages_directories(mail_path)
|
||||
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||
# messages_dirs = messages_dirs[:1]
|
||||
|
||||
print('len(messages_dirs): ', len(messages_dirs))
|
||||
|
||||
|
||||
@@ -1,40 +1,40 @@
|
||||
import argparse
|
||||
from llama_index.core import SimpleDirectoryReader, Settings
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
import asyncio
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
import shutil
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from pathlib import Path
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
|
||||
async def main(args):
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
@@ -58,8 +58,9 @@ async def main(args):
|
||||
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
|
||||
@@ -70,9 +71,7 @@ async def main(args):
|
||||
# )
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
|
||||
)
|
||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
|
||||
@@ -105,6 +104,12 @@ if __name__ == "__main__":
|
||||
default="./test_doc_files",
|
||||
help="Directory where the Leann index will be stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="examples/data",
|
||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_count,
|
||||
concatenate_messages=False, # Disable concatenation - one message per document
|
||||
concatenate_messages=True, # Disable concatenation - one message per document
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
@@ -74,11 +74,11 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=64)
|
||||
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -86,10 +86,11 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
@@ -224,7 +225,7 @@ async def query_leann_index(index_path: str, query: str):
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=128,
|
||||
complexity=16,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
@@ -252,13 +253,13 @@ async def main():
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./wechat_history_june19_test",
|
||||
default="./wechat_history_magic_test_11Debug_new",
|
||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=5000,
|
||||
default=50,
|
||||
help="Maximum number of chat entries to process (default: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
import contextlib
|
||||
import pickle
|
||||
|
||||
import logging
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from leann.registry import register_backend
|
||||
@@ -14,6 +16,46 @@ from leann.interface import (
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_cpp_output_if_needed():
|
||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
|
||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
if not should_suppress:
|
||||
# Don't suppress, just yield
|
||||
yield
|
||||
return
|
||||
|
||||
# Save original file descriptors
|
||||
stdout_fd = sys.stdout.fileno()
|
||||
stderr_fd = sys.stderr.fileno()
|
||||
|
||||
# Save original stdout/stderr
|
||||
stdout_dup = os.dup(stdout_fd)
|
||||
stderr_dup = os.dup(stderr_fd)
|
||||
|
||||
try:
|
||||
# Redirect to /dev/null
|
||||
devnull = os.open(os.devnull, os.O_WRONLY)
|
||||
os.dup2(devnull, stdout_fd)
|
||||
os.dup2(devnull, stderr_fd)
|
||||
os.close(devnull)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Restore original file descriptors
|
||||
os.dup2(stdout_dup, stdout_fd)
|
||||
os.dup2(stderr_dup, stderr_fd)
|
||||
os.close(stdout_dup)
|
||||
os.close(stderr_dup)
|
||||
|
||||
|
||||
def _get_diskann_metrics():
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
@@ -65,18 +107,20 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||
data = data.astype(np.float32)
|
||||
|
||||
data_filename = f"{index_prefix}_data.bin"
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
build_kwargs.get("distance_metric", "mips").lower()
|
||||
)
|
||||
if metric_enum is None:
|
||||
raise ValueError("Unsupported distance_metric.")
|
||||
raise ValueError(
|
||||
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||
)
|
||||
|
||||
try:
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
@@ -98,36 +142,40 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
temp_data_file = index_dir / data_filename
|
||||
if temp_data_file.exists():
|
||||
os.remove(temp_data_file)
|
||||
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
|
||||
|
||||
|
||||
class DiskannSearcher(BaseSearcher):
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
super().__init__(
|
||||
index_path,
|
||||
backend_module_name="leann_backend_diskann.embedding_server",
|
||||
backend_module_name="leann_backend_diskann.diskann_embedding_server",
|
||||
**kwargs,
|
||||
)
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||
# Initialize DiskANN index with suppressed C++ output based on log level
|
||||
with suppress_cpp_output_if_needed():
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
self.zmq_port = kwargs.get("zmq_port", 6666)
|
||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
self.zmq_port,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
|
||||
fake_zmq_port = 6666
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
fake_zmq_port, # Initial port, can be updated at runtime
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -138,7 +186,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
zmq_port: Optional[int] = None,
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs,
|
||||
@@ -157,7 +205,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
- "global": Use global pruning strategy (default)
|
||||
- "local": Use local pruning strategy
|
||||
- "proportional": Not supported in DiskANN, falls back to global
|
||||
zmq_port: ZMQ port for embedding server
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||
@@ -165,22 +213,25 @@ class DiskannSearcher(BaseSearcher):
|
||||
Returns:
|
||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||
"""
|
||||
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
)
|
||||
current_port = self._index.get_zmq_port()
|
||||
if zmq_port != current_port:
|
||||
logger.debug(
|
||||
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
|
||||
)
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
if pruning_strategy == "proportional":
|
||||
raise NotImplementedError(
|
||||
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
||||
)
|
||||
|
||||
# Use recompute_embeddings parameter
|
||||
use_recompute = recompute_embeddings
|
||||
if use_recompute:
|
||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
if not meta_file_path.exists():
|
||||
raise RuntimeError(
|
||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
||||
)
|
||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
|
||||
@@ -190,25 +241,26 @@ class DiskannSearcher(BaseSearcher):
|
||||
else: # "global"
|
||||
use_global_pruning = True
|
||||
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
query.shape[0],
|
||||
top_k,
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
use_recompute,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning,
|
||||
)
|
||||
# Perform search with suppressed C++ output based on log level
|
||||
with suppress_cpp_output_if_needed():
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
query.shape[0],
|
||||
top_k,
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
recompute_embeddings,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning,
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels]
|
||||
for batch_labels in labels
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
DiskANN-specific embedding server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
import zmq
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import sys
|
||||
import logging
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Ensure we have a handler if none exists
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
||||
"""
|
||||
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
||||
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||
|
||||
# Add leann-core to path for unified embedding computation
|
||||
current_dir = Path(__file__).parent
|
||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.api import PassageManager
|
||||
|
||||
logger.info("Successfully imported unified embedding computation module")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import embedding computation module: {e}")
|
||||
return
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
# Check port availability
|
||||
import socket
|
||||
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
logger.error(f"Port {zmq_port} is already in use")
|
||||
return
|
||||
|
||||
# Only support metadata file, fail fast for everything else
|
||||
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||
|
||||
# Load metadata to get passage sources
|
||||
with open(passages_file, "r") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
|
||||
# Import protobuf after ensuring the path is correct
|
||||
try:
|
||||
from . import embedding_pb2
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import protobuf module: {e}")
|
||||
return
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread using REP socket for universal compatibility"""
|
||||
context = zmq.Context()
|
||||
socket = context.socket(
|
||||
zmq.REP
|
||||
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# REP socket receives single-part messages
|
||||
message = socket.recv()
|
||||
|
||||
# Check for empty messages - REP socket requires response to every request
|
||||
if len(message) == 0:
|
||||
logger.debug("Received empty message, sending empty response")
|
||||
socket.send(b"") # REP socket must respond to every request
|
||||
continue
|
||||
|
||||
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
||||
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
||||
|
||||
e2e_start = time.time()
|
||||
|
||||
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
||||
texts = []
|
||||
node_ids = []
|
||||
is_text_request = False
|
||||
|
||||
try:
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = list(req_proto.node_ids)
|
||||
|
||||
if not node_ids:
|
||||
raise RuntimeError(
|
||||
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
||||
)
|
||||
except Exception as protobuf_error:
|
||||
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
||||
# Fallback to msgpack (for BaseSearcher direct text requests)
|
||||
try:
|
||||
import msgpack
|
||||
|
||||
request = msgpack.unpackb(message)
|
||||
# For BaseSearcher compatibility, request is a list of texts directly
|
||||
if isinstance(request, list) and all(
|
||||
isinstance(item, str) for item in request
|
||||
):
|
||||
texts = request
|
||||
is_text_request = True
|
||||
logger.info(
|
||||
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Not a valid msgpack text request")
|
||||
except Exception as msgpack_error:
|
||||
raise RuntimeError(
|
||||
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
||||
)
|
||||
|
||||
# Look up texts by node IDs (only if not direct text request)
|
||||
if not is_text_request:
|
||||
for nid in node_ids:
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(
|
||||
f"FATAL: Empty text for passage ID {nid}"
|
||||
)
|
||||
texts.append(txt)
|
||||
except KeyError as e:
|
||||
logger.error(f"Passage ID {nid} not found: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"Processing {len(texts)} texts")
|
||||
logger.debug(
|
||||
f"Text lengths: {[len(t) for t in texts[:5]]}"
|
||||
) # Show first 5
|
||||
|
||||
# Process embeddings using unified computation
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Prepare response based on request type
|
||||
if is_text_request:
|
||||
# For BaseSearcher compatibility: return msgpack format
|
||||
import msgpack
|
||||
|
||||
response_data = msgpack.packb(embeddings.tolist())
|
||||
else:
|
||||
# For DiskANN C++ compatibility: return protobuf format
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(
|
||||
embeddings, dtype=np.float32
|
||||
)
|
||||
|
||||
# Serialize embeddings data
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# Send response back to the client
|
||||
socket.send(response_data)
|
||||
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
except zmq.Again:
|
||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
zmq_thread.start()
|
||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DiskANN Server shutting down...")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import signal
|
||||
import sys
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument(
|
||||
"--passages-file",
|
||||
type=str,
|
||||
help="Metadata JSON file containing passage sources",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create and start the DiskANN embedding server
|
||||
create_diskann_embedding_server(
|
||||
passages_file=args.passages_file,
|
||||
zmq_port=args.zmq_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
)
|
||||
@@ -1,705 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import argparse
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import zmq
|
||||
import numpy as np
|
||||
import msgpack
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
RESET = "\033[0m"
|
||||
|
||||
# --- New Passage Loader from HNSW backend ---
|
||||
class SimplePassageLoader:
|
||||
"""
|
||||
Simple passage loader that replaces config.py dependencies
|
||||
"""
|
||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||
self.passages_data = passages_data or {}
|
||||
self._meta_path = ''
|
||||
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID"""
|
||||
str_id = str(passage_id)
|
||||
if str_id in self.passages_data:
|
||||
return {"text": self.passages_data[str_id]}
|
||||
else:
|
||||
# Return empty text for missing passages
|
||||
return {"text": ""}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.passages_data)
|
||||
|
||||
def keys(self):
|
||||
return self.passages_data.keys()
|
||||
|
||||
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages using metadata file with PassageManager for lazy loading
|
||||
"""
|
||||
# Load metadata to get passage sources
|
||||
with open(meta_file, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
# Import PassageManager dynamically to avoid circular imports
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Find the leann package directory relative to this file
|
||||
current_dir = Path(__file__).parent
|
||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.api import PassageManager
|
||||
passage_manager = PassageManager(meta['passage_sources'])
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
|
||||
|
||||
class LazyPassageLoader(SimplePassageLoader):
|
||||
def __init__(self, passage_manager):
|
||||
self.passage_manager = passage_manager
|
||||
# Initialize parent with empty data
|
||||
super().__init__({})
|
||||
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID with lazy loading"""
|
||||
try:
|
||||
int_id = int(passage_id)
|
||||
string_id = str(int_id)
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.passage_manager.global_offset_map)
|
||||
|
||||
def keys(self):
|
||||
return self.passage_manager.global_offset_map.keys()
|
||||
|
||||
loader = LazyPassageLoader(passage_manager)
|
||||
loader._meta_path = meta_file
|
||||
return loader
|
||||
|
||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages from a JSONL file with label map support
|
||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||
"""
|
||||
|
||||
if not os.path.exists(passages_file):
|
||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||
|
||||
if not passages_file.endswith('.jsonl'):
|
||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||
|
||||
# Load passages directly by their sequential IDs
|
||||
passages_data = {}
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
passages_data[passage['id']] = passage['text']
|
||||
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
|
||||
return SimplePassageLoader(passages_data)
|
||||
|
||||
def create_embedding_server_thread(
|
||||
zmq_port=5555,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
max_batch_size=128,
|
||||
passages_file: Optional[str] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
Create and run embedding server in the current thread
|
||||
This function is designed to be called in a separate thread
|
||||
"""
|
||||
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
||||
|
||||
try:
|
||||
# Check if port is already occupied
|
||||
import socket
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||
return
|
||||
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
embedding_mode = "openai"
|
||||
|
||||
if embedding_mode == "mlx":
|
||||
from leann.api import compute_embeddings_mlx
|
||||
import torch
|
||||
logger.info("Using MLX for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
mps_available = False
|
||||
elif embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
import torch
|
||||
logger.info("Using OpenAI API for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
mps_available = False
|
||||
elif embedding_mode == "sentence-transformers":
|
||||
# Initialize model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
import torch
|
||||
|
||||
# Select device
|
||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
logger.info("Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
logger.info("Using MPS device (Apple Silicon)")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU device")
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
|
||||
# Optimize model
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
logger.info(f"Using FP16 precision with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
|
||||
|
||||
# Load passages from file if provided
|
||||
if passages_file and os.path.exists(passages_file):
|
||||
# Check if it's a metadata file or a single passages file
|
||||
if passages_file.endswith('.meta.json'):
|
||||
passages = load_passages_from_metadata(passages_file)
|
||||
else:
|
||||
# Try to find metadata file in same directory
|
||||
passages_dir = Path(passages_file).parent
|
||||
meta_files = list(passages_dir.glob("*.meta.json"))
|
||||
if meta_files:
|
||||
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
|
||||
passages = load_passages_from_metadata(str(meta_files[0]))
|
||||
else:
|
||||
# Fallback to original single file loading (will cause warnings)
|
||||
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
|
||||
passages = load_passages_from_file(passages_file)
|
||||
else:
|
||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||
passages = SimplePassageLoader()
|
||||
|
||||
logger.info(f"Loaded {len(passages)} passages.")
|
||||
|
||||
def client_warmup(zmq_port):
|
||||
"""Perform client-side warmup for DiskANN server"""
|
||||
time.sleep(2)
|
||||
print(f"Performing client-side warmup with model {model_name}...")
|
||||
|
||||
# Get actual passage IDs from the loaded passages
|
||||
sample_ids = []
|
||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
||||
available_ids = list(passages.keys())
|
||||
# Take up to 5 actual IDs, but at least 1
|
||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
||||
else:
|
||||
print("No passages available for warmup, skipping warmup...")
|
||||
return
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
try:
|
||||
ids_to_send = [int(x) for x in sample_ids]
|
||||
except ValueError:
|
||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
||||
return
|
||||
|
||||
if not ids_to_send:
|
||||
print("Skipping warmup send.")
|
||||
return
|
||||
|
||||
# Use protobuf format for warmup
|
||||
from . import embedding_pb2
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.node_ids.extend(ids_to_send)
|
||||
request_bytes = req_proto.SerializeToString()
|
||||
|
||||
for i in range(3):
|
||||
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
|
||||
socket.send(request_bytes)
|
||||
response_bytes = socket.recv()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
resp_proto.ParseFromString(response_bytes)
|
||||
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
|
||||
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Client-side Protobuf ZMQ warmup complete")
|
||||
socket.close()
|
||||
context.term()
|
||||
except Exception as e:
|
||||
print(f"Error during Protobuf ZMQ warmup: {e}")
|
||||
|
||||
class DeviceTimer:
|
||||
"""Device timer"""
|
||||
def __init__(self, name="", device=device):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
else:
|
||||
self.start_event = None
|
||||
self.end_event = None
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start()
|
||||
yield
|
||||
self.end()
|
||||
|
||||
def start(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
self.start_event.record()
|
||||
else:
|
||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.start_time = time.time()
|
||||
|
||||
def end(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
self.end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed_time(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||
else:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def print_elapsed(self):
|
||||
elapsed = self.elapsed_time()
|
||||
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
|
||||
|
||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
||||
"""Process text batch"""
|
||||
if not texts_batch:
|
||||
return np.array([])
|
||||
|
||||
# Filter out empty texts and their corresponding IDs
|
||||
valid_texts = []
|
||||
valid_ids = []
|
||||
for i, text in enumerate(texts_batch):
|
||||
if text.strip(): # Only include non-empty texts
|
||||
valid_texts.append(text)
|
||||
valid_ids.append(ids_batch[i])
|
||||
|
||||
if not valid_texts:
|
||||
print("WARNING: No valid texts in batch")
|
||||
return np.array([])
|
||||
|
||||
# Tokenize
|
||||
token_timer = DeviceTimer("tokenization")
|
||||
with token_timer.timing():
|
||||
inputs = tokenizer(
|
||||
valid_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
# Compute embeddings
|
||||
embed_timer = DeviceTimer("embedding computation")
|
||||
with embed_timer.timing():
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# Mean pooling
|
||||
attention_mask = inputs['attention_mask']
|
||||
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
batch_embeddings = sum_embeddings / sum_mask
|
||||
embed_timer.print_elapsed()
|
||||
|
||||
return batch_embeddings.cpu().numpy()
|
||||
|
||||
# ZMQ server main loop - modified to use REP socket
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.ROUTER) # Changed to REP socket
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
||||
|
||||
# Set timeouts
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
|
||||
|
||||
from . import embedding_pb2
|
||||
|
||||
print(f"INFO: Embedding server ready to serve requests")
|
||||
|
||||
# Start warmup thread if enabled
|
||||
if enable_warmup and len(passages) > 0:
|
||||
import threading
|
||||
print(f"Warmup enabled: starting warmup thread")
|
||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
else:
|
||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
||||
|
||||
while True:
|
||||
try:
|
||||
parts = socket.recv_multipart()
|
||||
|
||||
# --- Restore robust message format detection ---
|
||||
# Must check parts length to avoid IndexError
|
||||
if len(parts) >= 3:
|
||||
identity = parts[0]
|
||||
# empty = parts[1] # We usually don't care about the middle empty frame
|
||||
message = parts[2]
|
||||
elif len(parts) == 2:
|
||||
# Can also handle cases without empty frame
|
||||
identity = parts[0]
|
||||
message = parts[1]
|
||||
else:
|
||||
# If received message format is wrong, print warning and ignore it instead of crashing
|
||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
||||
continue
|
||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
||||
|
||||
# Handle control messages (MessagePack format)
|
||||
try:
|
||||
request_payload = msgpack.unpackb(message)
|
||||
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||
if request_payload[0] == "__QUERY_META_PATH__":
|
||||
# Return the current meta path being used by the server
|
||||
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
|
||||
response = [current_meta_path]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
|
||||
# Update the server's meta path and reload passages
|
||||
new_meta_path = request_payload[1]
|
||||
try:
|
||||
print(f"INFO: Updating server meta path to: {new_meta_path}")
|
||||
# Reload passages from the new meta file
|
||||
passages = load_passages_from_metadata(new_meta_path)
|
||||
# Store the meta path for future queries
|
||||
passages._meta_path = new_meta_path
|
||||
response = ["SUCCESS"]
|
||||
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update meta path: {e}")
|
||||
response = ["FAILED", str(e)]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__QUERY_MODEL__":
|
||||
# Return the current model being used by the server
|
||||
response = [model_name]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
|
||||
# Update the server's embedding model
|
||||
new_model_name = request_payload[1]
|
||||
try:
|
||||
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
|
||||
|
||||
# Clean up old model to free memory
|
||||
if not use_mlx:
|
||||
print("INFO: Releasing old model from memory...")
|
||||
old_model = model
|
||||
old_tokenizer = tokenizer
|
||||
|
||||
# Load new tokenizer first
|
||||
print(f"Loading new tokenizer for {new_model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
|
||||
|
||||
# Load new model
|
||||
print(f"Loading new model {new_model_name}...")
|
||||
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
|
||||
|
||||
# Optimize new model
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"INFO: Using FP16 precision with model: {new_model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
|
||||
# Now safely delete old model after new one is loaded
|
||||
del old_model
|
||||
del old_tokenizer
|
||||
|
||||
# Clear GPU cache if available
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
print("INFO: Cleared CUDA cache")
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
print("INFO: Cleared MPS cache")
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
print("INFO: Memory cleanup completed")
|
||||
|
||||
# Update model name
|
||||
model_name = new_model_name
|
||||
|
||||
response = ["SUCCESS"]
|
||||
print(f"INFO: Successfully updated model to: {new_model_name}")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update model: {e}")
|
||||
response = ["FAILED", str(e)]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
except:
|
||||
# Not a control message, continue with normal protobuf processing
|
||||
pass
|
||||
|
||||
e2e_start = time.time()
|
||||
lookup_timer = DeviceTimer("text lookup")
|
||||
|
||||
# Parse request
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = req_proto.node_ids
|
||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
||||
|
||||
# Add debug information
|
||||
if len(node_ids) > 0:
|
||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
||||
|
||||
# Look up texts
|
||||
texts = []
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
txt = txtinfo["text"]
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
else:
|
||||
# If text is empty, we still need a placeholder for batch processing,
|
||||
# but record its ID as missing
|
||||
texts.append("")
|
||||
missing_ids.append(nid)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
if missing_ids:
|
||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
||||
|
||||
# Process batch
|
||||
total_size = len(texts)
|
||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
if total_size > max_batch_size:
|
||||
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
||||
for i in range(0, total_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, total_size)
|
||||
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
||||
|
||||
chunk_texts = texts[i:end_idx]
|
||||
chunk_ids = node_ids[i:end_idx]
|
||||
|
||||
if embedding_mode == "mlx":
|
||||
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
|
||||
elif embedding_mode == "openai":
|
||||
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
|
||||
else: # sentence-transformers
|
||||
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
|
||||
all_embeddings.append(embeddings_chunk)
|
||||
|
||||
if embedding_mode == "sentence-transformers":
|
||||
if cuda_available:
|
||||
torch.cuda.empty_cache()
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
hidden = np.vstack(all_embeddings)
|
||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
||||
else:
|
||||
if embedding_mode == "mlx":
|
||||
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
|
||||
elif embedding_mode == "openai":
|
||||
hidden = compute_embeddings_openai(texts, model_name)
|
||||
else: # sentence-transformers
|
||||
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
|
||||
|
||||
# Serialize response
|
||||
ser_start = time.time()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
resp_proto.missing_ids.extend(missing_ids)
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# REP socket sends a single response
|
||||
socket.send_multipart([identity, b'', response_data])
|
||||
|
||||
ser_end = time.time()
|
||||
|
||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
||||
|
||||
if embedding_mode == "sentence-transformers":
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
|
||||
except zmq.Again:
|
||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"ERROR: Error in ZMQ server: {e}")
|
||||
try:
|
||||
# Send empty response to maintain REQ-REP state
|
||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
||||
socket.send(empty_resp.SerializeToString())
|
||||
except:
|
||||
# If sending fails, recreate socket
|
||||
socket.close()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
print("INFO: ZMQ socket recreated after error")
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to start embedding server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_embedding_server(
|
||||
domain="demo",
|
||||
load_passages=True,
|
||||
load_embeddings=False,
|
||||
use_fp16=True,
|
||||
use_int8=False,
|
||||
use_cuda_graphs=False,
|
||||
zmq_port=5555,
|
||||
max_batch_size=128,
|
||||
lazy_load_passages=False,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
passages_file: Optional[str] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
原有的 create_embedding_server 函数保持不变
|
||||
这个是阻塞版本,用于直接运行
|
||||
"""
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
|
||||
choices=["sentence-transformers", "mlx", "openai"],
|
||||
help="Embedding backend mode")
|
||||
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
|
||||
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle backward compatibility with use_mlx
|
||||
embedding_mode = args.embedding_mode
|
||||
if args.use_mlx:
|
||||
embedding_mode = "mlx"
|
||||
|
||||
create_embedding_server(
|
||||
domain=args.domain,
|
||||
load_passages=args.load_passages,
|
||||
load_embeddings=args.load_embeddings,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int8=args.use_int8,
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
zmq_port=args.zmq_port,
|
||||
max_batch_size=args.max_batch_size,
|
||||
lazy_load_passages=args.lazy_load_passages,
|
||||
model_name=args.model_name,
|
||||
passages_file=args.passages_file,
|
||||
embedding_mode=embedding_mode,
|
||||
enable_warmup=not args.disable_warmup,
|
||||
)
|
||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: af2a26481e...25339b0341
@@ -1,10 +1,9 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal
|
||||
import pickle
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
import shutil
|
||||
import time
|
||||
import logging
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
@@ -16,6 +15,8 @@ from leann.interface import (
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_metric_map():
|
||||
from . import faiss # type: ignore
|
||||
@@ -57,9 +58,9 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||
data = data.astype(np.float32)
|
||||
|
||||
|
||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
@@ -81,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
|
||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||
|
||||
@@ -90,11 +91,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ CSR conversion successful.")
|
||||
logger.info("✅ CSR conversion successful.")
|
||||
index_file_old = index_file.with_suffix(".old")
|
||||
shutil.move(str(index_file), str(index_file_old))
|
||||
shutil.move(str(csr_temp_file), str(index_file))
|
||||
print(
|
||||
logger.info(
|
||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||
)
|
||||
else:
|
||||
@@ -131,24 +132,22 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
hnsw_config = faiss.HNSWIndexConfig()
|
||||
hnsw_config.is_compact = self.is_compact
|
||||
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
||||
|
||||
if self.is_pruned and not hnsw_config.is_recompute:
|
||||
raise RuntimeError("Index is pruned but recompute is disabled.")
|
||||
hnsw_config.is_recompute = (
|
||||
self.is_pruned
|
||||
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
top_k: int,
|
||||
zmq_port: Optional[int] = None,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
batch_size: int = 0,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -166,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
- "global": Use global PQ queue size for selection (default)
|
||||
- "local": Local pruning, sort and select best candidates
|
||||
- "proportional": Base selection on new neighbor count ratio
|
||||
zmq_port: ZMQ port for embedding server
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
||||
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
||||
|
||||
@@ -175,15 +174,14 @@ class HNSWSearcher(BaseSearcher):
|
||||
"""
|
||||
from . import faiss # type: ignore
|
||||
|
||||
# Use recompute_embeddings parameter
|
||||
use_recompute = recompute_embeddings or self.is_pruned
|
||||
if use_recompute:
|
||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
if not meta_file_path.exists():
|
||||
raise RuntimeError(
|
||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
||||
if not recompute_embeddings:
|
||||
if self.is_pruned:
|
||||
raise RuntimeError("Recompute is required for pruned index.")
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
)
|
||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
@@ -191,7 +189,10 @@ class HNSWSearcher(BaseSearcher):
|
||||
faiss.normalize_L2(query)
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
params.zmq_port = zmq_port
|
||||
if zmq_port is not None:
|
||||
params.zmq_port = (
|
||||
zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||
)
|
||||
params.efSearch = complexity
|
||||
params.beam_size = beam_width
|
||||
|
||||
@@ -228,8 +229,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels]
|
||||
for batch_labels in labels
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -11,25 +11,29 @@ import numpy as np
|
||||
import msgpack
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from typing import Optional
|
||||
import sys
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "INFO").upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Ensure we have a handler if none exists
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
passages_data: Optional[Dict[str, str]] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
distance_metric: str = "mips",
|
||||
@@ -39,14 +43,8 @@ def create_hnsw_embedding_server(
|
||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||
Simplified version using unified embedding computation module.
|
||||
"""
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if embedding_mode == "sentence-transformers" and model_name.startswith(
|
||||
"text-embedding-"
|
||||
):
|
||||
embedding_mode = "openai"
|
||||
|
||||
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||
print(f"Using embedding mode: {embedding_mode}")
|
||||
logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||
|
||||
# Add leann-core to path for unified embedding computation
|
||||
current_dir = Path(__file__).parent
|
||||
@@ -57,9 +55,9 @@ def create_hnsw_embedding_server(
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.api import PassageManager
|
||||
|
||||
print("Successfully imported unified embedding computation module")
|
||||
logger.info("Successfully imported unified embedding computation module")
|
||||
except ImportError as e:
|
||||
print(f"ERROR: Failed to import embedding computation module: {e}")
|
||||
logger.error(f"Failed to import embedding computation module: {e}")
|
||||
return
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
@@ -72,26 +70,28 @@ def create_hnsw_embedding_server(
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||
logger.error(f"Port {zmq_port} is already in use")
|
||||
return
|
||||
|
||||
# Only support metadata file, fail fast for everything else
|
||||
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||
|
||||
|
||||
# Load metadata to get passage sources
|
||||
with open(passages_file, "r") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata")
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread"""
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
print(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
@@ -99,12 +99,12 @@ def create_hnsw_embedding_server(
|
||||
while True:
|
||||
try:
|
||||
message_bytes = socket.recv()
|
||||
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||
|
||||
e2e_start = time.time()
|
||||
request_payload = msgpack.unpackb(message_bytes)
|
||||
|
||||
# Handle direct text embedding request (for OpenAI and sentence-transformers)
|
||||
# Handle direct text embedding request
|
||||
if isinstance(request_payload, list) and len(request_payload) > 0:
|
||||
# Check if this is a direct text request (list of strings)
|
||||
if all(isinstance(item, str) for item in request_payload):
|
||||
@@ -112,7 +112,7 @@ def create_hnsw_embedding_server(
|
||||
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
||||
)
|
||||
|
||||
# Use unified embedding computation
|
||||
# Use unified embedding computation (now with model caching)
|
||||
embeddings = compute_embeddings(
|
||||
request_payload, model_name, mode=embedding_mode
|
||||
)
|
||||
@@ -136,8 +136,8 @@ def create_hnsw_embedding_server(
|
||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||
|
||||
logger.debug("Distance calculation request received")
|
||||
print(f" Node IDs: {node_ids}")
|
||||
print(f" Query vector dim: {len(query_vector)}")
|
||||
logger.debug(f" Node IDs: {node_ids}")
|
||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||
|
||||
# Get embeddings for node IDs
|
||||
texts = []
|
||||
@@ -147,18 +147,20 @@ def create_hnsw_embedding_server(
|
||||
txt = passage_data["text"]
|
||||
texts.append(txt)
|
||||
except KeyError:
|
||||
print(f"ERROR: Passage ID {nid} not found")
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
raise RuntimeError(
|
||||
f"FATAL: Passage with ID {nid} not found"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Process embeddings
|
||||
embeddings = compute_embeddings(
|
||||
texts, model_name, mode=embedding_mode
|
||||
)
|
||||
print(
|
||||
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Calculate distances
|
||||
@@ -173,7 +175,9 @@ def create_hnsw_embedding_server(
|
||||
response_bytes = msgpack.packb(
|
||||
[response_payload], use_single_float=True
|
||||
)
|
||||
print(f"Sending distance response with {len(distances)} distances")
|
||||
logger.debug(
|
||||
f"Sending distance response with {len(distances)} distances"
|
||||
)
|
||||
|
||||
socket.send(response_bytes)
|
||||
e2e_end = time.time()
|
||||
@@ -188,14 +192,14 @@ def create_hnsw_embedding_server(
|
||||
or len(request_payload) != 1
|
||||
or not isinstance(request_payload[0], list)
|
||||
):
|
||||
print(
|
||||
f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
||||
logger.error(
|
||||
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
||||
)
|
||||
socket.send(msgpack.packb([[], []]))
|
||||
continue
|
||||
|
||||
node_ids = request_payload[0]
|
||||
print(f"Request for {len(node_ids)} node embeddings")
|
||||
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
||||
|
||||
# Look up texts by node IDs
|
||||
texts = []
|
||||
@@ -204,24 +208,26 @@ def create_hnsw_embedding_server(
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||
raise RuntimeError(
|
||||
f"FATAL: Empty text for passage ID {nid}"
|
||||
)
|
||||
texts.append(txt)
|
||||
except KeyError:
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Process embeddings
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
print(
|
||||
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Serialization and response
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
print(
|
||||
f"{RED}!!! ERROR: NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}...{RESET}"
|
||||
logger.error(
|
||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||
)
|
||||
assert False
|
||||
|
||||
@@ -242,7 +248,7 @@ def create_hnsw_embedding_server(
|
||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Error in ZMQ server loop: {e}")
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
@@ -250,18 +256,29 @@ def create_hnsw_embedding_server(
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
zmq_thread.start()
|
||||
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("HNSW Server shutting down...")
|
||||
logger.info("HNSW Server shutting down...")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import signal
|
||||
import sys
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument(
|
||||
@@ -282,7 +299,7 @@ if __name__ == "__main__":
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai"],
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,12 @@ build-backend = "scikit_build_core.build"
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.1.0"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
dependencies = [
|
||||
"leann-core==0.1.0",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.scikit-build]
|
||||
wheel.packages = ["leann_backend_hnsw"]
|
||||
|
||||
Submodule packages/leann-backend-hnsw/third_party/msgpack-c updated: 9b801f087a...a0b2ec09da
@@ -5,14 +5,22 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.1.0"
|
||||
description = "Core API and plugin system for Leann."
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { text = "MIT" }
|
||||
|
||||
# All required dependencies included
|
||||
dependencies = [
|
||||
"numpy>=1.20.0",
|
||||
"tqdm>=4.60.0"
|
||||
"tqdm>=4.60.0",
|
||||
"psutil>=5.8.0",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
"torch>=2.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,13 +5,18 @@ with the correct, original embedding logic from the user's reference code.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
import numpy as np
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
from .registry import BACKEND_REGISTRY
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .chat import get_llm
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
@@ -19,7 +24,8 @@ def compute_embeddings(
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
port: int = 5557,
|
||||
port: Optional[int] = None,
|
||||
is_build=False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -38,6 +44,8 @@ def compute_embeddings(
|
||||
"""
|
||||
if use_server:
|
||||
# Use embedding server (for search/query)
|
||||
if port is None:
|
||||
raise ValueError("port is required when use_server is True")
|
||||
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||
else:
|
||||
# Use direct computation (for build_index)
|
||||
@@ -49,6 +57,7 @@ def compute_embeddings(
|
||||
chunks,
|
||||
model_name,
|
||||
mode=mode,
|
||||
is_build=is_build,
|
||||
)
|
||||
|
||||
|
||||
@@ -61,8 +70,8 @@ def compute_embeddings_via_server(
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the sentence transformer model
|
||||
"""
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
import zmq
|
||||
import msgpack
|
||||
@@ -105,25 +114,24 @@ class PassageManager:
|
||||
self.global_offset_map = {} # Combined map for fast lookup
|
||||
|
||||
for source in passage_sources:
|
||||
if source["type"] == "jsonl":
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"]
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Passage index file not found: {index_file}"
|
||||
)
|
||||
with open(index_file, "rb") as f:
|
||||
offset_map = pickle.load(f)
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
self.passage_files[passage_file] = passage_file
|
||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"] # .idx file
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||
with open(index_file, "rb") as f:
|
||||
offset_map = pickle.load(f)
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
self.passage_files[passage_file] = passage_file
|
||||
|
||||
# Build global map for O(1) lookup
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
# Build global map for O(1) lookup
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
|
||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||
if passage_id in self.global_offset_map:
|
||||
passage_file, offset = self.global_offset_map[passage_id]
|
||||
# Lazy file opening - only open when needed
|
||||
with open(passage_file, "r", encoding="utf-8") as f:
|
||||
f.seek(offset)
|
||||
return json.loads(f.readline())
|
||||
@@ -134,7 +142,7 @@ class LeannBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
backend_name: str,
|
||||
embedding_model: str = "facebook/contriever-msmarco",
|
||||
embedding_model: str = "facebook/contriever",
|
||||
dimensions: Optional[int] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**backend_kwargs,
|
||||
@@ -209,7 +217,7 @@ class LeannBuilder:
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
port=5557,
|
||||
is_build=True,
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
@@ -283,7 +291,7 @@ class LeannBuilder:
|
||||
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
print(
|
||||
logger.info(
|
||||
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||
)
|
||||
|
||||
@@ -291,7 +299,7 @@ class LeannBuilder:
|
||||
if len(self.chunks) != len(ids):
|
||||
# If no text chunks provided, create placeholder text entries
|
||||
if not self.chunks:
|
||||
print("No text chunks provided, creating placeholder entries...")
|
||||
logger.info("No text chunks provided, creating placeholder entries...")
|
||||
for id_val in ids:
|
||||
self.add_text(
|
||||
f"Document {id_val}",
|
||||
@@ -366,15 +374,19 @@ class LeannBuilder:
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
print(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||
logger.info(
|
||||
f"Index built successfully from precomputed embeddings: {index_path}"
|
||||
)
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
|
||||
with open(meta_path_str, "r", encoding="utf-8") as f:
|
||||
self.meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(self.meta_path_str).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Leann metadata file not found at {self.meta_path_str}"
|
||||
)
|
||||
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
||||
self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
@@ -382,16 +394,15 @@ class LeannSearcher:
|
||||
self.embedding_mode = self.meta_data.get(
|
||||
"embedding_mode", "sentence-transformers"
|
||||
)
|
||||
# Backward compatibility with use_mlx
|
||||
if self.meta_data.get("use_mlx", False):
|
||||
self.embedding_mode = "mlx"
|
||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -400,26 +411,39 @@ class LeannSearcher:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
expected_zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||
print(f" Query: '{query}'")
|
||||
print(f" Top_k: {top_k}")
|
||||
print(f" Additional kwargs: {kwargs}")
|
||||
logger.info("🔍 LeannSearcher.search() called:")
|
||||
logger.info(f" Query: '{query}'")
|
||||
logger.info(f" Top_k: {top_k}")
|
||||
logger.info(f" Additional kwargs: {kwargs}")
|
||||
|
||||
# Use backend's compute_query_embedding method
|
||||
# This will automatically use embedding server if available and needed
|
||||
import time
|
||||
zmq_port = None
|
||||
|
||||
start_time = time.time()
|
||||
if recompute_embeddings:
|
||||
zmq_port = self.backend_impl._ensure_server_running(
|
||||
self.meta_path_str,
|
||||
port=expected_zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
del expected_zmq_port
|
||||
zmq_time = time.time() - start_time
|
||||
logger.info(f" Launching server time: {zmq_time} seconds")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
zmq_port=zmq_port,
|
||||
)
|
||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
print(f" Embedding time: {embedding_time} seconds")
|
||||
logger.info(f" Embedding time: {embedding_time} seconds")
|
||||
|
||||
start_time = time.time()
|
||||
results = self.backend_impl.search(
|
||||
@@ -434,14 +458,14 @@ class LeannSearcher:
|
||||
**kwargs,
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
print(f" Search time: {search_time} seconds")
|
||||
print(
|
||||
logger.info(f" Search time: {search_time} seconds")
|
||||
logger.info(
|
||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||
)
|
||||
|
||||
enriched_results = []
|
||||
if "labels" in results and "distances" in results:
|
||||
print(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||
for i, (string_id, dist) in enumerate(
|
||||
zip(results["labels"][0], results["distances"][0])
|
||||
):
|
||||
@@ -455,15 +479,15 @@ class LeannSearcher:
|
||||
metadata=passage_data.get("metadata", {}),
|
||||
)
|
||||
)
|
||||
print(
|
||||
logger.info(
|
||||
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
|
||||
)
|
||||
except KeyError:
|
||||
print(
|
||||
logger.error(
|
||||
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
|
||||
)
|
||||
|
||||
print(f" Final enriched results: {len(enriched_results)} passages")
|
||||
logger.info(f" Final enriched results: {len(enriched_results)} passages")
|
||||
return enriched_results
|
||||
|
||||
|
||||
@@ -485,10 +509,10 @@ class LeannChat:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**search_kwargs,
|
||||
):
|
||||
if llm_kwargs is None:
|
||||
@@ -502,7 +526,7 @@ class LeannChat:
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
zmq_port=zmq_port,
|
||||
expected_zmq_port=expected_zmq_port,
|
||||
**search_kwargs,
|
||||
)
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
import os
|
||||
import difflib
|
||||
import torch
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -28,6 +29,68 @@ def check_ollama_models() -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||
"""Check if a model exists in Ollama's remote library and return available tags
|
||||
|
||||
Returns:
|
||||
(model_exists, available_tags): bool and list of matching tags
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
import re
|
||||
|
||||
# Split model name and tag
|
||||
if ':' in model_name:
|
||||
base_model, requested_tag = model_name.split(':', 1)
|
||||
else:
|
||||
base_model, requested_tag = model_name, None
|
||||
|
||||
# First check if base model exists in library
|
||||
library_response = requests.get("https://ollama.com/library", timeout=8)
|
||||
if library_response.status_code != 200:
|
||||
return True, [] # Assume exists if can't check
|
||||
|
||||
# Extract model names from library page
|
||||
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
|
||||
|
||||
if base_model not in models_in_library:
|
||||
return False, [] # Base model doesn't exist
|
||||
|
||||
# If base model exists, get available tags
|
||||
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
|
||||
if tags_response.status_code != 200:
|
||||
return True, [] # Base model exists but can't get tags
|
||||
|
||||
# Extract tags for this model - be more specific to avoid HTML artifacts
|
||||
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
|
||||
raw_tags = re.findall(tag_pattern, tags_response.text)
|
||||
|
||||
# Clean up tags - remove HTML artifacts and duplicates
|
||||
available_tags = []
|
||||
seen = set()
|
||||
for tag in raw_tags:
|
||||
# Skip if it looks like HTML (contains < or >)
|
||||
if '<' in tag or '>' in tag:
|
||||
continue
|
||||
if tag not in seen:
|
||||
seen.add(tag)
|
||||
available_tags.append(tag)
|
||||
|
||||
# Check if exact model exists
|
||||
if requested_tag is None:
|
||||
# User just requested base model, suggest tags
|
||||
return True, available_tags[:10] # Return up to 10 tags
|
||||
else:
|
||||
exact_match = model_name in available_tags
|
||||
return exact_match, available_tags[:10]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If scraping fails, assume model might exist (don't block user)
|
||||
return True, []
|
||||
|
||||
|
||||
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
||||
"""Use intelligent fuzzy search for Ollama models"""
|
||||
if not available_models:
|
||||
@@ -243,24 +306,66 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
||||
if llm_type == "ollama":
|
||||
available_models = check_ollama_models()
|
||||
if available_models and model_name not in available_models:
|
||||
# Use intelligent fuzzy search based on locally installed models
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
if suggestions:
|
||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
else:
|
||||
error_msg += "\n\nYour installed models:\n"
|
||||
for i, model in enumerate(available_models[:8], 1):
|
||||
error_msg += f" {i}. {model}\n"
|
||||
if len(available_models) > 8:
|
||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||
|
||||
error_msg += "\nTo list all models: ollama list"
|
||||
error_msg += "\nTo download a new model: ollama pull <model_name>"
|
||||
error_msg += "\nBrowse models: https://ollama.com/library"
|
||||
# Check if the model exists remotely and get available tags
|
||||
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
|
||||
|
||||
if model_exists_remotely and model_name in available_tags:
|
||||
# Exact model exists remotely - suggest pulling it
|
||||
error_msg += f"\n\nTo install the requested model:\n"
|
||||
error_msg += f" ollama pull {model_name}\n"
|
||||
|
||||
# Show local alternatives
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
if suggestions:
|
||||
error_msg += "\nOr use one of these similar installed models:\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
|
||||
elif model_exists_remotely and available_tags:
|
||||
# Base model exists but requested tag doesn't - suggest correct tags
|
||||
base_model = model_name.split(':')[0]
|
||||
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
|
||||
|
||||
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
||||
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
|
||||
for i, tag in enumerate(available_tags[:8], 1):
|
||||
error_msg += f" {i}. ollama pull {tag}\n"
|
||||
if len(available_tags) > 8:
|
||||
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
|
||||
|
||||
# Also show local alternatives
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
if suggestions:
|
||||
error_msg += "\nOr use one of these similar installed models:\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
|
||||
else:
|
||||
# Model doesn't exist remotely - show fuzzy suggestions
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||
|
||||
if suggestions:
|
||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
else:
|
||||
error_msg += "\n\nYour installed models:\n"
|
||||
for i, model in enumerate(available_models[:8], 1):
|
||||
error_msg += f" {i}. {model}\n"
|
||||
if len(available_models) > 8:
|
||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||
|
||||
error_msg += "\n\nCommands:"
|
||||
error_msg += "\n ollama list # List installed models"
|
||||
if model_exists_remotely and available_tags:
|
||||
if model_name in available_tags:
|
||||
error_msg += f"\n ollama pull {model_name} # Install requested model"
|
||||
else:
|
||||
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
||||
error_msg += "\n https://ollama.com/library # Browse available models"
|
||||
return error_msg
|
||||
|
||||
elif llm_type == "hf":
|
||||
@@ -375,8 +480,9 @@ class OllamaChat(LLMInterface):
|
||||
"stream": False, # Keep it simple for now
|
||||
"options": kwargs,
|
||||
}
|
||||
logger.info(f"Sending request to Ollama: {payload}")
|
||||
logger.debug(f"Sending request to Ollama: {payload}")
|
||||
try:
|
||||
logger.info(f"Sending request to Ollama and waiting for response...")
|
||||
response = requests.post(full_url, data=json.dumps(payload))
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -396,7 +502,7 @@ class OllamaChat(LLMInterface):
|
||||
|
||||
|
||||
class HFChat(LLMInterface):
|
||||
"""LLM interface for local Hugging Face Transformers models."""
|
||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||
|
||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||
@@ -407,7 +513,7 @@ class HFChat(LLMInterface):
|
||||
raise ValueError(model_error)
|
||||
|
||||
try:
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -416,54 +522,100 @@ class HFChat(LLMInterface):
|
||||
|
||||
# Auto-detect device
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
self.device = "cuda"
|
||||
logger.info("CUDA is available. Using GPU.")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
self.device = "mps"
|
||||
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||
else:
|
||||
device = "cpu"
|
||||
self.device = "cpu"
|
||||
logger.info("No GPU detected. Using CPU.")
|
||||
|
||||
self.pipeline = pipeline("text-generation", model=model_name, device=device)
|
||||
# Load tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Move model to device if not using device_map
|
||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# Set pad token if not present
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Map OpenAI-style arguments to Hugging Face equivalents
|
||||
if "max_tokens" in kwargs:
|
||||
# Prefer user-provided max_new_tokens if both are present
|
||||
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
|
||||
# Remove the unsupported key to avoid errors in Transformers
|
||||
kwargs.pop("max_tokens")
|
||||
# Check if this is a Qwen model and add /no_think by default
|
||||
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
|
||||
|
||||
# For Qwen models, automatically add /no_think to the prompt
|
||||
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
|
||||
prompt = prompt + " /no_think"
|
||||
|
||||
# Prepare chat template
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Apply chat template if available
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
try:
|
||||
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chat template failed, using raw prompt: {e}")
|
||||
formatted_prompt = prompt
|
||||
else:
|
||||
# Fallback for models without chat template
|
||||
formatted_prompt = prompt
|
||||
|
||||
# Handle temperature=0 edge-case for greedy decoding
|
||||
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
|
||||
# Remove unsupported zero temperature and use deterministic generation
|
||||
kwargs.pop("temperature")
|
||||
kwargs.setdefault("do_sample", False)
|
||||
# Tokenize input
|
||||
inputs = self.tokenizer(
|
||||
formatted_prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
if self.device != "cpu":
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Sensible defaults for text generation
|
||||
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
|
||||
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
||||
results = self.pipeline(prompt, **params)
|
||||
# Set generation parameters
|
||||
generation_config = {
|
||||
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"top_p": kwargs.get("top_p", 0.9),
|
||||
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
}
|
||||
|
||||
# Handle temperature=0 for greedy decoding
|
||||
if generation_config["temperature"] == 0.0:
|
||||
generation_config["do_sample"] = False
|
||||
generation_config.pop("temperature")
|
||||
|
||||
# Handle different response formats from transformers
|
||||
if isinstance(results, list) and len(results) > 0:
|
||||
generated_text = (
|
||||
results[0].get("generated_text", "")
|
||||
if isinstance(results[0], dict)
|
||||
else str(results[0])
|
||||
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
**generation_config
|
||||
)
|
||||
else:
|
||||
generated_text = str(results)
|
||||
|
||||
# Extract only the newly generated portion by removing the original prompt
|
||||
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
||||
response = generated_text[len(prompt) :].strip()
|
||||
else:
|
||||
# Fallback: return the full response if prompt removal fails
|
||||
response = str(generated_text)
|
||||
|
||||
return response
|
||||
# Decode response
|
||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
return response.strip()
|
||||
|
||||
|
||||
class OpenAIChat(LLMInterface):
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
@@ -16,20 +12,20 @@ class LeannCLI:
|
||||
def __init__(self):
|
||||
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
self.node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
|
||||
def get_index_path(self, index_name: str) -> str:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
return str(index_dir / "documents.leann")
|
||||
|
||||
|
||||
def index_exists(self, index_name: str) -> bool:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
return meta_file.exists()
|
||||
|
||||
|
||||
def create_parser(self) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="leann",
|
||||
@@ -41,24 +37,32 @@ Examples:
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||
build_parser.add_argument("index_name", help="Index name")
|
||||
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
||||
build_parser.add_argument("--backend", type=str, default="hnsw", choices=["hnsw", "diskann"])
|
||||
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
||||
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
||||
build_parser.add_argument(
|
||||
"--docs", type=str, required=True, help="Documents directory"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-model", type=str, default="facebook/contriever"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--force", "-f", action="store_true", help="Force rebuild"
|
||||
)
|
||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||
build_parser.add_argument("--complexity", type=int, default=64)
|
||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||
search_parser.add_argument("index_name", help="Index name")
|
||||
@@ -68,12 +72,21 @@ Examples:
|
||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
search_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
||||
|
||||
search_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
default="global",
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
ask_parser.add_argument("index_name", help="Index name")
|
||||
ask_parser.add_argument("--llm", type=str, default="ollama", choices=["simulated", "ollama", "hf", "openai"])
|
||||
ask_parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="ollama",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
)
|
||||
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||
@@ -82,81 +95,91 @@ Examples:
|
||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
ask_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
||||
|
||||
ask_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
default="global",
|
||||
)
|
||||
|
||||
# List command
|
||||
list_parser = subparsers.add_parser("list", help="List all indexes")
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def list_indexes(self):
|
||||
print("Stored LEANN indexes:")
|
||||
|
||||
|
||||
if not self.indexes_dir.exists():
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
print(
|
||||
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||
|
||||
|
||||
if not index_dirs:
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
print(
|
||||
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
print(f"Found {len(index_dirs)} indexes:")
|
||||
for i, index_dir in enumerate(index_dirs, 1):
|
||||
index_name = index_dir.name
|
||||
status = "✓" if self.index_exists(index_name) else "✗"
|
||||
|
||||
|
||||
print(f" {i}. {index_name} [{status}]")
|
||||
if self.index_exists(index_name):
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (1024 * 1024)
|
||||
size_mb = sum(
|
||||
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
|
||||
) / (1024 * 1024)
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
|
||||
|
||||
if index_dirs:
|
||||
example_name = index_dirs[0].name
|
||||
print(f"\nUsage:")
|
||||
print(f" leann search {example_name} \"your query\"")
|
||||
print(f' leann search {example_name} "your query"')
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
|
||||
|
||||
def load_documents(self, docs_dir: str):
|
||||
print(f"Loading documents from {docs_dir}...")
|
||||
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md", ".docx"],
|
||||
).load_data(show_progress=True)
|
||||
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
|
||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||
return all_texts
|
||||
|
||||
|
||||
async def build_index(self, args):
|
||||
docs_dir = args.docs
|
||||
index_name = args.index_name
|
||||
index_dir = self.indexes_dir / index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
|
||||
if index_dir.exists() and not args.force:
|
||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||
return
|
||||
|
||||
|
||||
all_texts = self.load_documents(docs_dir)
|
||||
if not all_texts:
|
||||
print("No documents found")
|
||||
return
|
||||
|
||||
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
@@ -166,103 +189,107 @@ Examples:
|
||||
is_recompute=args.recompute,
|
||||
num_threads=args.num_threads,
|
||||
)
|
||||
|
||||
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"Index built at {index_path}")
|
||||
|
||||
|
||||
async def search_documents(self, args):
|
||||
index_name = args.index_name
|
||||
query = args.query
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
||||
print(
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
results = searcher.search(
|
||||
query,
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
)
|
||||
|
||||
|
||||
print(f"Search results for '{query}' (top {len(results)}):")
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"{i}. Score: {result.score:.3f}")
|
||||
print(f" {result.text[:200]}...")
|
||||
print()
|
||||
|
||||
|
||||
async def ask_questions(self, args):
|
||||
index_name = args.index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
||||
print(
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
print(f"Starting chat with index '{index_name}'...")
|
||||
print(f"Using {args.model} ({args.llm})")
|
||||
|
||||
|
||||
llm_config = {"type": args.llm, "model": args.model}
|
||||
if args.llm == "ollama":
|
||||
llm_config["host"] = args.host
|
||||
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
|
||||
if args.interactive:
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
|
||||
|
||||
while True:
|
||||
user_input = input("\nYou: ").strip()
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
response = chat.ask(
|
||||
query,
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
|
||||
async def run(self, args=None):
|
||||
parser = self.create_parser()
|
||||
|
||||
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
|
||||
if args.command == "list":
|
||||
self.list_indexes()
|
||||
elif args.command == "build":
|
||||
@@ -277,11 +304,12 @@ Examples:
|
||||
|
||||
def main():
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
cli = LeannCLI()
|
||||
asyncio.run(cli.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -6,14 +6,27 @@ Preserves all optimization parameters to ensure performance
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import List
|
||||
from typing import List, Dict, Any
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Set up logger with proper level
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
texts: List[str], model_name: str, mode: str = "sentence-transformers"
|
||||
texts: List[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
is_build: bool = False,
|
||||
batch_size: int = 32,
|
||||
adaptive_optimization: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Unified embedding computation entry point
|
||||
@@ -22,12 +35,21 @@ def compute_embeddings(
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Model name
|
||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
batch_size: Batch size for processing
|
||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
if mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(texts, model_name)
|
||||
return compute_embeddings_sentence_transformers(
|
||||
texts,
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
batch_size=batch_size,
|
||||
adaptive_optimization=adaptive_optimization,
|
||||
)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
elif mode == "mlx":
|
||||
@@ -42,27 +64,28 @@ def compute_embeddings_sentence_transformers(
|
||||
use_fp16: bool = True,
|
||||
device: str = "auto",
|
||||
batch_size: int = 32,
|
||||
is_build: bool = False,
|
||||
adaptive_optimization: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure consistency with original embedding_server
|
||||
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: SentenceTransformer model name
|
||||
model_name: Model name
|
||||
use_fp16: Whether to use FP16 precision
|
||||
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
|
||||
device: Device to use ('auto', 'cuda', 'mps', 'cpu')
|
||||
batch_size: Batch size for processing
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||
"""
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||
# Handle empty input
|
||||
if not texts:
|
||||
raise ValueError("Cannot compute embeddings for empty text list")
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||
)
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Auto-detect device
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
@@ -72,75 +95,139 @@ def compute_embeddings_sentence_transformers(
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
print(f"INFO: Using device: {device}")
|
||||
# Apply optimizations based on benchmark results
|
||||
if adaptive_optimization:
|
||||
# Use optimal batch_size constants for different devices based on benchmark results
|
||||
if device == "mps":
|
||||
batch_size = 128 # MPS optimal batch size from benchmark
|
||||
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
||||
batch_size = 32
|
||||
elif device == "cuda":
|
||||
batch_size = 256 # CUDA optimal batch size
|
||||
# Keep original batch_size for CPU
|
||||
|
||||
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||
"low_cpu_mem_usage": True,
|
||||
"_fast_init": True, # Skip weight initialization checks for faster loading
|
||||
}
|
||||
# Create cache key
|
||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"use_fast": True, # Use fast tokenizer for better runtime performance
|
||||
}
|
||||
|
||||
# Load SentenceTransformer (try local first, then network)
|
||||
print(f"INFO: Loading SentenceTransformer model: {model_name}")
|
||||
|
||||
try:
|
||||
# Try local loading (avoid network delays)
|
||||
model_kwargs["local_files_only"] = True
|
||||
tokenizer_kwargs["local_files_only"] = True
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=True,
|
||||
# Check if model is already cached
|
||||
if cache_key in _model_cache:
|
||||
logger.info(f"Using cached optimized model: {model_name}")
|
||||
model = _model_cache[cache_key]
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading and caching optimized SentenceTransformer model: {model_name}"
|
||||
)
|
||||
print("✅ Model loaded successfully! (local + optimized)")
|
||||
except Exception as e:
|
||||
print(f"Local loading failed ({e}), trying network download...")
|
||||
# Fallback to network loading
|
||||
model_kwargs["local_files_only"] = False
|
||||
tokenizer_kwargs["local_files_only"] = False
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
print("✅ Model loaded successfully! (network + optimized)")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Apply hardware optimizations
|
||||
if device == "cuda":
|
||||
# TODO: Haven't tested this yet
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.cuda.set_per_process_memory_fraction(0.9)
|
||||
elif device == "mps":
|
||||
try:
|
||||
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
||||
torch.mps.set_per_process_memory_fraction(0.9)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"Some MPS optimizations not available in this PyTorch version"
|
||||
)
|
||||
elif device == "cpu":
|
||||
# TODO: Haven't tested this yet
|
||||
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
||||
try:
|
||||
torch.backends.mkldnn.enabled = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Prepare optimized model and tokenizer parameters
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||
"low_cpu_mem_usage": True,
|
||||
"_fast_init": True,
|
||||
"attn_implementation": "eager", # Use eager attention for speed
|
||||
}
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"use_fast": True,
|
||||
"padding": True,
|
||||
"truncation": True,
|
||||
}
|
||||
|
||||
# Apply additional optimizations (if supported)
|
||||
if use_fp16 and device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"FP16 or compile optimization failed, continuing with default settings: {e}"
|
||||
# Try local loading first
|
||||
model_kwargs["local_files_only"] = True
|
||||
tokenizer_kwargs["local_files_only"] = True
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=True,
|
||||
)
|
||||
logger.info("Model loaded successfully! (local + optimized)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Local loading failed ({e}), trying network download...")
|
||||
# Fallback to network loading
|
||||
model_kwargs["local_files_only"] = False
|
||||
tokenizer_kwargs["local_files_only"] = False
|
||||
|
||||
# Compute embeddings (using SentenceTransformer's optimized implementation)
|
||||
print("INFO: Starting embedding computation...")
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
logger.info("Model loaded successfully! (network + optimized)")
|
||||
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=False, # Don't show progress bar in server environment
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False, # Keep consistent with original API behavior
|
||||
device=device,
|
||||
)
|
||||
# Apply additional optimizations based on mode
|
||||
if use_fp16 and device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = model.half()
|
||||
logger.info(f"Applied FP16 precision: {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"FP16 optimization failed: {e}")
|
||||
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
# Apply torch.compile optimization
|
||||
if device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = torch.compile(model, mode="reduce-overhead", dynamic=True)
|
||||
logger.info(f"Applied torch.compile optimization: {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"torch.compile optimization failed: {e}")
|
||||
|
||||
# Set model to eval mode and disable gradients for inference
|
||||
model.eval()
|
||||
for param in model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
# Cache the model
|
||||
_model_cache[cache_key] = model
|
||||
logger.info(f"Model cached: {cache_key}")
|
||||
|
||||
# Compute embeddings with optimized inference mode
|
||||
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
||||
|
||||
# Use torch.inference_mode for optimal performance
|
||||
with torch.inference_mode():
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
|
||||
# Validate results
|
||||
@@ -153,6 +240,7 @@ def compute_embeddings_sentence_transformers(
|
||||
|
||||
|
||||
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
try:
|
||||
import openai
|
||||
@@ -164,10 +252,17 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
# Cache OpenAI client
|
||||
cache_key = "openai_client"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
_model_cache[cache_key] = client
|
||||
logger.info("OpenAI client cached")
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# OpenAI has limits on batch size and input length
|
||||
@@ -194,12 +289,12 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Batch {i} failed: {e}")
|
||||
logger.error(f"Batch {i} failed: {e}")
|
||||
raise
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
@@ -207,22 +302,30 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
def compute_embeddings_mlx(
|
||||
chunks: List[str], model_name: str, batch_size: int = 16
|
||||
) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
# Cache MLX model and tokenizer
|
||||
cache_key = f"mlx_{model_name}"
|
||||
if cache_key in _model_cache:
|
||||
logger.info(f"Using cached MLX model: {model_name}")
|
||||
model, tokenizer = _model_cache[cache_key]
|
||||
else:
|
||||
logger.info(f"Loading and caching MLX model: {model_name}")
|
||||
model, tokenizer = load(model_name)
|
||||
_model_cache[cache_key] = (model, tokenizer)
|
||||
logger.info(f"MLX model cached: {cache_key}")
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
import threading
|
||||
import time
|
||||
import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import select
|
||||
import psutil
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format="%(levelname)s - %(name)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
"""Check if a port is in use"""
|
||||
@@ -36,11 +44,11 @@ def _check_process_matches_config(
|
||||
cmdline, port, expected_model, expected_passages_file
|
||||
)
|
||||
|
||||
print(f"DEBUG: No process found listening on port {port}")
|
||||
logger.debug(f"No process found listening on port {port}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not check process on port {port}: {e}")
|
||||
logger.warning(f"Could not check process on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -61,7 +69,7 @@ def _check_cmdline_matches_config(
|
||||
) -> bool:
|
||||
"""Check if command line matches our expected configuration."""
|
||||
cmdline_str = " ".join(cmdline)
|
||||
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
|
||||
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
||||
|
||||
# Check if it's our embedding server
|
||||
is_embedding_server = any(
|
||||
@@ -74,7 +82,7 @@ def _check_cmdline_matches_config(
|
||||
)
|
||||
|
||||
if not is_embedding_server:
|
||||
print(f"DEBUG: Process on port {port} is not our embedding server")
|
||||
logger.debug(f"Process on port {port} is not our embedding server")
|
||||
return False
|
||||
|
||||
# Check model name
|
||||
@@ -84,8 +92,8 @@ def _check_cmdline_matches_config(
|
||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||
|
||||
result = model_matches and passages_matches
|
||||
print(
|
||||
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||
logger.debug(
|
||||
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -132,10 +140,10 @@ def _find_compatible_port_or_next_available(
|
||||
|
||||
# Port is in use, check if it's compatible
|
||||
if _check_process_matches_config(port, model_name, passages_file):
|
||||
print(f"✅ Found compatible server on port {port}")
|
||||
logger.info(f"Found compatible server on port {port}")
|
||||
return port, True
|
||||
else:
|
||||
print(f"⚠️ Port {port} has incompatible server, trying next port...")
|
||||
logger.info(f"Port {port} has incompatible server, trying next port...")
|
||||
|
||||
raise RuntimeError(
|
||||
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||
@@ -194,17 +202,17 @@ class EmbeddingServerManager:
|
||||
port, model_name, passages_file
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(f"❌ {e}")
|
||||
logger.error(str(e))
|
||||
return False, port
|
||||
|
||||
if is_compatible:
|
||||
print(f"✅ Using existing compatible server on port {actual_port}")
|
||||
logger.info(f"Using existing compatible server on port {actual_port}")
|
||||
self.server_port = actual_port
|
||||
self.server_process = None # We don't own this process
|
||||
return True, actual_port
|
||||
|
||||
if actual_port != port:
|
||||
print(f"⚠️ Using port {actual_port} instead of {port}")
|
||||
logger.info(f"Using port {actual_port} instead of {port}")
|
||||
|
||||
# Start new server
|
||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||
@@ -221,19 +229,21 @@ class EmbeddingServerManager:
|
||||
return False
|
||||
|
||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||
print(
|
||||
f"✅ Existing server process (PID {self.server_process.pid}) is compatible"
|
||||
logger.info(
|
||||
f"Existing server process (PID {self.server_process.pid}) is compatible"
|
||||
)
|
||||
return True
|
||||
|
||||
print("⚠️ Existing server process is incompatible. Should start a new server.")
|
||||
logger.info(
|
||||
"Existing server process is incompatible. Should start a new server."
|
||||
)
|
||||
return False
|
||||
|
||||
def _start_new_server(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> tuple[bool, int]:
|
||||
"""Start a new embedding server on the given port."""
|
||||
print(f"INFO: Starting embedding server on port {port}...")
|
||||
logger.info(f"Starting embedding server on port {port}...")
|
||||
|
||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
@@ -241,7 +251,7 @@ class EmbeddingServerManager:
|
||||
self._launch_server_process(command, port)
|
||||
return self._wait_for_server_ready(port)
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: Failed to start embedding server: {e}")
|
||||
logger.error(f"Failed to start embedding server: {e}")
|
||||
return False, port
|
||||
|
||||
def _build_server_command(
|
||||
@@ -268,20 +278,18 @@ class EmbeddingServerManager:
|
||||
def _launch_server_process(self, command: list, port: int) -> None:
|
||||
"""Launch the server process."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Command: {' '.join(command)}")
|
||||
logger.info(f"Command: {' '.join(command)}")
|
||||
|
||||
# Let server output go directly to console
|
||||
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
stdout=None, # Direct to console
|
||||
stderr=None, # Direct to console
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
# Register atexit callback only when we actually start a process
|
||||
if not self._atexit_registered:
|
||||
@@ -294,49 +302,19 @@ class EmbeddingServerManager:
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print("✅ Embedding server is ready!")
|
||||
threading.Thread(target=self._log_monitor, daemon=True).start()
|
||||
logger.info("Embedding server is ready!")
|
||||
return True, port
|
||||
|
||||
if self.server_process.poll() is not None:
|
||||
print("❌ ERROR: Server terminated during startup.")
|
||||
self._print_recent_output()
|
||||
if self.server_process and self.server_process.poll() is not None:
|
||||
logger.error("Server terminated during startup.")
|
||||
return False, port
|
||||
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
|
||||
logger.error(f"Server failed to start within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False, port
|
||||
|
||||
def _print_recent_output(self):
|
||||
"""Print any recent output from the server process."""
|
||||
if not self.server_process or not self.server_process.stdout:
|
||||
return
|
||||
try:
|
||||
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
||||
output = self.server_process.stdout.read()
|
||||
if output:
|
||||
print(f"[{self.backend_module_name} OUTPUT]: {output}")
|
||||
except Exception as e:
|
||||
print(f"Error reading server output: {e}")
|
||||
|
||||
def _log_monitor(self):
|
||||
"""Monitors and prints the server's stdout and stderr."""
|
||||
if not self.server_process:
|
||||
return
|
||||
try:
|
||||
if self.server_process.stdout:
|
||||
while True:
|
||||
line = self.server_process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
print(
|
||||
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Log monitor error: {e}")
|
||||
|
||||
def stop_server(self):
|
||||
"""Stops the embedding server process if it's running."""
|
||||
if not self.server_process:
|
||||
@@ -347,18 +325,24 @@ class EmbeddingServerManager:
|
||||
self.server_process = None
|
||||
return
|
||||
|
||||
print(
|
||||
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||
logger.info(
|
||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print(f"INFO: Server process {self.server_process.pid} terminated.")
|
||||
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||
logger.warning(
|
||||
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||
)
|
||||
self.server_process.kill()
|
||||
|
||||
# Clean up process resources to prevent resource tracker warnings
|
||||
try:
|
||||
self.server_process.wait() # Ensure process is fully cleaned up
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.server_process = None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
from typing import Dict, Any, List, Literal
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
|
||||
|
||||
class LeannBackendBuilderInterface(ABC):
|
||||
@@ -34,6 +34,13 @@ class LeannBackendSearcherInterface(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||
) -> int:
|
||||
"""Ensure server is running"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
@@ -44,7 +51,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
@@ -57,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
**kwargs: Backend-specific parameters
|
||||
|
||||
Returns:
|
||||
@@ -67,7 +74,10 @@ class LeannBackendSearcherInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Literal
|
||||
from typing import Dict, Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -43,8 +42,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||
)
|
||||
|
||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name=backend_module_name
|
||||
backend_module_name=backend_module_name,
|
||||
)
|
||||
|
||||
def _load_meta(self) -> Dict[str, Any]:
|
||||
@@ -68,14 +69,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"Cannot use recompute mode without 'embedding_model' in meta.json."
|
||||
)
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
embedding_mode=self.embedding_mode,
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=kwargs.get("distance_metric"),
|
||||
embedding_mode=embedding_mode,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
@@ -86,7 +85,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
return actual_port
|
||||
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int = 5557,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
@@ -102,6 +104,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
# Try to use embedding server if available and requested
|
||||
if use_server_if_available:
|
||||
try:
|
||||
# TODO: Maybe we can directly use this port here?
|
||||
# For this internal method, it's ok to assume that the server is running
|
||||
# on that port?
|
||||
|
||||
# Ensure we have a server with passages_file for compatibility
|
||||
passages_source_file = (
|
||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
@@ -118,7 +124,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
print("⏭️ Falling back to direct model loading...")
|
||||
|
||||
# Fallback to direct computation
|
||||
from .api import compute_embeddings
|
||||
from .embedding_compute import compute_embeddings
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||
@@ -165,7 +171,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -179,7 +185,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||
|
||||
Returns:
|
||||
|
||||
40
packages/leann/README.md
Normal file
40
packages/leann/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# LEANN - The smallest vector index in the world
|
||||
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Default installation (HNSW backend, recommended)
|
||||
uv pip install leann
|
||||
|
||||
# With DiskANN backend (for large-scale deployments)
|
||||
uv pip install leann[diskann]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.build_index("my_index.leann")
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher("my_index.leann")
|
||||
results = searcher.search("storage savings", top_k=3)
|
||||
|
||||
# Chat with your data
|
||||
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
|
||||
response = chat.ask("How much storage does LEANN save?")
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
12
packages/leann/__init__.py
Normal file
12
packages/leann/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
LEANN - Low-storage Embedding Approximation for Neural Networks
|
||||
|
||||
A revolutionary vector database that democratizes personal AI.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# Re-export main API from leann-core
|
||||
from leann_core import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat"]
|
||||
42
packages/leann/pyproject.toml
Normal file
42
packages/leann/pyproject.toml
Normal file
@@ -0,0 +1,42 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.1.0"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { text = "MIT" }
|
||||
authors = [
|
||||
{ name = "LEANN Team" }
|
||||
]
|
||||
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
|
||||
# Default installation: core + hnsw
|
||||
dependencies = [
|
||||
"leann-core>=0.1.0",
|
||||
"leann-backend-hnsw>=0.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
diskann = [
|
||||
"leann-backend-diskann>=0.1.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/yourusername/leann"
|
||||
Documentation = "https://leann.readthedocs.io"
|
||||
Repository = "https://github.com/yourusername/leann"
|
||||
Issues = "https://github.com/yourusername/leann/issues"
|
||||
@@ -33,8 +33,8 @@ dependencies = [
|
||||
"msgpack>=1.1.1",
|
||||
"llama-index-vector-stores-faiss>=0.4.0",
|
||||
"llama-index-embeddings-huggingface>=0.5.5",
|
||||
"mlx>=0.26.3",
|
||||
"mlx-lm>=0.26.0",
|
||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||
"psutil>=5.8.0",
|
||||
]
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ else:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
|
||||
use_mlx=True,
|
||||
embedding_mode="mlx",
|
||||
)
|
||||
|
||||
# 2. Add documents
|
||||
|
||||
Reference in New Issue
Block a user