Compare commits
67 Commits
v0.3.4
...
feature/cu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9996c29618 | ||
|
|
12951ad4d5 | ||
|
|
a878d2459b | ||
|
|
6c39a3427f | ||
|
|
17cbd07b25 | ||
|
|
3629ccf8f7 | ||
|
|
a0bbf831db | ||
|
|
76cc798e3e | ||
|
|
d599566fd7 | ||
|
|
00770aebbb | ||
|
|
e268392d5b | ||
|
|
eb909ccec5 | ||
|
|
969f514564 | ||
|
|
1ef9cba7de | ||
|
|
a63550944b | ||
|
|
97493a2896 | ||
|
|
f7d2dc6e7c | ||
|
|
ea86b283cb | ||
|
|
e7519bceaa | ||
|
|
abf0b2c676 | ||
|
|
3c4785bb63 | ||
|
|
930b79cc98 | ||
|
|
3766ad1fd2 | ||
|
|
c3aceed1e0 | ||
|
|
dc6c9f696e | ||
|
|
2406c41eef | ||
|
|
d4f5f2896f | ||
|
|
366984e92e | ||
|
|
64b92a04a7 | ||
|
|
a85d0ad4a7 | ||
|
|
dbb5f4d352 | ||
|
|
f180b83589 | ||
|
|
abf312d998 | ||
|
|
ab251ab751 | ||
|
|
28085f6f04 | ||
|
|
6495833887 | ||
|
|
5543b3c5f7 | ||
|
|
a99983b3d9 | ||
|
|
36482e016c | ||
|
|
32967daf81 | ||
|
|
b4bb8dec75 | ||
|
|
5ba9cf6442 | ||
|
|
1484406a8d | ||
|
|
761ec1f0ac | ||
|
|
4808afc686 | ||
|
|
0bba4b2157 | ||
|
|
e67b5f44fa | ||
|
|
658bce47ef | ||
|
|
6b399ad8d2 | ||
|
|
16f35aa067 | ||
|
|
ab9c6bd69e | ||
|
|
e2b37914ce | ||
|
|
e588100674 | ||
|
|
fecee94af1 | ||
|
|
01475c10a0 | ||
|
|
c8aa063f48 | ||
|
|
576beb13db | ||
|
|
63c7b0c8a3 | ||
|
|
ec889f7ef4 | ||
|
|
322e5c162d | ||
|
|
edde0cdeb2 | ||
|
|
db7ba27ff6 | ||
|
|
5f7806e16f | ||
|
|
d034e2195b | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
ad0d2faabc |
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Bug Report
|
||||
description: Report a bug in LEANN
|
||||
labels: ["bug"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: A clear description of the bug
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduce
|
||||
attributes:
|
||||
label: How to reproduce
|
||||
placeholder: |
|
||||
1. Install with...
|
||||
2. Run command...
|
||||
3. See error
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: error
|
||||
attributes:
|
||||
label: Error message
|
||||
description: Paste any error messages
|
||||
render: shell
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: LEANN Version
|
||||
placeholder: "0.1.0"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: os
|
||||
attributes:
|
||||
label: Operating System
|
||||
options:
|
||||
- macOS
|
||||
- Linux
|
||||
- Windows
|
||||
- Docker
|
||||
validations:
|
||||
required: true
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Documentation
|
||||
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||
about: Read the docs first
|
||||
- name: Discussions
|
||||
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||
about: Ask questions and share ideas
|
||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: Feature Request
|
||||
description: Suggest a new feature for LEANN
|
||||
labels: ["enhancement"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: What problem does this solve?
|
||||
description: Describe the problem or need
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: solution
|
||||
attributes:
|
||||
label: Proposed solution
|
||||
description: How would you like this to work?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: example
|
||||
attributes:
|
||||
label: Example usage
|
||||
description: Show how the API might look
|
||||
render: python
|
||||
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
## What does this PR do?
|
||||
|
||||
<!-- Brief description of your changes -->
|
||||
|
||||
## Related Issues
|
||||
|
||||
Fixes #
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Tests pass (`uv run pytest`)
|
||||
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||
111
.github/workflows/build-reusable.yml
vendored
@@ -17,26 +17,17 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install ruff
|
||||
- name: Run pre-commit with only lint group (no project deps)
|
||||
run: |
|
||||
uv tool install ruff
|
||||
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||
|
||||
- name: Run ruff check
|
||||
run: |
|
||||
ruff check .
|
||||
|
||||
- name: Run ruff format check
|
||||
run: |
|
||||
ruff format --check .
|
||||
|
||||
build:
|
||||
needs: lint
|
||||
@@ -103,14 +94,11 @@ jobs:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
@@ -168,11 +156,24 @@ jobs:
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --system auditwheel
|
||||
uv python install ${{ matrix.python }}
|
||||
uv venv --python ${{ matrix.python }} .uv-build
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
BUILD_PY=".uv-build\\Scripts\\python.exe"
|
||||
else
|
||||
uv pip install --system delocate
|
||||
BUILD_PY=".uv-build/bin/python"
|
||||
fi
|
||||
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --python "$BUILD_PY" auditwheel
|
||||
else
|
||||
uv pip install --python "$BUILD_PY" delocate
|
||||
fi
|
||||
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
|
||||
else
|
||||
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Set macOS environment variables
|
||||
@@ -308,18 +309,66 @@ jobs:
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create a virtual environment with the correct Python version
|
||||
# Create uv-managed virtual environment with the requested interpreter
|
||||
uv python install ${{ matrix.python }}
|
||||
uv venv --python ${{ matrix.python }}
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Install packages using --find-links to prioritize local builds
|
||||
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
|
||||
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
UV_PY=".venv\\Scripts\\python.exe"
|
||||
else
|
||||
UV_PY=".venv/bin/python"
|
||||
fi
|
||||
|
||||
# Install test dependencies using extras
|
||||
uv pip install -e ".[test]"
|
||||
# Install test dependency group only (avoids reinstalling project package)
|
||||
uv pip install --python "$UV_PY" --group test
|
||||
|
||||
# Install core wheel built in this job
|
||||
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||
if [[ -n "$CORE_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$CORE_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
|
||||
fi
|
||||
|
||||
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||
|
||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
fi
|
||||
fi
|
||||
|
||||
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||
if [[ -z "$HNSW_WHL" ]]; then
|
||||
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||
fi
|
||||
if [[ -n "$HNSW_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$HNSW_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
|
||||
fi
|
||||
|
||||
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||
if [[ -z "$DISKANN_WHL" ]]; then
|
||||
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||
fi
|
||||
if [[ -n "$DISKANN_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$DISKANN_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
|
||||
fi
|
||||
|
||||
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||
if [[ -n "$LEANN_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$LEANN_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
|
||||
fi
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
|
||||
2
.github/workflows/link-check.yml
vendored
@@ -14,6 +14,6 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
22
.gitignore
vendored
@@ -18,6 +18,7 @@ demo/experiment_results/**/*.json
|
||||
*.eml
|
||||
*.emlx
|
||||
*.json
|
||||
*.png
|
||||
!.vscode/*.json
|
||||
*.sh
|
||||
*.txt
|
||||
@@ -90,14 +91,21 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
|
||||
*.npy
|
||||
*.db
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
paru-bin/
|
||||
|
||||
CLAUDE.md
|
||||
CLAUDE.local.md
|
||||
.claude/*.local.*
|
||||
.claude/local/*
|
||||
benchmarks/data/
|
||||
|
||||
## multi vector
|
||||
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||
|
||||
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
|
||||
# If you need to commit a specific demo PDF, remove this negation locally.
|
||||
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||
|
||||
# AUR build directory (Arch Linux)
|
||||
paru-bin/
|
||||
|
||||
3
.gitmodules
vendored
@@ -16,5 +16,4 @@
|
||||
url = https://github.com/zeromq/libzmq.git
|
||||
[submodule "packages/astchunk-leann"]
|
||||
path = packages/astchunk-leann
|
||||
url = git@github.com:yichuan-w/astchunk-leann.git
|
||||
branch = main
|
||||
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||
|
||||
499
README.md
@@ -8,19 +8,35 @@
|
||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
|
||||
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q">
|
||||
<img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||
</a>
|
||||
<a href="assets/wechat_user_group.JPG" title="Join WeChat group">
|
||||
<img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://forms.gle/rDbZf864gMNxhpTq8">
|
||||
<img src="https://img.shields.io/badge/📣_Community_Survey-Help_Shape_v0.4-007ec6?style=for-the-badge&logo=google-forms&logoColor=white" alt="Take Survey">
|
||||
</a>
|
||||
<p>
|
||||
We track <b>zero telemetry</b>. This survey is the ONLY way to tell us if you want <br>
|
||||
<b>GPU Acceleration</b> or <b>More Integrations</b> next.<br>
|
||||
👉 <a href="https://forms.gle/rDbZf864gMNxhpTq8"><b>Click here to cast your vote (2 mins)</b></a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
@@ -72,8 +88,9 @@ uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
|
||||
<!--
|
||||
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
||||
> Low-resource? See "Low-resource setups" in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
@@ -176,13 +193,16 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, ChatGPT conversations, Claude conversations, iMessage conversations, and **live data from any platform through MCP (Model Context Protocol) servers** - including Slack, Twitter, and more.
|
||||
|
||||
|
||||
|
||||
### Generation Model Setup
|
||||
|
||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
#### LLM Backend
|
||||
|
||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
|
||||
|
||||
|
||||
<details>
|
||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||
@@ -193,6 +213,69 @@ Set your OpenAI API key as an environment variable:
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
Make sure to use `--llm openai` flag when using the CLI.
|
||||
You can also specify the model name with `--llm-model <model-name>` flag.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🛠️ Supported LLM & Embedding Providers (via OpenAI Compatibility)</strong></summary>
|
||||
|
||||
Thanks to the widespread adoption of the OpenAI API format, LEANN is compatible out-of-the-box with a vast array of LLM and embedding providers. Simply set the `OPENAI_BASE_URL` and `OPENAI_API_KEY` environment variables to connect to your preferred service.
|
||||
|
||||
```sh
|
||||
export OPENAI_API_KEY="xxx"
|
||||
export OPENAI_BASE_URL="http://localhost:1234/v1" # base url of the provider
|
||||
```
|
||||
|
||||
To use OpenAI compatible endpoint with the CLI interface:
|
||||
|
||||
If you are using it for text generation, make sure to use `--llm openai` flag and specify the model name with `--llm-model <model-name>` flag.
|
||||
|
||||
If you are using it for embedding, set the `--embedding-mode openai` flag and specify the model name with `--embedding-model <MODEL>`.
|
||||
|
||||
-----
|
||||
|
||||
|
||||
Below is a list of base URLs for common providers to get you started.
|
||||
|
||||
|
||||
### 🖥️ Local Inference Engines (Recommended for full privacy)
|
||||
|
||||
| Provider | Sample Base URL |
|
||||
| ---------------- | --------------------------- |
|
||||
| **Ollama** | `http://localhost:11434/v1` |
|
||||
| **LM Studio** | `http://localhost:1234/v1` |
|
||||
| **vLLM** | `http://localhost:8000/v1` |
|
||||
| **llama.cpp** | `http://localhost:8080/v1` |
|
||||
| **SGLang** | `http://localhost:30000/v1` |
|
||||
| **LiteLLM** | `http://localhost:4000` |
|
||||
|
||||
-----
|
||||
|
||||
### ☁️ Cloud Providers
|
||||
|
||||
> **🚨 A Note on Privacy:** Before choosing a cloud provider, carefully review their privacy and data retention policies. Depending on their terms, your data may be used for their own purposes, including but not limited to human reviews and model training, which can lead to serious consequences if not handled properly.
|
||||
|
||||
|
||||
| Provider | Base URL |
|
||||
| ---------------- | ---------------------------------------------------------- |
|
||||
| **OpenAI** | `https://api.openai.com/v1` |
|
||||
| **OpenRouter** | `https://openrouter.ai/api/v1` |
|
||||
| **Gemini** | `https://generativelanguage.googleapis.com/v1beta/openai/` |
|
||||
| **x.AI (Grok)** | `https://api.x.ai/v1` |
|
||||
| **Groq AI** | `https://api.groq.com/openai/v1` |
|
||||
| **DeepSeek** | `https://api.deepseek.com/v1` |
|
||||
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||
| **Anthropic** | `https://api.anthropic.com/v1` |
|
||||
|
||||
|
||||
|
||||
|
||||
If your provider isn't on this list, don't worry! Check their documentation for an OpenAI-compatible endpoint—chances are, it's OpenAI Compatible too!
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -246,7 +329,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||
|
||||
@@ -477,10 +560,386 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### 🤖 ChatGPT Chat History: Your Personal AI Conversation Archive!
|
||||
|
||||
Transform your ChatGPT conversations into a searchable knowledge base! Search through all your ChatGPT discussions about coding, research, brainstorming, and more.
|
||||
|
||||
```bash
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_export.html --query "How do I create a list in Python?"
|
||||
```
|
||||
|
||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your ChatGPT discussions again.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Export ChatGPT Data</strong></summary>
|
||||
|
||||
**Step-by-step export process:**
|
||||
|
||||
1. **Sign in to ChatGPT**
|
||||
2. **Click your profile icon** in the top right corner
|
||||
3. **Navigate to Settings** → **Data Controls**
|
||||
4. **Click "Export"** under Export Data
|
||||
5. **Confirm the export** request
|
||||
6. **Download the ZIP file** from the email link (expires in 24 hours)
|
||||
7. **Extract or use directly** with LEANN
|
||||
|
||||
**Supported formats:**
|
||||
- `.html` files from ChatGPT exports
|
||||
- `.zip` archives from ChatGPT
|
||||
- Directories with multiple export files
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: ChatGPT-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--export-path PATH # Path to ChatGPT export file (.html/.zip) or directory (default: ./chatgpt_export)
|
||||
--separate-messages # Process each message separately instead of concatenated conversations
|
||||
--chunk-size N # Text chunk size (default: 512)
|
||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage with HTML export
|
||||
python -m apps.chatgpt_rag --export-path conversations.html
|
||||
|
||||
# Process ZIP archive from ChatGPT
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_export.zip
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_data.html --query "Python programming help"
|
||||
|
||||
# Process individual messages for fine-grained search
|
||||
python -m apps.chatgpt_rag --separate-messages --export-path chatgpt_export.html
|
||||
|
||||
# Process directory containing multiple exports
|
||||
python -m apps.chatgpt_rag --export-path ./chatgpt_exports/ --max-items 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your ChatGPT conversations are indexed, you can search with queries like:
|
||||
- "What did I ask ChatGPT about Python programming?"
|
||||
- "Show me conversations about machine learning algorithms"
|
||||
- "Find discussions about web development frameworks"
|
||||
- "What coding advice did ChatGPT give me?"
|
||||
- "Search for conversations about debugging techniques"
|
||||
- "Find ChatGPT's recommendations for learning resources"
|
||||
|
||||
</details>
|
||||
|
||||
### 🤖 Claude Chat History: Your Personal AI Conversation Archive!
|
||||
|
||||
Transform your Claude conversations into a searchable knowledge base! Search through all your Claude discussions about coding, research, brainstorming, and more.
|
||||
|
||||
```bash
|
||||
python -m apps.claude_rag --export-path claude_export.json --query "What did I ask about Python dictionaries?"
|
||||
```
|
||||
|
||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your Claude discussions again.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Export Claude Data</strong></summary>
|
||||
|
||||
**Step-by-step export process:**
|
||||
|
||||
1. **Open Claude** in your browser
|
||||
2. **Navigate to Settings** (look for gear icon or settings menu)
|
||||
3. **Find Export/Download** options in your account settings
|
||||
4. **Download conversation data** (usually in JSON format)
|
||||
5. **Place the file** in your project directory
|
||||
|
||||
*Note: Claude export methods may vary depending on the interface you're using. Check Claude's help documentation for the most current export instructions.*
|
||||
|
||||
**Supported formats:**
|
||||
- `.json` files (recommended)
|
||||
- `.zip` archives containing JSON data
|
||||
- Directories with multiple export files
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Claude-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--export-path PATH # Path to Claude export file (.json/.zip) or directory (default: ./claude_export)
|
||||
--separate-messages # Process each message separately instead of concatenated conversations
|
||||
--chunk-size N # Text chunk size (default: 512)
|
||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage with JSON export
|
||||
python -m apps.claude_rag --export-path my_claude_conversations.json
|
||||
|
||||
# Process ZIP archive from Claude
|
||||
python -m apps.claude_rag --export-path claude_export.zip
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.claude_rag --export-path claude_data.json --query "machine learning advice"
|
||||
|
||||
# Process individual messages for fine-grained search
|
||||
python -m apps.claude_rag --separate-messages --export-path claude_export.json
|
||||
|
||||
# Process directory containing multiple exports
|
||||
python -m apps.claude_rag --export-path ./claude_exports/ --max-items 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your Claude conversations are indexed, you can search with queries like:
|
||||
- "What did I ask Claude about Python programming?"
|
||||
- "Show me conversations about machine learning algorithms"
|
||||
- "Find discussions about software architecture patterns"
|
||||
- "What debugging advice did Claude give me?"
|
||||
- "Search for conversations about data structures"
|
||||
- "Find Claude's recommendations for learning resources"
|
||||
|
||||
</details>
|
||||
|
||||
### 💬 iMessage History: Your Personal Conversation Archive!
|
||||
|
||||
Transform your iMessage conversations into a searchable knowledge base! Search through all your text messages, group chats, and conversations with friends, family, and colleagues.
|
||||
|
||||
```bash
|
||||
python -m apps.imessage_rag --query "What did we discuss about the weekend plans?"
|
||||
```
|
||||
|
||||
**Unlock your message history.** Never lose track of important conversations, shared links, or memorable moments from your iMessage history.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Access iMessage Data</strong></summary>
|
||||
|
||||
**iMessage data location:**
|
||||
|
||||
iMessage conversations are stored in a SQLite database on your Mac at:
|
||||
```
|
||||
~/Library/Messages/chat.db
|
||||
```
|
||||
|
||||
**Important setup requirements:**
|
||||
|
||||
1. **Grant Full Disk Access** to your terminal or IDE:
|
||||
- Open **System Preferences** → **Security & Privacy** → **Privacy**
|
||||
- Select **Full Disk Access** from the left sidebar
|
||||
- Click the **+** button and add your terminal app (Terminal, iTerm2) or IDE (VS Code, etc.)
|
||||
- Restart your terminal/IDE after granting access
|
||||
|
||||
2. **Alternative: Use a backup database**
|
||||
- If you have Time Machine backups or manual copies of the database
|
||||
- Use `--db-path` to specify a custom location
|
||||
|
||||
**Supported formats:**
|
||||
- Direct access to `~/Library/Messages/chat.db` (default)
|
||||
- Custom database path with `--db-path`
|
||||
- Works with backup copies of the database
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: iMessage-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--db-path PATH # Path to chat.db file (default: ~/Library/Messages/chat.db)
|
||||
--concatenate-conversations # Group messages by conversation (default: True)
|
||||
--no-concatenate-conversations # Process each message individually
|
||||
--chunk-size N # Text chunk size (default: 1000)
|
||||
--chunk-overlap N # Overlap between chunks (default: 200)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage (requires Full Disk Access)
|
||||
python -m apps.imessage_rag
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.imessage_rag --query "family dinner plans"
|
||||
|
||||
# Use custom database path
|
||||
python -m apps.imessage_rag --db-path /path/to/backup/chat.db
|
||||
|
||||
# Process individual messages instead of conversations
|
||||
python -m apps.imessage_rag --no-concatenate-conversations
|
||||
|
||||
# Limit processing for testing
|
||||
python -m apps.imessage_rag --max-items 100 --query "weekend"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your iMessage conversations are indexed, you can search with queries like:
|
||||
- "What did we discuss about vacation plans?"
|
||||
- "Find messages about restaurant recommendations"
|
||||
- "Show me conversations with John about the project"
|
||||
- "Search for shared links about technology"
|
||||
- "Find group chat discussions about weekend events"
|
||||
- "What did mom say about the family gathering?"
|
||||
|
||||
</details>
|
||||
|
||||
### MCP Integration: RAG on Live Data from Any Platform
|
||||
|
||||
Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
|
||||
|
||||
**Key Benefits:**
|
||||
- **Live Data Access**: Fetch real-time data without manual exports
|
||||
- **Standardized Protocol**: Use any MCP-compatible server
|
||||
- **Easy Extension**: Add new platforms with minimal code
|
||||
- **Secure Access**: MCP servers handle authentication
|
||||
|
||||
#### 💬 Slack Messages: Search Your Team Conversations
|
||||
|
||||
Transform your Slack workspace into a searchable knowledge base! Find discussions, decisions, and shared knowledge across all your channels.
|
||||
|
||||
```bash
|
||||
# Test MCP server connection
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
|
||||
|
||||
# Index and search Slack messages
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "my-team" \
|
||||
--channels general dev-team random \
|
||||
--query "What did we decide about the product launch?"
|
||||
```
|
||||
|
||||
**📖 Comprehensive Setup Guide**: For detailed setup instructions, troubleshooting common issues (like "users cache is not ready yet"), and advanced configuration options, see our [**Slack Setup Guide**](docs/slack-setup-guide.md).
|
||||
|
||||
**Quick Setup:**
|
||||
1. Install a Slack MCP server (e.g., `npm install -g slack-mcp-server`)
|
||||
2. Create a Slack App and get API credentials (see detailed guide above)
|
||||
3. Set environment variables:
|
||||
```bash
|
||||
export SLACK_BOT_TOKEN="xoxb-your-bot-token"
|
||||
export SLACK_APP_TOKEN="xapp-your-app-token" # Optional
|
||||
```
|
||||
4. Test connection with `--test-connection` flag
|
||||
|
||||
**Arguments:**
|
||||
- `--mcp-server`: Command to start the Slack MCP server
|
||||
- `--workspace-name`: Slack workspace name for organization
|
||||
- `--channels`: Specific channels to index (optional)
|
||||
- `--concatenate-conversations`: Group messages by channel (default: true)
|
||||
- `--max-messages-per-channel`: Limit messages per channel (default: 100)
|
||||
- `--max-retries`: Maximum retries for cache sync issues (default: 5)
|
||||
- `--retry-delay`: Initial delay between retries in seconds (default: 2.0)
|
||||
|
||||
#### 🐦 Twitter Bookmarks: Your Personal Tweet Library
|
||||
|
||||
Search through your Twitter bookmarks! Find that perfect article, thread, or insight you saved for later.
|
||||
|
||||
```bash
|
||||
# Test MCP server connection
|
||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --test-connection
|
||||
|
||||
# Index and search Twitter bookmarks
|
||||
python -m apps.twitter_rag \
|
||||
--mcp-server "twitter-mcp-server" \
|
||||
--max-bookmarks 1000 \
|
||||
--query "What AI articles did I bookmark about machine learning?"
|
||||
```
|
||||
|
||||
**Setup Requirements:**
|
||||
1. Install a Twitter MCP server (e.g., `npm install -g twitter-mcp-server`)
|
||||
2. Get Twitter API credentials:
|
||||
- Apply for a Twitter Developer Account at [developer.twitter.com](https://developer.twitter.com)
|
||||
- Create a new app in the Twitter Developer Portal
|
||||
- Generate API keys and access tokens with "Read" permissions
|
||||
- For bookmarks access, you may need Twitter API v2 with appropriate scopes
|
||||
```bash
|
||||
export TWITTER_API_KEY="your-api-key"
|
||||
export TWITTER_API_SECRET="your-api-secret"
|
||||
export TWITTER_ACCESS_TOKEN="your-access-token"
|
||||
export TWITTER_ACCESS_TOKEN_SECRET="your-access-token-secret"
|
||||
```
|
||||
3. Test connection with `--test-connection` flag
|
||||
|
||||
**Arguments:**
|
||||
- `--mcp-server`: Command to start the Twitter MCP server
|
||||
- `--username`: Filter bookmarks by username (optional)
|
||||
- `--max-bookmarks`: Maximum bookmarks to fetch (default: 1000)
|
||||
- `--no-tweet-content`: Exclude tweet content, only metadata
|
||||
- `--no-metadata`: Exclude engagement metadata
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
**Slack Queries:**
|
||||
- "What did the team discuss about the project deadline?"
|
||||
- "Find messages about the new feature launch"
|
||||
- "Show me conversations about budget planning"
|
||||
- "What decisions were made in the dev-team channel?"
|
||||
|
||||
**Twitter Queries:**
|
||||
- "What AI articles did I bookmark last month?"
|
||||
- "Find tweets about machine learning techniques"
|
||||
- "Show me bookmarked threads about startup advice"
|
||||
- "What Python tutorials did I save?"
|
||||
|
||||
</details>
|
||||
<summary><strong>🔧 Using MCP with CLI Commands</strong></summary>
|
||||
|
||||
**Want to use MCP data with regular LEANN CLI?** You can combine MCP apps with CLI commands:
|
||||
|
||||
```bash
|
||||
# Step 1: Use MCP app to fetch and index data
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --workspace-name "my-team"
|
||||
|
||||
# Step 2: The data is now indexed and available via CLI
|
||||
leann search slack_messages "project deadline"
|
||||
leann ask slack_messages "What decisions were made about the product launch?"
|
||||
|
||||
# Same for Twitter bookmarks
|
||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server"
|
||||
leann search twitter_bookmarks "machine learning articles"
|
||||
```
|
||||
|
||||
**MCP vs Manual Export:**
|
||||
- **MCP**: Live data, automatic updates, requires server setup
|
||||
- **Manual Export**: One-time setup, works offline, requires manual data export
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Adding New MCP Platforms</strong></summary>
|
||||
|
||||
Want to add support for other platforms? LEANN's MCP integration is designed for easy extension:
|
||||
|
||||
1. **Find or create an MCP server** for your platform
|
||||
2. **Create a reader class** following the pattern in `apps/slack_data/slack_mcp_reader.py`
|
||||
3. **Create a RAG application** following the pattern in `apps/slack_rag.py`
|
||||
4. **Test and contribute** back to the community!
|
||||
|
||||
**Popular MCP servers to explore:**
|
||||
- GitHub repositories and issues
|
||||
- Discord messages
|
||||
- Notion pages
|
||||
- Google Drive documents
|
||||
- And many more in the MCP ecosystem!
|
||||
|
||||
</details>
|
||||
|
||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||
|
||||
<details>
|
||||
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
||||
<summary><strong>AST‑Aware Code Chunking</strong></summary>
|
||||
|
||||
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||
|
||||
@@ -508,7 +967,7 @@ Try our fully agentic pipeline with auto query rewriting, semantic search planni
|
||||
|
||||
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||
|
||||
## 🖥️ Command Line Interface
|
||||
## Command Line Interface
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
@@ -546,6 +1005,9 @@ leann search my-docs "machine learning concepts"
|
||||
# Interactive chat with your documents
|
||||
leann ask my-docs --interactive
|
||||
|
||||
# Ask a single question (non-interactive)
|
||||
leann ask my-docs "Where are prompts configured?"
|
||||
|
||||
# List all your indexes
|
||||
leann list
|
||||
|
||||
@@ -596,10 +1058,10 @@ Options:
|
||||
leann ask INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
**List Command:**
|
||||
@@ -706,9 +1168,8 @@ results = searcher.search("banana‑crocodile", use_grep=True, top_k=1)
|
||||
## Reproduce Our Results
|
||||
|
||||
```bash
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||
uv run benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
uv run benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
@@ -748,7 +1209,7 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||
|
||||
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
||||
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan), [Aakash Suresh](https://github.com/ASuresh0524)
|
||||
|
||||
|
||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||
@@ -765,3 +1226,7 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
|
||||
<p align="center">
|
||||
Made with ❤️ by the Leann team
|
||||
</p>
|
||||
|
||||
## 🤖 Explore LEANN with AI
|
||||
|
||||
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.
|
||||
|
||||
@@ -10,8 +10,40 @@ from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
|
||||
# Optional import: older PyPI builds may not include interactive_utils
|
||||
try:
|
||||
from leann.interactive_utils import create_rag_session
|
||||
except ImportError:
|
||||
|
||||
def create_rag_session(app_name: str, data_description: str):
|
||||
class _SimpleSession:
|
||||
def run_interactive_loop(self, handler):
|
||||
print(f"Interactive session for {app_name}: {data_description}")
|
||||
print("Interactive mode not available in this build")
|
||||
|
||||
return _SimpleSession()
|
||||
|
||||
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
# Optional import: older PyPI builds may not include settings
|
||||
try:
|
||||
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
except ImportError:
|
||||
# Minimal fallbacks if settings helpers are unavailable
|
||||
import os
|
||||
|
||||
def resolve_ollama_host(value: str | None) -> str | None:
|
||||
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
|
||||
|
||||
def resolve_openai_api_key(value: str | None) -> str | None:
|
||||
return value or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
def resolve_openai_base_url(value: str | None) -> str | None:
|
||||
return value or os.getenv("OPENAI_BASE_URL")
|
||||
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
@@ -78,6 +110,24 @@ class BaseRAGExample(ABC):
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible embedding host",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible embedding services",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# LLM parameters
|
||||
llm_group = parser.add_argument_group("LLM Parameters")
|
||||
@@ -97,8 +147,8 @@ class BaseRAGExample(ABC):
|
||||
llm_group.add_argument(
|
||||
"--llm-host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="Host for Ollama API (default: http://localhost:11434)",
|
||||
default=None,
|
||||
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--thinking-budget",
|
||||
@@ -107,6 +157,18 @@ class BaseRAGExample(ABC):
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible APIs",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# AST Chunking parameters
|
||||
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||
@@ -118,14 +180,14 @@ class BaseRAGExample(ABC):
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum characters per AST chunk (default: 512)",
|
||||
default=300,
|
||||
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-overlap",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Overlap between AST chunks (default: 64)",
|
||||
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--code-file-extensions",
|
||||
@@ -205,9 +267,13 @@ class BaseRAGExample(ABC):
|
||||
|
||||
if args.llm == "openai":
|
||||
config["model"] = args.llm_model or "gpt-4o"
|
||||
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||
if resolved_key:
|
||||
config["api_key"] = resolved_key
|
||||
elif args.llm == "ollama":
|
||||
config["model"] = args.llm_model or "llama3.2:1b"
|
||||
config["host"] = args.llm_host
|
||||
config["host"] = resolve_ollama_host(args.llm_host)
|
||||
elif args.llm == "hf":
|
||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
elif args.llm == "simulated":
|
||||
@@ -223,10 +289,20 @@ class BaseRAGExample(ABC):
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
print(f"Total text chunks: {len(texts)}")
|
||||
|
||||
embedding_options: dict[str, Any] = {}
|
||||
if args.embedding_mode == "ollama":
|
||||
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||
elif args.embedding_mode == "openai":
|
||||
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
embedding_options=embedding_options or None,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
@@ -262,37 +338,26 @@ class BaseRAGExample(ABC):
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||
print("Type 'quit' or 'exit' to stop.\n")
|
||||
# Create interactive session
|
||||
session = create_rag_session(
|
||||
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("You: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
def handle_query(query: str):
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
if not query:
|
||||
continue
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
session.run_interactive_loop(handle_query)
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
|
||||
0
apps/chatgpt_data/__init__.py
Normal file
413
apps/chatgpt_data/chatgpt_reader.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
ChatGPT export data reader.
|
||||
|
||||
Reads and processes ChatGPT export data from chat.html files.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ChatGPTReader(BaseReader):
|
||||
"""
|
||||
ChatGPT export data reader.
|
||||
|
||||
Reads ChatGPT conversation data from exported chat.html files or zip archives.
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa
|
||||
except ImportError:
|
||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
|
||||
"""
|
||||
Extract chat.html from ChatGPT export zip file.
|
||||
|
||||
Args:
|
||||
zip_path: Path to the ChatGPT export zip file
|
||||
|
||||
Returns:
|
||||
HTML content as string, or None if not found
|
||||
"""
|
||||
try:
|
||||
with ZipFile(zip_path, "r") as zip_file:
|
||||
# Look for chat.html or conversations.html
|
||||
html_files = [
|
||||
f
|
||||
for f in zip_file.namelist()
|
||||
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
|
||||
]
|
||||
|
||||
if not html_files:
|
||||
print(f"No HTML chat file found in {zip_path}")
|
||||
return None
|
||||
|
||||
# Use the first HTML file found
|
||||
html_file = html_files[0]
|
||||
print(f"Found HTML file: {html_file}")
|
||||
|
||||
with zip_file.open(html_file) as f:
|
||||
return f.read().decode("utf-8", errors="ignore")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting HTML from zip {zip_path}: {e}")
|
||||
return None
|
||||
|
||||
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
|
||||
"""
|
||||
Parse ChatGPT HTML export to extract conversations.
|
||||
|
||||
Args:
|
||||
html_content: HTML content from ChatGPT export
|
||||
|
||||
Returns:
|
||||
List of conversation dictionaries
|
||||
"""
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
conversations = []
|
||||
|
||||
# Try different possible structures for ChatGPT exports
|
||||
# Structure 1: Look for conversation containers
|
||||
conversation_containers = soup.find_all(
|
||||
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
|
||||
)
|
||||
|
||||
if not conversation_containers:
|
||||
# Structure 2: Look for message containers directly
|
||||
conversation_containers = [soup] # Use the entire document as one conversation
|
||||
|
||||
for container in conversation_containers:
|
||||
conversation = self._extract_conversation_from_container(container)
|
||||
if conversation and conversation.get("messages"):
|
||||
conversations.append(conversation)
|
||||
|
||||
# If no structured conversations found, try to extract all text as one conversation
|
||||
if not conversations:
|
||||
all_text = soup.get_text(separator="\n", strip=True)
|
||||
if all_text:
|
||||
conversations.append(
|
||||
{
|
||||
"title": "ChatGPT Conversation",
|
||||
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
|
||||
"timestamp": None,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations
|
||||
|
||||
def _extract_conversation_from_container(self, container) -> dict | None:
|
||||
"""
|
||||
Extract conversation data from a container element.
|
||||
|
||||
Args:
|
||||
container: BeautifulSoup element containing conversation
|
||||
|
||||
Returns:
|
||||
Dictionary with conversation data or None
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# Look for message elements with various possible structures
|
||||
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
|
||||
|
||||
for selector in message_selectors:
|
||||
message_elements = container.select(selector)
|
||||
if message_elements:
|
||||
break
|
||||
else:
|
||||
message_elements = []
|
||||
|
||||
# If no structured messages found, treat the entire container as one message
|
||||
if not message_elements:
|
||||
text_content = container.get_text(separator="\n", strip=True)
|
||||
if text_content:
|
||||
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
|
||||
else:
|
||||
for element in message_elements:
|
||||
message = self._extract_message_from_element(element)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Try to extract conversation title
|
||||
title_element = container.find(["h1", "h2", "h3", "title"])
|
||||
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
|
||||
|
||||
# Try to extract timestamp from various possible locations
|
||||
timestamp = self._extract_timestamp_from_container(container)
|
||||
|
||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||
|
||||
def _extract_message_from_element(self, element) -> dict | None:
|
||||
"""
|
||||
Extract message data from an element.
|
||||
|
||||
Args:
|
||||
element: BeautifulSoup element containing message
|
||||
|
||||
Returns:
|
||||
Dictionary with message data or None
|
||||
"""
|
||||
text_content = element.get_text(separator=" ", strip=True)
|
||||
|
||||
# Skip empty or very short messages
|
||||
if not text_content or len(text_content.strip()) < 3:
|
||||
return None
|
||||
|
||||
# Try to determine role (user/assistant) from class names or content
|
||||
role = "mixed" # Default role
|
||||
|
||||
class_names = " ".join(element.get("class", [])).lower()
|
||||
if "user" in class_names or "human" in class_names:
|
||||
role = "user"
|
||||
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
|
||||
role = "assistant"
|
||||
elif text_content.lower().startswith(("you:", "user:", "me:")):
|
||||
role = "user"
|
||||
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
|
||||
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
|
||||
role = "assistant"
|
||||
text_content = re.sub(
|
||||
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# Try to extract timestamp
|
||||
timestamp = self._extract_timestamp_from_element(element)
|
||||
|
||||
return {"role": role, "content": text_content, "timestamp": timestamp}
|
||||
|
||||
def _extract_timestamp_from_element(self, element) -> str | None:
|
||||
"""Extract timestamp from element."""
|
||||
# Look for timestamp in various attributes and child elements
|
||||
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
|
||||
for attr in timestamp_attrs:
|
||||
if element.get(attr):
|
||||
return element.get(attr)
|
||||
|
||||
# Look for time elements
|
||||
time_element = element.find("time")
|
||||
if time_element:
|
||||
return time_element.get("datetime") or time_element.get_text(strip=True)
|
||||
|
||||
# Look for date-like text patterns
|
||||
text = element.get_text()
|
||||
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
|
||||
|
||||
for pattern in date_patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
return None
|
||||
|
||||
def _extract_timestamp_from_container(self, container) -> str | None:
|
||||
"""Extract timestamp from conversation container."""
|
||||
return self._extract_timestamp_from_element(container)
|
||||
|
||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||
"""
|
||||
Create concatenated content from conversation messages.
|
||||
|
||||
Args:
|
||||
conversation: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
title = conversation.get("title", "ChatGPT Conversation")
|
||||
messages = conversation.get("messages", [])
|
||||
timestamp = conversation.get("timestamp", "Unknown")
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if role == "user":
|
||||
prefix = "[You]"
|
||||
elif role == "assistant":
|
||||
prefix = "[ChatGPT]"
|
||||
else:
|
||||
prefix = "[Message]"
|
||||
|
||||
# Add timestamp if available
|
||||
if msg_timestamp:
|
||||
prefix += f" ({msg_timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {content}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""Conversation: {title}
|
||||
Date: {timestamp}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load ChatGPT export data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing ChatGPT export files or path to specific file
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum number of conversations to process
|
||||
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
|
||||
include_metadata (bool): Whether to include metadata in documents
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", -1)
|
||||
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
|
||||
include_metadata = load_kwargs.get("include_metadata", True)
|
||||
|
||||
if not chatgpt_export_path:
|
||||
print("No ChatGPT export path provided")
|
||||
return docs
|
||||
|
||||
export_path = Path(chatgpt_export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"ChatGPT export path not found: {export_path}")
|
||||
return docs
|
||||
|
||||
html_content = None
|
||||
|
||||
# Handle different input types
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() == ".zip":
|
||||
# Extract HTML from zip file
|
||||
html_content = self._extract_html_from_zip(export_path)
|
||||
elif export_path.suffix.lower() == ".html":
|
||||
# Read HTML file directly
|
||||
try:
|
||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||
html_content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading HTML file {export_path}: {e}")
|
||||
return docs
|
||||
else:
|
||||
print(f"Unsupported file type: {export_path.suffix}")
|
||||
return docs
|
||||
|
||||
elif export_path.is_dir():
|
||||
# Look for HTML files in directory
|
||||
html_files = list(export_path.glob("*.html"))
|
||||
zip_files = list(export_path.glob("*.zip"))
|
||||
|
||||
if html_files:
|
||||
# Use first HTML file found
|
||||
html_file = html_files[0]
|
||||
print(f"Found HTML file: {html_file}")
|
||||
try:
|
||||
with open(html_file, encoding="utf-8", errors="ignore") as f:
|
||||
html_content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading HTML file {html_file}: {e}")
|
||||
return docs
|
||||
|
||||
elif zip_files:
|
||||
# Use first zip file found
|
||||
zip_file = zip_files[0]
|
||||
print(f"Found zip file: {zip_file}")
|
||||
html_content = self._extract_html_from_zip(zip_file)
|
||||
|
||||
else:
|
||||
print(f"No HTML or zip files found in {export_path}")
|
||||
return docs
|
||||
|
||||
if not html_content:
|
||||
print("No HTML content found to process")
|
||||
return docs
|
||||
|
||||
# Parse conversations from HTML
|
||||
print("Parsing ChatGPT conversations from HTML...")
|
||||
conversations = self._parse_chatgpt_html(html_content)
|
||||
|
||||
if not conversations:
|
||||
print("No conversations found in HTML content")
|
||||
return docs
|
||||
|
||||
print(f"Found {len(conversations)} conversations")
|
||||
|
||||
# Process conversations into documents
|
||||
count = 0
|
||||
for conversation in conversations:
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Create one document per conversation with concatenated messages
|
||||
doc_content = self._create_concatenated_content(conversation)
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"title": conversation.get("title", "ChatGPT Conversation"),
|
||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||
"message_count": len(conversation.get("messages", [])),
|
||||
"source": "ChatGPT Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
else:
|
||||
# Create separate documents for each message
|
||||
for message in conversation.get("messages", []):
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# Create document content with context
|
||||
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
|
||||
Role: {role}
|
||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||
Message: {content}
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
|
||||
"role": role,
|
||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||
"source": "ChatGPT Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(docs)} documents from ChatGPT export")
|
||||
return docs
|
||||
186
apps/chatgpt_rag.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
ChatGPT RAG example using the unified interface.
|
||||
Supports ChatGPT export data from chat.html files.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .chatgpt_data.chatgpt_reader import ChatGPTReader
|
||||
|
||||
|
||||
class ChatGPTRAG(BaseRAGExample):
|
||||
"""RAG example for ChatGPT conversation data."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all conversations by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="ChatGPT",
|
||||
description="Process and query ChatGPT conversation exports with LEANN",
|
||||
default_index_name="chatgpt_conversations_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add ChatGPT-specific arguments."""
|
||||
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
|
||||
chatgpt_group.add_argument(
|
||||
"--export-path",
|
||||
type=str,
|
||||
default="./chatgpt_export",
|
||||
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--separate-messages",
|
||||
action="store_true",
|
||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
|
||||
"""
|
||||
Find ChatGPT export files in the given path.
|
||||
|
||||
Args:
|
||||
export_path: Path to search for exports
|
||||
|
||||
Returns:
|
||||
List of paths to ChatGPT export files
|
||||
"""
|
||||
export_files = []
|
||||
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() in [".zip", ".html"]:
|
||||
export_files.append(export_path)
|
||||
elif export_path.is_dir():
|
||||
# Look for zip and html files
|
||||
export_files.extend(export_path.glob("*.zip"))
|
||||
export_files.extend(export_path.glob("*.html"))
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load ChatGPT export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"ChatGPT export path not found: {export_path}")
|
||||
print(
|
||||
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
|
||||
)
|
||||
print("\nTo export your ChatGPT data:")
|
||||
print("1. Sign in to ChatGPT")
|
||||
print("2. Click on your profile icon → Settings → Data Controls")
|
||||
print("3. Click 'Export' under Export Data")
|
||||
print("4. Download the zip file from the email link")
|
||||
print("5. Extract or place the file/directory at the specified path")
|
||||
return []
|
||||
|
||||
# Find export files
|
||||
export_files = self._find_chatgpt_exports(export_path)
|
||||
|
||||
if not export_files:
|
||||
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(export_files)} ChatGPT export files")
|
||||
|
||||
# Create reader with appropriate settings
|
||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||
reader = ChatGPTReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Process each export file
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_file in enumerate(export_files):
|
||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per file
|
||||
max_per_file = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_file = remaining
|
||||
|
||||
# Load conversations
|
||||
documents = reader.load_data(
|
||||
chatgpt_export_path=str(export_file),
|
||||
max_count=max_per_file,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} conversations from this file")
|
||||
else:
|
||||
print(f"No conversations loaded from {export_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_file}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No conversations found to process!")
|
||||
print("\nTroubleshooting:")
|
||||
print("- Ensure the export file is a valid ChatGPT export")
|
||||
print("- Check that the HTML file contains conversation data")
|
||||
print("- Try extracting the zip file and pointing to the HTML file directly")
|
||||
return []
|
||||
|
||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||
print("Now starting to split into text chunks... this may take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for ChatGPT RAG
|
||||
print("\n🤖 ChatGPT RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did I ask about Python programming?'")
|
||||
print("- 'Show me conversations about machine learning'")
|
||||
print("- 'Find discussions about travel planning'")
|
||||
print("- 'What advice did ChatGPT give me about career development?'")
|
||||
print("- 'Search for conversations about cooking recipes'")
|
||||
print("\nTo get started:")
|
||||
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
|
||||
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
|
||||
print("3. Run this script to build your personal ChatGPT knowledge base!")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = ChatGPTRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -12,6 +12,7 @@ from pathlib import Path
|
||||
try:
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
@@ -25,6 +26,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||
sys.path.insert(0, str(leann_src))
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
@@ -36,6 +38,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||
|
||||
__all__ = [
|
||||
"CODE_EXTENSIONS",
|
||||
"_traditional_chunks_as_dicts",
|
||||
"create_ast_chunks",
|
||||
"create_text_chunks",
|
||||
"create_traditional_chunks",
|
||||
|
||||
0
apps/claude_data/__init__.py
Normal file
420
apps/claude_data/claude_reader.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Claude export data reader.
|
||||
|
||||
Reads and processes Claude conversation data from exported JSON files.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ClaudeReader(BaseReader):
|
||||
"""
|
||||
Claude export data reader.
|
||||
|
||||
Reads Claude conversation data from exported JSON files or zip archives.
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
|
||||
"""
|
||||
Extract JSON files from Claude export zip file.
|
||||
|
||||
Args:
|
||||
zip_path: Path to the Claude export zip file
|
||||
|
||||
Returns:
|
||||
List of JSON content strings, or empty list if not found
|
||||
"""
|
||||
json_contents = []
|
||||
try:
|
||||
with ZipFile(zip_path, "r") as zip_file:
|
||||
# Look for JSON files
|
||||
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
|
||||
|
||||
if not json_files:
|
||||
print(f"No JSON files found in {zip_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(json_files)} JSON files in archive")
|
||||
|
||||
for json_file in json_files:
|
||||
with zip_file.open(json_file) as f:
|
||||
content = f.read().decode("utf-8", errors="ignore")
|
||||
json_contents.append(content)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting JSON from zip {zip_path}: {e}")
|
||||
|
||||
return json_contents
|
||||
|
||||
def _parse_claude_json(self, json_content: str) -> list[dict]:
|
||||
"""
|
||||
Parse Claude JSON export to extract conversations.
|
||||
|
||||
Args:
|
||||
json_content: JSON content from Claude export
|
||||
|
||||
Returns:
|
||||
List of conversation dictionaries
|
||||
"""
|
||||
try:
|
||||
data = json.loads(json_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON: {e}")
|
||||
return []
|
||||
|
||||
conversations = []
|
||||
|
||||
# Handle different possible JSON structures
|
||||
if isinstance(data, list):
|
||||
# If data is a list of conversations
|
||||
for item in data:
|
||||
conversation = self._extract_conversation_from_json(item)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
elif isinstance(data, dict):
|
||||
# Check for common structures
|
||||
if "conversations" in data:
|
||||
# Structure: {"conversations": [...]}
|
||||
for item in data["conversations"]:
|
||||
conversation = self._extract_conversation_from_json(item)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
elif "messages" in data:
|
||||
# Single conversation with messages
|
||||
conversation = self._extract_conversation_from_json(data)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
else:
|
||||
# Try to treat the whole object as a conversation
|
||||
conversation = self._extract_conversation_from_json(data)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
|
||||
return conversations
|
||||
|
||||
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
|
||||
"""
|
||||
Extract conversation data from a JSON object.
|
||||
|
||||
Args:
|
||||
conv_data: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Dictionary with conversation data or None
|
||||
"""
|
||||
if not isinstance(conv_data, dict):
|
||||
return None
|
||||
|
||||
messages = []
|
||||
|
||||
# Look for messages in various possible structures
|
||||
message_sources = []
|
||||
if "messages" in conv_data:
|
||||
message_sources = conv_data["messages"]
|
||||
elif "chat" in conv_data:
|
||||
message_sources = conv_data["chat"]
|
||||
elif "conversation" in conv_data:
|
||||
message_sources = conv_data["conversation"]
|
||||
else:
|
||||
# If no clear message structure, try to extract from the object itself
|
||||
if "content" in conv_data and "role" in conv_data:
|
||||
message_sources = [conv_data]
|
||||
|
||||
for msg_data in message_sources:
|
||||
message = self._extract_message_from_json(msg_data)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Extract conversation metadata
|
||||
title = self._extract_title_from_conversation(conv_data, messages)
|
||||
timestamp = self._extract_timestamp_from_conversation(conv_data)
|
||||
|
||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||
|
||||
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
|
||||
"""
|
||||
Extract message data from a JSON message object.
|
||||
|
||||
Args:
|
||||
msg_data: Dictionary containing message data
|
||||
|
||||
Returns:
|
||||
Dictionary with message data or None
|
||||
"""
|
||||
if not isinstance(msg_data, dict):
|
||||
return None
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = ""
|
||||
content_fields = ["content", "text", "message", "body"]
|
||||
for field in content_fields:
|
||||
if msg_data.get(field):
|
||||
content = str(msg_data[field])
|
||||
break
|
||||
|
||||
if not content or len(content.strip()) < 3:
|
||||
return None
|
||||
|
||||
# Extract role (user/assistant/human/ai/claude)
|
||||
role = "mixed" # Default role
|
||||
role_fields = ["role", "sender", "from", "author", "type"]
|
||||
for field in role_fields:
|
||||
if msg_data.get(field):
|
||||
role_value = str(msg_data[field]).lower()
|
||||
if role_value in ["user", "human", "person"]:
|
||||
role = "user"
|
||||
elif role_value in ["assistant", "ai", "claude", "bot"]:
|
||||
role = "assistant"
|
||||
break
|
||||
|
||||
# Extract timestamp
|
||||
timestamp = self._extract_timestamp_from_message(msg_data)
|
||||
|
||||
return {"role": role, "content": content, "timestamp": timestamp}
|
||||
|
||||
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
|
||||
"""Extract timestamp from message data."""
|
||||
timestamp_fields = ["timestamp", "created_at", "date", "time"]
|
||||
for field in timestamp_fields:
|
||||
if msg_data.get(field):
|
||||
return str(msg_data[field])
|
||||
return None
|
||||
|
||||
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
|
||||
"""Extract timestamp from conversation data."""
|
||||
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
|
||||
for field in timestamp_fields:
|
||||
if conv_data.get(field):
|
||||
return str(conv_data[field])
|
||||
return None
|
||||
|
||||
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
|
||||
"""Extract or generate title for conversation."""
|
||||
# Try to find explicit title
|
||||
title_fields = ["title", "name", "subject", "topic"]
|
||||
for field in title_fields:
|
||||
if conv_data.get(field):
|
||||
return str(conv_data[field])
|
||||
|
||||
# Generate title from first user message
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content", "")
|
||||
if content:
|
||||
# Use first 50 characters as title
|
||||
title = content[:50].strip()
|
||||
if len(content) > 50:
|
||||
title += "..."
|
||||
return title
|
||||
|
||||
return "Claude Conversation"
|
||||
|
||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||
"""
|
||||
Create concatenated content from conversation messages.
|
||||
|
||||
Args:
|
||||
conversation: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
title = conversation.get("title", "Claude Conversation")
|
||||
messages = conversation.get("messages", [])
|
||||
timestamp = conversation.get("timestamp", "Unknown")
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if role == "user":
|
||||
prefix = "[You]"
|
||||
elif role == "assistant":
|
||||
prefix = "[Claude]"
|
||||
else:
|
||||
prefix = "[Message]"
|
||||
|
||||
# Add timestamp if available
|
||||
if msg_timestamp:
|
||||
prefix += f" ({msg_timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {content}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""Conversation: {title}
|
||||
Date: {timestamp}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load Claude export data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing Claude export files or path to specific file
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum number of conversations to process
|
||||
claude_export_path (str): Specific path to Claude export file/directory
|
||||
include_metadata (bool): Whether to include metadata in documents
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", -1)
|
||||
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
|
||||
include_metadata = load_kwargs.get("include_metadata", True)
|
||||
|
||||
if not claude_export_path:
|
||||
print("No Claude export path provided")
|
||||
return docs
|
||||
|
||||
export_path = Path(claude_export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"Claude export path not found: {export_path}")
|
||||
return docs
|
||||
|
||||
json_contents = []
|
||||
|
||||
# Handle different input types
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() == ".zip":
|
||||
# Extract JSON from zip file
|
||||
json_contents = self._extract_json_from_zip(export_path)
|
||||
elif export_path.suffix.lower() == ".json":
|
||||
# Read JSON file directly
|
||||
try:
|
||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||
json_contents.append(f.read())
|
||||
except Exception as e:
|
||||
print(f"Error reading JSON file {export_path}: {e}")
|
||||
return docs
|
||||
else:
|
||||
print(f"Unsupported file type: {export_path.suffix}")
|
||||
return docs
|
||||
|
||||
elif export_path.is_dir():
|
||||
# Look for JSON files in directory
|
||||
json_files = list(export_path.glob("*.json"))
|
||||
zip_files = list(export_path.glob("*.zip"))
|
||||
|
||||
if json_files:
|
||||
print(f"Found {len(json_files)} JSON files in directory")
|
||||
for json_file in json_files:
|
||||
try:
|
||||
with open(json_file, encoding="utf-8", errors="ignore") as f:
|
||||
json_contents.append(f.read())
|
||||
except Exception as e:
|
||||
print(f"Error reading JSON file {json_file}: {e}")
|
||||
continue
|
||||
|
||||
if zip_files:
|
||||
print(f"Found {len(zip_files)} ZIP files in directory")
|
||||
for zip_file in zip_files:
|
||||
zip_contents = self._extract_json_from_zip(zip_file)
|
||||
json_contents.extend(zip_contents)
|
||||
|
||||
if not json_files and not zip_files:
|
||||
print(f"No JSON or ZIP files found in {export_path}")
|
||||
return docs
|
||||
|
||||
if not json_contents:
|
||||
print("No JSON content found to process")
|
||||
return docs
|
||||
|
||||
# Parse conversations from JSON content
|
||||
print("Parsing Claude conversations from JSON...")
|
||||
all_conversations = []
|
||||
for json_content in json_contents:
|
||||
conversations = self._parse_claude_json(json_content)
|
||||
all_conversations.extend(conversations)
|
||||
|
||||
if not all_conversations:
|
||||
print("No conversations found in JSON content")
|
||||
return docs
|
||||
|
||||
print(f"Found {len(all_conversations)} conversations")
|
||||
|
||||
# Process conversations into documents
|
||||
count = 0
|
||||
for conversation in all_conversations:
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Create one document per conversation with concatenated messages
|
||||
doc_content = self._create_concatenated_content(conversation)
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"title": conversation.get("title", "Claude Conversation"),
|
||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||
"message_count": len(conversation.get("messages", [])),
|
||||
"source": "Claude Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
else:
|
||||
# Create separate documents for each message
|
||||
for message in conversation.get("messages", []):
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# Create document content with context
|
||||
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
|
||||
Role: {role}
|
||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||
Message: {content}
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"conversation_title": conversation.get("title", "Claude Conversation"),
|
||||
"role": role,
|
||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||
"source": "Claude Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(docs)} documents from Claude export")
|
||||
return docs
|
||||
189
apps/claude_rag.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Claude RAG example using the unified interface.
|
||||
Supports Claude export data from JSON files.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .claude_data.claude_reader import ClaudeReader
|
||||
|
||||
|
||||
class ClaudeRAG(BaseRAGExample):
|
||||
"""RAG example for Claude conversation data."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all conversations by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Claude",
|
||||
description="Process and query Claude conversation exports with LEANN",
|
||||
default_index_name="claude_conversations_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add Claude-specific arguments."""
|
||||
claude_group = parser.add_argument_group("Claude Parameters")
|
||||
claude_group.add_argument(
|
||||
"--export-path",
|
||||
type=str,
|
||||
default="./claude_export",
|
||||
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--separate-messages",
|
||||
action="store_true",
|
||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _find_claude_exports(self, export_path: Path) -> list[Path]:
|
||||
"""
|
||||
Find Claude export files in the given path.
|
||||
|
||||
Args:
|
||||
export_path: Path to search for exports
|
||||
|
||||
Returns:
|
||||
List of paths to Claude export files
|
||||
"""
|
||||
export_files = []
|
||||
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() in [".zip", ".json"]:
|
||||
export_files.append(export_path)
|
||||
elif export_path.is_dir():
|
||||
# Look for zip and json files
|
||||
export_files.extend(export_path.glob("*.zip"))
|
||||
export_files.extend(export_path.glob("*.json"))
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load Claude export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"Claude export path not found: {export_path}")
|
||||
print(
|
||||
"Please ensure you have exported your Claude data and placed it in the correct location."
|
||||
)
|
||||
print("\nTo export your Claude data:")
|
||||
print("1. Open Claude in your browser")
|
||||
print("2. Look for export/download options in settings or conversation menu")
|
||||
print("3. Download the conversation data (usually in JSON format)")
|
||||
print("4. Place the file/directory at the specified path")
|
||||
print(
|
||||
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
|
||||
)
|
||||
return []
|
||||
|
||||
# Find export files
|
||||
export_files = self._find_claude_exports(export_path)
|
||||
|
||||
if not export_files:
|
||||
print(f"No Claude export files (.json or .zip) found in: {export_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(export_files)} Claude export files")
|
||||
|
||||
# Create reader with appropriate settings
|
||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||
reader = ClaudeReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Process each export file
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_file in enumerate(export_files):
|
||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per file
|
||||
max_per_file = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_file = remaining
|
||||
|
||||
# Load conversations
|
||||
documents = reader.load_data(
|
||||
claude_export_path=str(export_file),
|
||||
max_count=max_per_file,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} conversations from this file")
|
||||
else:
|
||||
print(f"No conversations loaded from {export_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_file}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No conversations found to process!")
|
||||
print("\nTroubleshooting:")
|
||||
print("- Ensure the export file is a valid Claude export")
|
||||
print("- Check that the JSON file contains conversation data")
|
||||
print("- Try using a different export format or method")
|
||||
print("- Check Claude's documentation for current export procedures")
|
||||
return []
|
||||
|
||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||
print("Now starting to split into text chunks... this may take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for Claude RAG
|
||||
print("\n🤖 Claude RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did I ask Claude about Python programming?'")
|
||||
print("- 'Show me conversations about machine learning'")
|
||||
print("- 'Find discussions about code optimization'")
|
||||
print("- 'What advice did Claude give me about software design?'")
|
||||
print("- 'Search for conversations about debugging techniques'")
|
||||
print("\nTo get started:")
|
||||
print("1. Export your Claude conversation data")
|
||||
print("2. Place the JSON/ZIP file in ./claude_export/")
|
||||
print("3. Run this script to build your personal Claude knowledge base!")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = ClaudeRAG()
|
||||
asyncio.run(rag.run())
|
||||
1
apps/imessage_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""iMessage data processing module."""
|
||||
342
apps/imessage_data/imessage_reader.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
iMessage data reader.
|
||||
|
||||
Reads and processes iMessage conversation data from the macOS Messages database.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class IMessageReader(BaseReader):
|
||||
"""
|
||||
iMessage data reader.
|
||||
|
||||
Reads iMessage conversation data from the macOS Messages database (chat.db).
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _get_default_chat_db_path(self) -> Path:
|
||||
"""
|
||||
Get the default path to the iMessage chat database.
|
||||
|
||||
Returns:
|
||||
Path to the chat.db file
|
||||
"""
|
||||
home = Path.home()
|
||||
return home / "Library" / "Messages" / "chat.db"
|
||||
|
||||
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
|
||||
"""
|
||||
Convert Cocoa timestamp to readable format.
|
||||
|
||||
Args:
|
||||
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
|
||||
|
||||
Returns:
|
||||
Formatted timestamp string
|
||||
"""
|
||||
if cocoa_timestamp == 0:
|
||||
return "Unknown"
|
||||
|
||||
try:
|
||||
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
|
||||
# Convert to seconds and add to Unix epoch
|
||||
cocoa_epoch = datetime(2001, 1, 1)
|
||||
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
|
||||
message_time = cocoa_epoch.timestamp() + unix_timestamp
|
||||
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "Unknown"
|
||||
|
||||
def _get_contact_name(self, handle_id: str) -> str:
|
||||
"""
|
||||
Get a readable contact name from handle ID.
|
||||
|
||||
Args:
|
||||
handle_id: The handle ID (phone number or email)
|
||||
|
||||
Returns:
|
||||
Formatted contact name
|
||||
"""
|
||||
if not handle_id:
|
||||
return "Unknown"
|
||||
|
||||
# Clean up phone numbers and emails for display
|
||||
if "@" in handle_id:
|
||||
return handle_id # Email address
|
||||
elif handle_id.startswith("+"):
|
||||
return handle_id # International phone number
|
||||
else:
|
||||
# Try to format as phone number
|
||||
digits = "".join(filter(str.isdigit, handle_id))
|
||||
if len(digits) == 10:
|
||||
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
|
||||
elif len(digits) == 11 and digits[0] == "1":
|
||||
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
|
||||
else:
|
||||
return handle_id
|
||||
|
||||
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
|
||||
"""
|
||||
Read messages from the iMessage database.
|
||||
|
||||
Args:
|
||||
db_path: Path to the chat.db file
|
||||
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
if not db_path.exists():
|
||||
print(f"iMessage database not found at: {db_path}")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Connect to the database
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query to get messages with chat and handle information
|
||||
query = """
|
||||
SELECT
|
||||
m.ROWID as message_id,
|
||||
m.text,
|
||||
m.date,
|
||||
m.is_from_me,
|
||||
m.service,
|
||||
c.chat_identifier,
|
||||
c.display_name as chat_display_name,
|
||||
h.id as handle_id,
|
||||
c.ROWID as chat_id
|
||||
FROM message m
|
||||
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
|
||||
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
|
||||
LEFT JOIN handle h ON m.handle_id = h.ROWID
|
||||
WHERE m.text IS NOT NULL AND m.text != ''
|
||||
ORDER BY c.ROWID, m.date
|
||||
"""
|
||||
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
messages = []
|
||||
for row in rows:
|
||||
(
|
||||
message_id,
|
||||
text,
|
||||
date,
|
||||
is_from_me,
|
||||
service,
|
||||
chat_identifier,
|
||||
chat_display_name,
|
||||
handle_id,
|
||||
chat_id,
|
||||
) = row
|
||||
|
||||
message = {
|
||||
"message_id": message_id,
|
||||
"text": text,
|
||||
"timestamp": self._convert_cocoa_timestamp(date),
|
||||
"is_from_me": bool(is_from_me),
|
||||
"service": service or "iMessage",
|
||||
"chat_identifier": chat_identifier or "Unknown",
|
||||
"chat_display_name": chat_display_name or "Unknown Chat",
|
||||
"handle_id": handle_id or "Unknown",
|
||||
"contact_name": self._get_contact_name(handle_id or ""),
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
messages.append(message)
|
||||
|
||||
conn.close()
|
||||
print(f"Found {len(messages)} messages in database")
|
||||
return messages
|
||||
|
||||
except sqlite3.Error as e:
|
||||
print(f"Error reading iMessage database: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Unexpected error reading iMessage database: {e}")
|
||||
return []
|
||||
|
||||
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
|
||||
"""
|
||||
Group messages by chat ID.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Dictionary mapping chat_id to list of messages
|
||||
"""
|
||||
chats = {}
|
||||
for message in messages:
|
||||
chat_id = message["chat_id"]
|
||||
if chat_id not in chats:
|
||||
chats[chat_id] = []
|
||||
chats[chat_id].append(message)
|
||||
|
||||
return chats
|
||||
|
||||
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
|
||||
"""
|
||||
Create concatenated content from chat messages.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID
|
||||
messages: List of messages in the chat
|
||||
|
||||
Returns:
|
||||
Concatenated text content
|
||||
"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Get chat info from first message
|
||||
first_msg = messages[0]
|
||||
chat_name = first_msg["chat_display_name"]
|
||||
chat_identifier = first_msg["chat_identifier"]
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
timestamp = message["timestamp"]
|
||||
is_from_me = message["is_from_me"]
|
||||
text = message["text"]
|
||||
contact_name = message["contact_name"]
|
||||
|
||||
if is_from_me:
|
||||
prefix = "[You]"
|
||||
else:
|
||||
prefix = f"[{contact_name}]"
|
||||
|
||||
if timestamp != "Unknown":
|
||||
prefix += f" ({timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {text}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
doc_content = f"""Chat: {chat_name}
|
||||
Identifier: {chat_identifier}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def _create_individual_content(self, message: dict) -> str:
|
||||
"""
|
||||
Create content for individual message.
|
||||
|
||||
Args:
|
||||
message: Message dictionary
|
||||
|
||||
Returns:
|
||||
Formatted message content
|
||||
"""
|
||||
timestamp = message["timestamp"]
|
||||
is_from_me = message["is_from_me"]
|
||||
text = message["text"]
|
||||
contact_name = message["contact_name"]
|
||||
chat_name = message["chat_display_name"]
|
||||
|
||||
sender = "You" if is_from_me else contact_name
|
||||
|
||||
return f"""Message from {sender} in chat "{chat_name}"
|
||||
Time: {timestamp}
|
||||
Content: {text}
|
||||
"""
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load iMessage data and return as documents.
|
||||
|
||||
Args:
|
||||
input_dir: Optional path to directory containing chat.db file.
|
||||
If not provided, uses default macOS location.
|
||||
**load_kwargs: Additional arguments (unused)
|
||||
|
||||
Returns:
|
||||
List of Document objects containing iMessage data
|
||||
"""
|
||||
docs = []
|
||||
|
||||
# Determine database path
|
||||
if input_dir:
|
||||
db_path = Path(input_dir) / "chat.db"
|
||||
else:
|
||||
db_path = self._get_default_chat_db_path()
|
||||
|
||||
print(f"Reading iMessage database from: {db_path}")
|
||||
|
||||
# Read messages from database
|
||||
messages = self._read_messages_from_db(db_path)
|
||||
if not messages:
|
||||
return docs
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Group messages by chat and create concatenated documents
|
||||
chats = self._group_messages_by_chat(messages)
|
||||
|
||||
for chat_id, chat_messages in chats.items():
|
||||
if not chat_messages:
|
||||
continue
|
||||
|
||||
content = self._create_concatenated_content(chat_id, chat_messages)
|
||||
|
||||
# Create metadata
|
||||
first_msg = chat_messages[0]
|
||||
last_msg = chat_messages[-1]
|
||||
|
||||
metadata = {
|
||||
"source": "iMessage",
|
||||
"chat_id": chat_id,
|
||||
"chat_name": first_msg["chat_display_name"],
|
||||
"chat_identifier": first_msg["chat_identifier"],
|
||||
"message_count": len(chat_messages),
|
||||
"first_message_date": first_msg["timestamp"],
|
||||
"last_message_date": last_msg["timestamp"],
|
||||
"participants": list(
|
||||
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
|
||||
),
|
||||
}
|
||||
|
||||
doc = Document(text=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
else:
|
||||
# Create individual documents for each message
|
||||
for message in messages:
|
||||
content = self._create_individual_content(message)
|
||||
|
||||
metadata = {
|
||||
"source": "iMessage",
|
||||
"message_id": message["message_id"],
|
||||
"chat_id": message["chat_id"],
|
||||
"chat_name": message["chat_display_name"],
|
||||
"chat_identifier": message["chat_identifier"],
|
||||
"timestamp": message["timestamp"],
|
||||
"is_from_me": message["is_from_me"],
|
||||
"contact_name": message["contact_name"],
|
||||
"service": message["service"],
|
||||
}
|
||||
|
||||
doc = Document(text=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
print(f"Created {len(docs)} documents from iMessage data")
|
||||
return docs
|
||||
125
apps/imessage_rag.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
iMessage RAG Example.
|
||||
|
||||
This example demonstrates how to build a RAG system on your iMessage conversation history.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from leann.chunking_utils import create_text_chunks
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.imessage_data.imessage_reader import IMessageReader
|
||||
|
||||
|
||||
class IMessageRAG(BaseRAGExample):
|
||||
"""RAG example for iMessage conversation history."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="iMessage",
|
||||
description="RAG on your iMessage conversation history",
|
||||
default_index_name="imessage_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add iMessage-specific arguments."""
|
||||
imessage_group = parser.add_argument_group("iMessage Parameters")
|
||||
imessage_group.add_argument(
|
||||
"--db-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--no-concatenate-conversations",
|
||||
action="store_true",
|
||||
help="Process each message individually instead of concatenating by conversation",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum characters per text chunk (default: 1000)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--chunk-overlap",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Overlap between text chunks (default: 200)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load iMessage history and convert to text chunks."""
|
||||
print("Loading iMessage conversation history...")
|
||||
|
||||
# Determine concatenation setting
|
||||
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
|
||||
|
||||
# Initialize iMessage reader
|
||||
reader = IMessageReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Load documents
|
||||
try:
|
||||
if args.db_path:
|
||||
# Use custom database path
|
||||
db_dir = str(Path(args.db_path).parent)
|
||||
documents = reader.load_data(input_dir=db_dir)
|
||||
else:
|
||||
# Use default macOS location
|
||||
documents = reader.load_data()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading iMessage data: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
|
||||
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
|
||||
print("3. Try specifying a custom path with --db-path if you have a backup")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
print("No iMessage conversations found!")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} iMessage documents")
|
||||
|
||||
# Show some statistics
|
||||
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
|
||||
print(f"Total messages: {total_messages}")
|
||||
|
||||
if concatenate:
|
||||
# Show chat statistics
|
||||
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
|
||||
unique_chats = len(set(chat_names))
|
||||
print(f"Unique conversations: {unique_chats}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0:
|
||||
all_texts = all_texts[: args.max_items]
|
||||
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
app = IMessageRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
113
apps/multimodal/vision-based-pdf-multi-vector/README.md
Normal file
@@ -0,0 +1,113 @@
|
||||
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
|
||||
|
||||
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
|
||||
|
||||
### What you’ll run
|
||||
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
|
||||
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
|
||||
|
||||
## Prerequisites (macOS)
|
||||
|
||||
### 1) Homebrew poppler (for pdf2image)
|
||||
```bash
|
||||
brew install poppler
|
||||
which pdfinfo && pdfinfo -v
|
||||
```
|
||||
|
||||
### 2) Python environment
|
||||
Use uv (recommended) or pip. Python 3.9+.
|
||||
|
||||
Using uv:
|
||||
```bash
|
||||
uv pip install \
|
||||
colpali_engine \
|
||||
pdf2image \
|
||||
pillow \
|
||||
matplotlib qwen_vl_utils \
|
||||
einops \
|
||||
seaborn
|
||||
```
|
||||
|
||||
Notes:
|
||||
- On first run, models download from Hugging Face. Login/config if needed.
|
||||
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
|
||||
```bash
|
||||
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
|
||||
```
|
||||
|
||||
## Run the demos
|
||||
|
||||
### A) Local PDF example
|
||||
Converts a local PDF into page images, embeds them, builds an index, and searches.
|
||||
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
# If you don't have the sample PDF locally, download it (ignored by Git)
|
||||
mkdir -p pdfs
|
||||
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
|
||||
ls pdfs/2004.12832v2.pdf
|
||||
# Ensure output dir exists
|
||||
mkdir -p pages
|
||||
python multi-vector-leann-paper-example.py
|
||||
```
|
||||
Expected:
|
||||
- Page images in `pages/`.
|
||||
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
|
||||
|
||||
To use your own PDF: edit `pdf_path` near the top of the script.
|
||||
|
||||
### B) Similarity map + answer demo
|
||||
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
|
||||
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
Artifacts (when enabled):
|
||||
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
|
||||
- Similarity maps: `./figures/similarity_map_rank{K}.png`
|
||||
|
||||
Key knobs in the script (top of file):
|
||||
- `QUERY`: your question
|
||||
- `MODEL`: `"colqwen2"` or `"colpali"`
|
||||
- `USE_HF_DATASET`: set `False` to use local pages
|
||||
- `PDF`, `PAGES_DIR`: for local mode
|
||||
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
|
||||
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
|
||||
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
|
||||
|
||||
## Troubleshooting
|
||||
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
|
||||
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
|
||||
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
|
||||
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
|
||||
|
||||
## Notes
|
||||
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
|
||||
- For local PDFs, page images go to `./pages/`.
|
||||
|
||||
|
||||
### Retrieval and Visualization Example
|
||||
|
||||
Example settings in `multi-vector-leann-similarity-map.py`:
|
||||
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
|
||||
- `SIMILARITY_MAP = True` (to generate heatmaps)
|
||||
- `TOPK = 1` (save the top retrieved page and its similarity map)
|
||||
|
||||
Run:
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
|
||||
Outputs (by default):
|
||||
- Retrieved page: `./figures/retrieved_page_rank1.png`
|
||||
- Similarity map: `./figures/similarity_map_rank1.png`
|
||||
|
||||
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
|
||||
"):
|
||||

|
||||
|
||||
Notes:
|
||||
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
|
||||
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.
|
||||
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the current directory to path to import leann_multi_vector
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||
from PIL import Image
|
||||
|
||||
# Ensure repo paths are importable
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Set environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def create_test_image():
|
||||
"""Create a simple test image."""
|
||||
# Create a simple RGB image (800x600)
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
return img
|
||||
|
||||
|
||||
def load_test_image_from_file():
|
||||
"""Try to load an image from the indexes directory if available."""
|
||||
# Try to find an existing image in the indexes directory
|
||||
indexes_dir = Path(__file__).parent / "indexes"
|
||||
|
||||
# Look for images in common locations
|
||||
possible_paths = [
|
||||
indexes_dir / "vidore_fastplaid" / "images",
|
||||
indexes_dir / "colvision_large.leann.images",
|
||||
indexes_dir / "colvision.leann.images",
|
||||
]
|
||||
|
||||
for img_dir in possible_paths:
|
||||
if img_dir.exists():
|
||||
# Find first image file
|
||||
for ext in [".png", ".jpg", ".jpeg"]:
|
||||
for img_file in img_dir.glob(f"*{ext}"):
|
||||
print(f"Loading test image from: {img_file}")
|
||||
return Image.open(img_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Testing ColQwen2 Forward Pass")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Load or create test image
|
||||
print("\n[Step 1] Loading test image...")
|
||||
test_image = load_test_image_from_file()
|
||||
if test_image is None:
|
||||
print("No existing image found, creating a simple test image...")
|
||||
test_image = create_test_image()
|
||||
else:
|
||||
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||
|
||||
# Convert to RGB if needed
|
||||
if test_image.mode != "RGB":
|
||||
test_image = test_image.convert("RGB")
|
||||
print(f"✓ Converted to RGB: {test_image.size}")
|
||||
|
||||
# Step 2: Load model
|
||||
print("\n[Step 2] Loading ColQwen2 model...")
|
||||
try:
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||
print(f"✓ Model loaded: {model_name}")
|
||||
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||
|
||||
# Print model info
|
||||
if hasattr(model, "device"):
|
||||
print(f"✓ Model device: {model.device}")
|
||||
if hasattr(model, "dtype"):
|
||||
print(f"✓ Model dtype: {model.dtype}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Step 3: Test forward pass
|
||||
print("\n[Step 3] Running forward pass...")
|
||||
try:
|
||||
# Use the _embed_images function which handles batching and forward pass
|
||||
images = [test_image]
|
||||
print(f"Processing {len(images)} image(s)...")
|
||||
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
print("✓ Forward pass completed!")
|
||||
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||
|
||||
if len(doc_vecs) > 0:
|
||||
emb = doc_vecs[0]
|
||||
print(f"✓ Embedding shape: {emb.shape}")
|
||||
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||
print("✓ Embedding stats:")
|
||||
print(f" - Min: {emb.min().item():.4f}")
|
||||
print(f" - Max: {emb.max().item():.4f}")
|
||||
print(f" - Mean: {emb.mean().item():.4f}")
|
||||
print(f" - Std: {emb.std().item():.4f}")
|
||||
|
||||
# Check for NaN or Inf
|
||||
if torch.isnan(emb).any():
|
||||
print("⚠ Warning: Embedding contains NaN values!")
|
||||
if torch.isinf(emb).any():
|
||||
print("⚠ Warning: Embedding contains Inf values!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during forward pass: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
apps/multimodal/vision-based-pdf-multi-vector/fig/image.png
Normal file
|
After Width: | Height: | Size: 166 KiB |
1452
apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# pip install pdf2image
|
||||
# pip install pymilvus
|
||||
# pip install colpali_engine
|
||||
# pip install tqdm
|
||||
# pip install pillow
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Ensure local leann packages are importable before importing them
|
||||
_repo_root = Path(__file__).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
|
||||
import torch
|
||||
from colpali_engine.models import ColPali
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||
_device_str = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
"mps"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
device = get_torch_device(_device_str)
|
||||
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=_dtype,
|
||||
device_map=device,
|
||||
).eval()
|
||||
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||
|
||||
queries = [
|
||||
"How to end-to-end retrieval with ColBert",
|
||||
"Where is ColBERT performance Table, including text representation results?",
|
||||
]
|
||||
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
qs: list[torch.Tensor] = []
|
||||
for batch_query in dataloader:
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
embeddings_query = model(**batch_query)
|
||||
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
print(qs[0].shape)
|
||||
# %%
|
||||
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](images),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
|
||||
ds: list[torch.Tensor] = []
|
||||
for batch_doc in tqdm(dataloader):
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
embeddings_doc = model(**batch_doc)
|
||||
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||
|
||||
print(ds[0].shape)
|
||||
|
||||
# %%
|
||||
# Build HNSW index via LeannRetriever primitives and run search
|
||||
index_path = "./indexes/colpali.leann"
|
||||
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||
retriever.create_collection()
|
||||
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||
for i in range(len(filepaths)):
|
||||
data = {
|
||||
"colbert_vecs": ds[i].float().numpy(),
|
||||
"doc_id": i,
|
||||
"filepath": filepaths[i],
|
||||
}
|
||||
retriever.insert(data)
|
||||
retriever.create_index()
|
||||
for query in qs:
|
||||
query_np = query.float().numpy()
|
||||
result = retriever.search(query_np, topk=1)
|
||||
print(filepaths[result[0][1]])
|
||||
@@ -0,0 +1,713 @@
|
||||
## Jupyter-style notebook script
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import argparse
|
||||
import faulthandler
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Enable faulthandler to get stack trace on segfault
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
from leann_multi_vector import ( # utility functions/classes
|
||||
_ensure_repo_paths_importable,
|
||||
_load_images_from_dir,
|
||||
_maybe_convert_pdf_to_images,
|
||||
_load_colvision,
|
||||
_embed_images,
|
||||
_embed_queries,
|
||||
_build_index,
|
||||
_load_retriever_if_index_exists,
|
||||
_generate_similarity_map,
|
||||
_build_fast_plaid_index,
|
||||
_load_fast_plaid_index_if_exists,
|
||||
_search_fast_plaid,
|
||||
_get_fast_plaid_image,
|
||||
_get_fast_plaid_metadata,
|
||||
QwenVL,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# %%
|
||||
# Config
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
|
||||
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||
|
||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||
USE_HF_DATASET: bool = True
|
||||
# Single dataset name (used when DATASET_NAMES is None)
|
||||
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||
# Multiple datasets to combine (if provided, DATASET_NAME is ignored)
|
||||
# Can be:
|
||||
# - List of strings: ["dataset1", "dataset2"]
|
||||
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
|
||||
# - Mixed: ["dataset1", ("dataset2", "config2")]
|
||||
#
|
||||
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
|
||||
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
|
||||
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
|
||||
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
|
||||
# - "pixparse/arxiv-papers" (if available, arXiv papers)
|
||||
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
|
||||
# - "huggingface/document-images" (if available)
|
||||
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
|
||||
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||
DATASET_NAMES = [
|
||||
"weaviate/arXiv-AI-papers-multi-vector",
|
||||
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||
]
|
||||
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||
# Set to None to try loading all available splits automatically
|
||||
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
|
||||
# Image field name in the dataset (auto-detect if None)
|
||||
# Common names: "page_image", "image", "images", "img"
|
||||
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
|
||||
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||
|
||||
# Local pages (used when USE_HF_DATASET == False)
|
||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||
PAGES_DIR: str = "./pages"
|
||||
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
|
||||
# If set, images will be loaded directly from this folder
|
||||
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
|
||||
# Whether to recursively search subdirectories when loading from custom folder
|
||||
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
|
||||
|
||||
# Index + retrieval settings
|
||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||
INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||
# Fast-Plaid index settings (alternative to LEANN index)
|
||||
# These are now command-line arguments (see CLI overrides section)
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
SIMILARITY_MAP: bool = True
|
||||
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||
ANSWER: bool = True
|
||||
MAX_NEW_TOKENS: int = 1024
|
||||
|
||||
|
||||
# %%
|
||||
# CLI overrides
|
||||
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
|
||||
parser.add_argument(
|
||||
"--search-method",
|
||||
type=str,
|
||||
choices=["ann", "exact", "exact-all"],
|
||||
default="ann",
|
||||
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=QUERY,
|
||||
help=f"Query string to search for. Default: '{QUERY}'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set to True to use fast-plaid instead of LEANN. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default="./indexes/colvision_fastplaid",
|
||||
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=TOPK,
|
||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Recursively search subdirectories when loading images from custom folder. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
|
||||
)
|
||||
cli_args, _unknown = parser.parse_known_args()
|
||||
SEARCH_METHOD: str = cli_args.search_method
|
||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
|
||||
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
|
||||
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
|
||||
|
||||
# %%
|
||||
|
||||
# Step 1: Check if we can skip data loading (index already exists)
|
||||
retriever: Optional[Any] = None
|
||||
fast_plaid_index: Optional[Any] = None
|
||||
need_to_build_index = REBUILD_INDEX
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid index handling
|
||||
if not REBUILD_INDEX:
|
||||
try:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is not None:
|
||||
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Fast-Plaid index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
except Exception as e:
|
||||
# If loading fails (e.g., memory error, corrupted index), rebuild
|
||||
print(f"Warning: Failed to load Fast-Plaid index: {e}")
|
||||
print("Will rebuild the index...")
|
||||
need_to_build_index = True
|
||||
fast_plaid_index = None
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
# Original LEANN index handling
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild index")
|
||||
need_to_build_index = True
|
||||
|
||||
# Step 2: Load data only if we need to build the index
|
||||
if need_to_build_index:
|
||||
print("Loading dataset...")
|
||||
# Check for custom folder path first (takes precedence)
|
||||
if CUSTOM_FOLDER_PATH:
|
||||
if not os.path.isdir(CUSTOM_FOLDER_PATH):
|
||||
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
|
||||
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
|
||||
if CUSTOM_FOLDER_RECURSIVE:
|
||||
print(" (recursive mode: searching subdirectories)")
|
||||
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
|
||||
print(f" Found {len(filepaths)} image files")
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
|
||||
)
|
||||
print(f" Successfully loaded {len(images)} images")
|
||||
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||
elif USE_HF_DATASET:
|
||||
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
||||
|
||||
# Determine which datasets to load
|
||||
if DATASET_NAMES is not None:
|
||||
dataset_names_to_load = DATASET_NAMES
|
||||
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
|
||||
else:
|
||||
dataset_names_to_load = [DATASET_NAME]
|
||||
print(f"Loading single dataset: {DATASET_NAME}")
|
||||
|
||||
# Load and combine datasets
|
||||
all_datasets_to_concat = []
|
||||
|
||||
for dataset_entry in dataset_names_to_load:
|
||||
# Handle both string and tuple formats
|
||||
if isinstance(dataset_entry, tuple):
|
||||
dataset_name, config_name = dataset_entry
|
||||
else:
|
||||
dataset_name = dataset_entry
|
||||
config_name = None
|
||||
|
||||
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
|
||||
|
||||
# Load dataset to check available splits
|
||||
# If config_name is provided, use it; otherwise try without config
|
||||
try:
|
||||
if config_name:
|
||||
dataset_dict = load_dataset(dataset_name, config_name)
|
||||
else:
|
||||
dataset_dict = load_dataset(dataset_name)
|
||||
except ValueError as e:
|
||||
if "Config name is missing" in str(e):
|
||||
# Try to get available configs and suggest
|
||||
from datasets import get_dataset_config_names
|
||||
try:
|
||||
available_configs = get_dataset_config_names(dataset_name)
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Available configs: {available_configs}. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
raise
|
||||
|
||||
# Determine which splits to load
|
||||
if DATASET_SPLITS is None:
|
||||
# Auto-detect: try to load all available splits
|
||||
available_splits = list(dataset_dict.keys())
|
||||
print(f" Auto-detected splits: {available_splits}")
|
||||
splits_to_load = available_splits
|
||||
else:
|
||||
splits_to_load = DATASET_SPLITS
|
||||
|
||||
# Load and concatenate multiple splits for this dataset
|
||||
datasets_to_concat = []
|
||||
for split in splits_to_load:
|
||||
if split not in dataset_dict:
|
||||
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
|
||||
continue
|
||||
split_dataset = dataset_dict[split]
|
||||
print(f" Loaded split '{split}': {len(split_dataset)} pages")
|
||||
datasets_to_concat.append(split_dataset)
|
||||
|
||||
if not datasets_to_concat:
|
||||
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
|
||||
continue
|
||||
|
||||
# Concatenate splits for this dataset
|
||||
if len(datasets_to_concat) > 1:
|
||||
combined_dataset = concatenate_datasets(datasets_to_concat)
|
||||
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
|
||||
else:
|
||||
combined_dataset = datasets_to_concat[0]
|
||||
|
||||
all_datasets_to_concat.append(combined_dataset)
|
||||
|
||||
if not all_datasets_to_concat:
|
||||
raise RuntimeError("No valid datasets or splits found.")
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(all_datasets_to_concat) > 1:
|
||||
dataset = concatenate_datasets(all_datasets_to_concat)
|
||||
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
|
||||
else:
|
||||
dataset = all_datasets_to_concat[0]
|
||||
|
||||
# Apply MAX_DOCS limit if specified
|
||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||
if N < len(dataset):
|
||||
print(f"Limiting to {N} pages (from {len(dataset)} total)")
|
||||
dataset = dataset.select(range(N))
|
||||
|
||||
# Auto-detect image field name if not specified
|
||||
if IMAGE_FIELD_NAME is None:
|
||||
# Check multiple samples to find the most common image field
|
||||
# (useful when datasets are merged and may have different field names)
|
||||
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
|
||||
field_counts = {}
|
||||
|
||||
# Check first few samples to find image fields
|
||||
num_samples_to_check = min(10, len(dataset))
|
||||
for sample_idx in range(num_samples_to_check):
|
||||
sample = dataset[sample_idx]
|
||||
for field in possible_image_fields:
|
||||
if field in sample and sample[field] is not None:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
field_counts[field] = field_counts.get(field, 0) + 1
|
||||
|
||||
# Choose the most common field, or first found if tied
|
||||
if field_counts:
|
||||
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
|
||||
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
|
||||
else:
|
||||
# Fallback: check first sample only
|
||||
sample = dataset[0]
|
||||
image_field = None
|
||||
for field in possible_image_fields:
|
||||
if field in sample:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
image_field = field
|
||||
break
|
||||
if image_field is None:
|
||||
raise RuntimeError(
|
||||
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
|
||||
f"Please specify IMAGE_FIELD_NAME manually."
|
||||
)
|
||||
print(f"Auto-detected image field: '{image_field}'")
|
||||
else:
|
||||
image_field = IMAGE_FIELD_NAME
|
||||
if image_field not in dataset[0]:
|
||||
raise RuntimeError(
|
||||
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
|
||||
)
|
||||
|
||||
filepaths: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
|
||||
p = dataset[i]
|
||||
# Try to compose a descriptive identifier
|
||||
# Handle different dataset structures
|
||||
identifier_parts = []
|
||||
|
||||
# Helper function to safely get field value
|
||||
def safe_get(field_name, default=None):
|
||||
if field_name in p and p[field_name] is not None:
|
||||
return p[field_name]
|
||||
return default
|
||||
|
||||
# Try to get various identifier fields
|
||||
if safe_get("paper_arxiv_id"):
|
||||
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
|
||||
if safe_get("paper_title"):
|
||||
identifier_parts.append(f"title:{p['paper_title']}")
|
||||
if safe_get("page_number") is not None:
|
||||
try:
|
||||
identifier_parts.append(f"page:{int(p['page_number'])}")
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, use the raw value or skip
|
||||
if p['page_number']:
|
||||
identifier_parts.append(f"page:{p['page_number']}")
|
||||
if safe_get("page_id"):
|
||||
identifier_parts.append(f"id:{p['page_id']}")
|
||||
elif safe_get("questionId"):
|
||||
identifier_parts.append(f"qid:{p['questionId']}")
|
||||
elif safe_get("docId"):
|
||||
identifier_parts.append(f"docId:{p['docId']}")
|
||||
elif safe_get("id"):
|
||||
identifier_parts.append(f"id:{p['id']}")
|
||||
|
||||
# If no identifier parts found, create one from index
|
||||
if identifier_parts:
|
||||
identifier = "|".join(identifier_parts)
|
||||
else:
|
||||
# Create identifier from available fields or index
|
||||
fallback_parts = []
|
||||
# Try common fields that might exist
|
||||
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
|
||||
if safe_get(field):
|
||||
fallback_parts.append(f"{field}:{p[field]}")
|
||||
break
|
||||
if fallback_parts:
|
||||
identifier = "|".join(fallback_parts) + f"|idx:{i}"
|
||||
else:
|
||||
identifier = f"doc_{i}"
|
||||
|
||||
filepaths.append(identifier)
|
||||
|
||||
# Get image - try detected field first, then fallback to other common fields
|
||||
img = None
|
||||
if image_field in p and p[image_field] is not None:
|
||||
img = p[image_field]
|
||||
else:
|
||||
# Fallback: try other common image field names
|
||||
for fallback_field in ["image", "page_image", "images", "img"]:
|
||||
if fallback_field in p and p[fallback_field] is not None:
|
||||
img = p[fallback_field]
|
||||
break
|
||||
|
||||
if img is None:
|
||||
raise RuntimeError(
|
||||
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
|
||||
f"Expected field: {image_field}"
|
||||
)
|
||||
|
||||
# Ensure it's a PIL Image
|
||||
if not isinstance(img, Image.Image):
|
||||
if hasattr(img, 'convert'):
|
||||
img = img.convert('RGB')
|
||||
else:
|
||||
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
|
||||
images.append(img)
|
||||
else:
|
||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
print(f"Loaded {len(images)} images")
|
||||
|
||||
# Memory check before loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
print("Skipping dataset loading (using existing index)")
|
||||
filepaths = [] # Not needed when using existing index
|
||||
images = [] # Not needed when using existing index
|
||||
|
||||
|
||||
# %%
|
||||
# Step 3: Load model and processor (only if we need to build index or perform search)
|
||||
print("Step 3: Loading model and processor...")
|
||||
print(f" Model: {MODEL}")
|
||||
try:
|
||||
import sys
|
||||
print(f" Python version: {sys.version}")
|
||||
print(f" Python executable: {sys.executable}")
|
||||
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
# Memory check after loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
# Step 4: Build index if needed
|
||||
if need_to_build_index:
|
||||
print("Step 4: Building index...")
|
||||
print(f" Number of images: {len(images)}")
|
||||
print(f" Number of filepaths: {len(filepaths)}")
|
||||
|
||||
try:
|
||||
print(" Embedding images...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
print(f" Embedded {len(doc_vecs)} documents")
|
||||
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
|
||||
except Exception as e:
|
||||
print(f"Error embedding images: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Build Fast-Plaid index
|
||||
print(" Building Fast-Plaid index...")
|
||||
try:
|
||||
fast_plaid_index, build_secs = _build_fast_plaid_index(
|
||||
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
|
||||
)
|
||||
from pathlib import Path
|
||||
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
|
||||
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
|
||||
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
|
||||
except Exception as e:
|
||||
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
else:
|
||||
# Build original LEANN index
|
||||
try:
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
except Exception as e:
|
||||
print(f"Error building LEANN index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
|
||||
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them
|
||||
|
||||
|
||||
# %%
|
||||
# Step 5: Embed query and search
|
||||
_t0 = time.perf_counter()
|
||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||
query_embed_secs = time.perf_counter() - _t0
|
||||
|
||||
print(f"[Search] Method: {SEARCH_METHOD}")
|
||||
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
|
||||
|
||||
# Run the selected search method and time it
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid search
|
||||
if fast_plaid_index is None:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is None:
|
||||
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
|
||||
|
||||
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
|
||||
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
|
||||
else:
|
||||
# Original LEANN search
|
||||
query_np = q_vec.float().numpy()
|
||||
|
||||
if SEARCH_METHOD == "ann":
|
||||
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact":
|
||||
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact-all":
|
||||
results = retriever.search_exact_all(query_np, topk=TOPK)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
|
||||
else:
|
||||
results = []
|
||||
if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
print("\n[DEBUG] Retrieval details:")
|
||||
top_images: list[Image.Image] = []
|
||||
image_hashes = {} # Track image hashes to detect duplicates
|
||||
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
# Retrieve image and metadata based on index type
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid: load image and get metadata
|
||||
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not find image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
|
||||
top_images.append(image)
|
||||
else:
|
||||
# Original LEANN: retrieve from retriever
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
top_images.append(image)
|
||||
|
||||
# Calculate image hash to detect duplicates
|
||||
import hashlib
|
||||
import io
|
||||
# Convert image to bytes for hashing
|
||||
img_bytes = io.BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
image_bytes = img_bytes.getvalue()
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
|
||||
|
||||
# Check if this image was already seen
|
||||
duplicate_info = ""
|
||||
if image_hash in image_hashes:
|
||||
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
|
||||
else:
|
||||
image_hashes[image_hash] = rank
|
||||
|
||||
# Print detailed information
|
||||
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
|
||||
if metadata:
|
||||
print(f" Metadata: {metadata}")
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
|
||||
base = _Path(SAVE_TOP_IMAGE)
|
||||
base.parent.mkdir(parents=True, exist_ok=True)
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if base.suffix:
|
||||
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||
else:
|
||||
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||
img.save(str(out_path))
|
||||
# Print the retrieval score (document-level MaxSim) alongside the saved path
|
||||
try:
|
||||
score, _doc_id = results[rank - 1]
|
||||
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
|
||||
except Exception:
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
|
||||
|
||||
# %%
|
||||
# Step 6: Similarity maps for top-K results
|
||||
if results and SIMILARITY_MAP:
|
||||
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||
from pathlib import Path as _Path
|
||||
|
||||
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if output_base:
|
||||
if output_base.suffix:
|
||||
out_dir = output_base.parent
|
||||
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||
out_path = str(out_dir / out_name)
|
||||
else:
|
||||
out_dir = output_base
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||
else:
|
||||
out_path = None
|
||||
chosen_idx, max_sim = _generate_similarity_map(
|
||||
model=model,
|
||||
processor=processor,
|
||||
image=img,
|
||||
query=QUERY,
|
||||
token_idx=token_idx,
|
||||
output_path=out_path,
|
||||
)
|
||||
if out_path:
|
||||
print(
|
||||
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Step 7: Optional answer generation
|
||||
if results and ANSWER:
|
||||
qwen = QwenVL(device=device_str)
|
||||
_t0 = time.perf_counter()
|
||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||
gen_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Generation: {gen_secs:.3f}s")
|
||||
print("\nAnswer:")
|
||||
print(response)
|
||||
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v1 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v1 tasks
|
||||
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# ViDoRe v1 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V1_TASKS = {
|
||||
"VidoreArxivQARetrieval": {
|
||||
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
|
||||
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreDocVQARetrieval": {
|
||||
"dataset_path": "vidore/docvqa_test_subsampled_beir",
|
||||
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreInfoVQARetrieval": {
|
||||
"dataset_path": "vidore/infovqa_test_subsampled_beir",
|
||||
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTabfquadRetrieval": {
|
||||
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
|
||||
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTatdqaRetrieval": {
|
||||
"dataset_path": "vidore/tatdqa_test_beir",
|
||||
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreShiftProjectRetrieval": {
|
||||
"dataset_path": "vidore/shiftproject_test_beir",
|
||||
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAAIRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
|
||||
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAEnergyRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
|
||||
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
|
||||
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
|
||||
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
# Task name aliases (short names -> full names)
|
||||
TASK_ALIASES = {
|
||||
"arxivqa": "VidoreArxivQARetrieval",
|
||||
"docvqa": "VidoreDocVQARetrieval",
|
||||
"infovqa": "VidoreInfoVQARetrieval",
|
||||
"tabfquad": "VidoreTabfquadRetrieval",
|
||||
"tatdqa": "VidoreTatdqaRetrieval",
|
||||
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||
}
|
||||
|
||||
|
||||
def normalize_task_name(task_name: str) -> str:
|
||||
"""Normalize task name (handle aliases)."""
|
||||
task_name_lower = task_name.lower()
|
||||
if task_name in VIDORE_V1_TASKS:
|
||||
return task_name
|
||||
if task_name_lower in TASK_ALIASES:
|
||||
return TASK_ALIASES[task_name_lower]
|
||||
# Try partial match
|
||||
for alias, full_name in TASK_ALIASES.items():
|
||||
if alias in task_name_lower or task_name_lower in alias:
|
||||
return full_name
|
||||
return task_name
|
||||
|
||||
|
||||
def get_safe_model_name(model_name: str) -> str:
|
||||
"""Get a safe model name for use in file paths."""
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
# If it's a path, use basename or hash
|
||||
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||
# Use basename if it's reasonable, otherwise use hash
|
||||
basename = os.path.basename(model_name.rstrip("/"))
|
||||
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||
return basename
|
||||
# Use hash for very long or problematic paths
|
||||
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||
# For HuggingFace model names, replace / with _
|
||||
return model_name.replace("/", "_").replace(":", "_")
|
||||
|
||||
|
||||
def load_vidore_v1_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v1 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
|
||||
queries = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
|
||||
corpus = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
|
||||
qrels = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 1000,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v1 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Normalize task name (handle aliases)
|
||||
task_name = normalize_task_name(task_name)
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V1_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V1_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v1_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
)
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
# Use safe model name for index path (different models need different indexes)
|
||||
safe_model_name = get_safe_model_name(model_name)
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,20,100,1000",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v1_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [normalize_task_name(args.task)]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,439 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v2 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v2 tasks
|
||||
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Language name to dataset language field value mapping
|
||||
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
|
||||
LANGUAGE_MAPPING = {
|
||||
"english": "eng-Latn",
|
||||
"french": "fra-Latn",
|
||||
"spanish": "spa-Latn",
|
||||
"german": "deu-Latn",
|
||||
}
|
||||
|
||||
# ViDoRe v2 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V2_TASKS = {
|
||||
"Vidore2ESGReportsRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_v2",
|
||||
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2EconomicsReportsRetrieval": {
|
||||
"dataset_path": "vidore/economics_reports_v2",
|
||||
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2BioMedicalLecturesRetrieval": {
|
||||
"dataset_path": "vidore/biomedical_lectures_v2",
|
||||
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2ESGReportsHLRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_human_labeled_v2",
|
||||
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
|
||||
# Note: This dataset doesn't have language filtering - all queries are English
|
||||
"languages": None, # No language filtering needed
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_vidore_v2_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v2 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
|
||||
# Check if dataset has language field before filtering
|
||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||
|
||||
if language and has_language_field:
|
||||
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||
# Check if filtering resulted in empty dataset
|
||||
if len(query_ds_filtered) == 0:
|
||||
print(
|
||||
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||
)
|
||||
# Try with original language value (dataset might use simple names like 'english')
|
||||
print(f"Trying with original language value '{language}'...")
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||
if len(query_ds_filtered) == 0:
|
||||
# Try to get a sample to see actual language values
|
||||
try:
|
||||
sample_ds = load_dataset(
|
||||
dataset_path, "queries", split=split, revision=revision
|
||||
)
|
||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||
sample_langs = set(sample_ds["language"])
|
||||
print(f"Available language values in dataset: {sample_langs}")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||
)
|
||||
query_ds = query_ds_filtered
|
||||
|
||||
queries = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
|
||||
corpus = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
|
||||
qrels = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v2 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V2_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V2_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
|
||||
# Determine language
|
||||
if language is None:
|
||||
# Use first language if multiple available
|
||||
languages = task_config.get("languages")
|
||||
if languages is None:
|
||||
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||
language = None
|
||||
elif len(languages) == 1:
|
||||
language = languages[0]
|
||||
else:
|
||||
language = None
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v2_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(
|
||||
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||
)
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
choices=["colqwen2", "colpali"],
|
||||
help="Model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Language to evaluate (default: first available)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Top-k results to retrieve",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,100",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v2_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [args.task]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
language=args.language,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
183
apps/semantic_file_search/leann-plus-temporal-search.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
|
||||
class TimeParser:
|
||||
def __init__(self):
|
||||
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
|
||||
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
|
||||
|
||||
# Compile for performance
|
||||
self.regex = re.compile(self.pattern, re.IGNORECASE)
|
||||
|
||||
# Stop words to remove before regex parsing
|
||||
self.stop_words = {
|
||||
"in",
|
||||
"at",
|
||||
"of",
|
||||
"by",
|
||||
"as",
|
||||
"me",
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"any",
|
||||
"find",
|
||||
"search",
|
||||
"list",
|
||||
"ago",
|
||||
"back",
|
||||
"past",
|
||||
"earlier",
|
||||
}
|
||||
|
||||
def clean_text(self, text):
|
||||
"""Remove stop words from text"""
|
||||
words = text.split()
|
||||
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
|
||||
return cleaned
|
||||
|
||||
def parse(self, text):
|
||||
"""Extract all time expressions from text"""
|
||||
# Clean text first
|
||||
cleaned_text = self.clean_text(text)
|
||||
|
||||
matches = []
|
||||
for match in self.regex.finditer(cleaned_text):
|
||||
fuzzy = match.group(1) # "around", "about", etc.
|
||||
number = int(match.group(2))
|
||||
unit = match.group(3).lower()
|
||||
|
||||
matches.append(
|
||||
{
|
||||
"full_match": match.group(0),
|
||||
"fuzzy": bool(fuzzy),
|
||||
"number": number,
|
||||
"unit": unit,
|
||||
"range": self.calculate_range(number, unit, bool(fuzzy)),
|
||||
}
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
def calculate_range(self, number, unit, is_fuzzy):
|
||||
"""Convert to actual datetime range and return ISO format strings"""
|
||||
units = {
|
||||
"hour": timedelta(hours=number),
|
||||
"day": timedelta(days=number),
|
||||
"week": timedelta(weeks=number),
|
||||
"month": timedelta(days=number * 30),
|
||||
"year": timedelta(days=number * 365),
|
||||
}
|
||||
|
||||
delta = units[unit]
|
||||
now = datetime.now()
|
||||
target = now - delta
|
||||
|
||||
if is_fuzzy:
|
||||
buffer = delta * 0.2 # 20% buffer for fuzzy
|
||||
start = (target - buffer).isoformat()
|
||||
end = (target + buffer).isoformat()
|
||||
else:
|
||||
start = target.isoformat()
|
||||
end = now.isoformat()
|
||||
|
||||
return (start, end)
|
||||
|
||||
|
||||
def search_files(query, top_k=15):
|
||||
"""Search the index and return results"""
|
||||
# Parse time expressions
|
||||
parser = TimeParser()
|
||||
time_matches = parser.parse(query)
|
||||
|
||||
# Remove time expressions from query for semantic search
|
||||
clean_query = query
|
||||
if time_matches:
|
||||
for match in time_matches:
|
||||
clean_query = clean_query.replace(match["full_match"], "").strip()
|
||||
|
||||
# Check if clean_query is less than 4 characters
|
||||
if len(clean_query) < 4:
|
||||
print("Error: add more input for accurate results.")
|
||||
return
|
||||
|
||||
# Single query to vector DB
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search(
|
||||
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
|
||||
)
|
||||
|
||||
# Filter by time if time expression found
|
||||
if time_matches:
|
||||
time_range = time_matches[0]["range"] # Use first time expression
|
||||
start_time, end_time = time_range
|
||||
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
# Access metadata attribute directly (not .get())
|
||||
metadata = result.metadata if hasattr(result, "metadata") else {}
|
||||
|
||||
if metadata:
|
||||
# Check modification date first, fall back to creation date
|
||||
date_str = metadata.get("modification_date") or metadata.get("creation_date")
|
||||
|
||||
if date_str:
|
||||
# Convert strings to datetime objects for proper comparison
|
||||
try:
|
||||
file_date = datetime.fromisoformat(date_str)
|
||||
start_dt = datetime.fromisoformat(start_time)
|
||||
end_dt = datetime.fromisoformat(end_time)
|
||||
|
||||
# Compare dates properly
|
||||
if start_dt <= file_date <= end_dt:
|
||||
filtered_results.append(result)
|
||||
except (ValueError, TypeError):
|
||||
# Handle invalid date formats
|
||||
print(f"Warning: Invalid date format in metadata: {date_str}")
|
||||
continue
|
||||
|
||||
results = filtered_results
|
||||
|
||||
# Print results
|
||||
print(f"\nSearch results for: '{query}'")
|
||||
if time_matches:
|
||||
print(
|
||||
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
|
||||
)
|
||||
print(
|
||||
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
|
||||
)
|
||||
print("-" * 80)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n[{i}] Score: {result.score:.4f}")
|
||||
print(f"Content: {result.text}")
|
||||
|
||||
# Show metadata if present
|
||||
metadata = result.metadata if hasattr(result, "metadata") else None
|
||||
if metadata:
|
||||
if "creation_date" in metadata:
|
||||
print(f"Created: {metadata['creation_date']}")
|
||||
if "modification_date" in metadata:
|
||||
print(f"Modified: {metadata['modification_date']}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print('Usage: python search_index.py "<search query>" [top_k]')
|
||||
sys.exit(1)
|
||||
|
||||
query = sys.argv[1]
|
||||
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
|
||||
|
||||
search_files(query, top_k)
|
||||
82
apps/semantic_file_search/leann_index_builder.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
|
||||
def process_json_items(json_file_path):
|
||||
"""Load and process JSON file with metadata items"""
|
||||
|
||||
with open(json_file_path, encoding="utf-8") as f:
|
||||
items = json.load(f)
|
||||
|
||||
# Guard against empty JSON
|
||||
if not items:
|
||||
print("⚠️ No items found in the JSON file. Exiting gracefully.")
|
||||
return
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
|
||||
|
||||
total_items = len(items)
|
||||
items_added = 0
|
||||
print(f"Processing {total_items} items...")
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
try:
|
||||
# Create embedding text sentence
|
||||
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
|
||||
|
||||
# Prepare metadata with dates
|
||||
metadata = {}
|
||||
if "CreationDate" in item:
|
||||
metadata["creation_date"] = item["CreationDate"]
|
||||
if "ContentChangeDate" in item:
|
||||
metadata["modification_date"] = item["ContentChangeDate"]
|
||||
|
||||
# Add to builder
|
||||
builder.add_text(embedding_text, metadata=metadata)
|
||||
items_added += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
|
||||
continue
|
||||
|
||||
# Show progress
|
||||
progress = (idx + 1) / total_items * 100
|
||||
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
|
||||
sys.stdout.flush()
|
||||
|
||||
print() # New line after progress
|
||||
|
||||
# Guard against no successfully added items
|
||||
if items_added == 0:
|
||||
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
|
||||
return
|
||||
|
||||
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
|
||||
print("Building index...")
|
||||
|
||||
try:
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"✓ Index saved to {INDEX_PATH}")
|
||||
except ValueError as e:
|
||||
if "No chunks added" in str(e):
|
||||
print("⚠️ No chunks were added to the builder. Index not created.")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python build_index.py <json_file>")
|
||||
sys.exit(1)
|
||||
|
||||
json_file = sys.argv[1]
|
||||
if not Path(json_file).exists():
|
||||
print(f"Error: File {json_file} not found")
|
||||
sys.exit(1)
|
||||
|
||||
process_json_items(json_file)
|
||||
265
apps/semantic_file_search/spotlight_index_dump.py
Normal file
@@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Spotlight Metadata Dumper for Vector DB
|
||||
Extracts only essential metadata for semantic search embeddings
|
||||
Output is optimized for vector database storage with minimal fields
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# Check platform before importing macOS-specific modules
|
||||
if sys.platform != "darwin":
|
||||
print("This script requires macOS (uses Spotlight)")
|
||||
sys.exit(1)
|
||||
|
||||
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
|
||||
|
||||
# EDIT THIS LIST: Add or remove folders to search
|
||||
# Can be either:
|
||||
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
|
||||
# - Absolute paths (e.g., "/Applications", "/System/Library")
|
||||
SEARCH_FOLDERS = [
|
||||
"Desktop",
|
||||
"Downloads",
|
||||
"Documents",
|
||||
"Music",
|
||||
"Pictures",
|
||||
"Movies",
|
||||
# "Library", # Uncomment to include
|
||||
# "/Applications", # Absolute path example
|
||||
# "Code/Projects", # Subfolder example
|
||||
# Add any other folders here
|
||||
]
|
||||
|
||||
|
||||
def convert_to_serializable(obj):
|
||||
"""Convert NS objects to Python serializable types"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle NSDate
|
||||
if hasattr(obj, "timeIntervalSince1970"):
|
||||
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
|
||||
|
||||
# Handle NSArray
|
||||
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
|
||||
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
|
||||
|
||||
# Convert to string
|
||||
try:
|
||||
return str(obj)
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
|
||||
|
||||
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
|
||||
"""
|
||||
Dump Spotlight data using public.item predicate
|
||||
"""
|
||||
# Build full paths from SEARCH_FOLDERS
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
search_paths = []
|
||||
|
||||
print("Search locations:")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Check if it's an absolute path or relative
|
||||
if folder.startswith("/"):
|
||||
full_path = folder
|
||||
else:
|
||||
full_path = os.path.join(home_dir, folder)
|
||||
|
||||
if os.path.exists(full_path):
|
||||
search_paths.append(full_path)
|
||||
print(f" ✓ {full_path}")
|
||||
else:
|
||||
print(f" ✗ {full_path} (not found)")
|
||||
|
||||
if not search_paths:
|
||||
print("No valid search paths found!")
|
||||
return []
|
||||
|
||||
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
|
||||
|
||||
# Create query with public.item predicate
|
||||
query = NSMetadataQuery.alloc().init()
|
||||
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
|
||||
query.setPredicate_(predicate)
|
||||
|
||||
# Set search scopes to our specific folders
|
||||
query.setSearchScopes_(search_paths)
|
||||
|
||||
print("Starting query...")
|
||||
query.startQuery()
|
||||
|
||||
# Wait for gathering to complete
|
||||
run_loop = NSRunLoop.currentRunLoop()
|
||||
print("Gathering results...")
|
||||
|
||||
# Let it gather for a few seconds
|
||||
for i in range(50): # 5 seconds max
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
# Check gathering status periodically
|
||||
if i % 10 == 0:
|
||||
current_count = query.resultCount()
|
||||
if current_count > 0:
|
||||
print(f" Found {current_count} items so far...")
|
||||
|
||||
# Continue while still gathering (up to 2 more seconds)
|
||||
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
|
||||
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
|
||||
query.stopQuery()
|
||||
|
||||
total_results = query.resultCount()
|
||||
print(f"Found {total_results} total items")
|
||||
|
||||
if total_results == 0:
|
||||
print("No results found")
|
||||
return []
|
||||
|
||||
# Process items
|
||||
items_to_process = min(total_results, max_items)
|
||||
results = []
|
||||
|
||||
# ONLY relevant attributes for vector embeddings
|
||||
# These provide essential context for semantic search without bloat
|
||||
attributes = [
|
||||
"kMDItemPath", # Full path for file retrieval
|
||||
"kMDItemFSName", # Filename for display & embedding
|
||||
"kMDItemFSSize", # Size for filtering/ranking
|
||||
"kMDItemContentType", # File type for categorization
|
||||
"kMDItemKind", # Human-readable type for embedding
|
||||
"kMDItemFSCreationDate", # Temporal context
|
||||
"kMDItemFSContentChangeDate", # Recency for ranking
|
||||
]
|
||||
|
||||
print(f"Processing {items_to_process} items...")
|
||||
|
||||
for i in range(items_to_process):
|
||||
try:
|
||||
item = query.resultAtIndex_(i)
|
||||
metadata = {}
|
||||
|
||||
# Extract ONLY the relevant attributes
|
||||
for attr in attributes:
|
||||
try:
|
||||
value = item.valueForAttribute_(attr)
|
||||
if value is not None:
|
||||
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
|
||||
clean_key = attr.replace("kMDItem", "").replace("FS", "")
|
||||
metadata[clean_key] = convert_to_serializable(value)
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Only add if we have at least a path
|
||||
if metadata.get("Path"):
|
||||
results.append(metadata)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing item {i}: {e}")
|
||||
continue
|
||||
|
||||
# Save to JSON
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\n✓ Saved {len(results)} items to {output_file}")
|
||||
|
||||
# Show summary
|
||||
print("\nSample items:")
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
for i, item in enumerate(results[:3]):
|
||||
print(f"\n[Item {i + 1}]")
|
||||
print(f" Path: {item.get('Path', 'N/A')}")
|
||||
print(f" Name: {item.get('Name', 'N/A')}")
|
||||
print(f" Type: {item.get('ContentType', 'N/A')}")
|
||||
print(f" Kind: {item.get('Kind', 'N/A')}")
|
||||
|
||||
# Handle size properly
|
||||
size = item.get("Size")
|
||||
if size:
|
||||
try:
|
||||
size_int = int(size)
|
||||
if size_int > 1024 * 1024:
|
||||
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
|
||||
elif size_int > 1024:
|
||||
print(f" Size: {size_int / 1024:.2f} KB")
|
||||
else:
|
||||
print(f" Size: {size_int} bytes")
|
||||
except (ValueError, TypeError):
|
||||
print(f" Size: {size}")
|
||||
|
||||
# Show dates
|
||||
if "CreationDate" in item:
|
||||
print(f" Created: {item['CreationDate']}")
|
||||
if "ContentChangeDate" in item:
|
||||
print(f" Modified: {item['ContentChangeDate']}")
|
||||
|
||||
# Count by type
|
||||
type_counts = {}
|
||||
for item in results:
|
||||
content_type = item.get("ContentType", "unknown")
|
||||
type_counts[content_type] = type_counts.get(content_type, 0) + 1
|
||||
|
||||
print(f"\nTotal items saved: {len(results)}")
|
||||
|
||||
if type_counts:
|
||||
print("\nTop content types:")
|
||||
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||
print(f" {ct}: {count} items")
|
||||
|
||||
# Count by folder
|
||||
folder_counts = {}
|
||||
for item in results:
|
||||
path = item.get("Path", "")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Build the full folder path
|
||||
if folder.startswith("/"):
|
||||
folder_path = folder
|
||||
else:
|
||||
folder_path = os.path.join(home_dir, folder)
|
||||
|
||||
if path.startswith(folder_path):
|
||||
folder_counts[folder] = folder_counts.get(folder, 0) + 1
|
||||
break
|
||||
|
||||
if folder_counts:
|
||||
print("\nItems by location:")
|
||||
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" {folder}: {count} items")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
max_items = int(sys.argv[1])
|
||||
except ValueError:
|
||||
print("Usage: python spot.py [number_of_items]")
|
||||
print("Default: 10 items")
|
||||
sys.exit(1)
|
||||
else:
|
||||
max_items = 10
|
||||
|
||||
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
|
||||
|
||||
# Run dump
|
||||
dump_spotlight_data(max_items=max_items, output_file=output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
apps/slack_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Slack MCP data integration for LEANN
|
||||
511
apps/slack_data/slack_mcp_reader.py
Normal file
@@ -0,0 +1,511 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Slack MCP Reader for LEANN
|
||||
|
||||
This module provides functionality to connect to Slack MCP servers and fetch message data
|
||||
for indexing in LEANN. It supports various Slack MCP server implementations and provides
|
||||
flexible message processing options.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackMCPReader:
|
||||
"""
|
||||
Reader for Slack data via MCP (Model Context Protocol) servers.
|
||||
|
||||
This class connects to Slack MCP servers to fetch message data and convert it
|
||||
into a format suitable for LEANN indexing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_server_command: str,
|
||||
workspace_name: Optional[str] = None,
|
||||
concatenate_conversations: bool = True,
|
||||
max_messages_per_conversation: int = 100,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Initialize the Slack MCP Reader.
|
||||
|
||||
Args:
|
||||
mcp_server_command: Command to start the MCP server (e.g., 'slack-mcp-server')
|
||||
workspace_name: Optional workspace name to filter messages
|
||||
concatenate_conversations: Whether to group messages by channel/thread
|
||||
max_messages_per_conversation: Maximum messages to include per conversation
|
||||
max_retries: Maximum number of retries for failed operations
|
||||
retry_delay: Initial delay between retries in seconds
|
||||
"""
|
||||
self.mcp_server_command = mcp_server_command
|
||||
self.workspace_name = workspace_name
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
self.max_messages_per_conversation = max_messages_per_conversation
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.mcp_process = None
|
||||
|
||||
async def start_mcp_server(self):
|
||||
"""Start the MCP server process."""
|
||||
try:
|
||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
||||
*self.mcp_server_command.split(),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}")
|
||||
raise
|
||||
|
||||
async def stop_mcp_server(self):
|
||||
"""Stop the MCP server process."""
|
||||
if self.mcp_process:
|
||||
self.mcp_process.terminate()
|
||||
await self.mcp_process.wait()
|
||||
logger.info("Stopped MCP server")
|
||||
|
||||
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send a request to the MCP server and get response."""
|
||||
if not self.mcp_process:
|
||||
raise RuntimeError("MCP server not started")
|
||||
|
||||
request_json = json.dumps(request) + "\n"
|
||||
self.mcp_process.stdin.write(request_json.encode())
|
||||
await self.mcp_process.stdin.drain()
|
||||
|
||||
response_line = await self.mcp_process.stdout.readline()
|
||||
if not response_line:
|
||||
raise RuntimeError("No response from MCP server")
|
||||
|
||||
return json.loads(response_line.decode().strip())
|
||||
|
||||
async def initialize_mcp_connection(self):
|
||||
"""Initialize the MCP connection."""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(init_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
||||
|
||||
logger.info("MCP connection initialized successfully")
|
||||
|
||||
async def list_available_tools(self) -> list[dict[str, Any]]:
|
||||
"""List available tools from the MCP server."""
|
||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
||||
|
||||
response = await self.send_mcp_request(list_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
||||
|
||||
return response.get("result", {}).get("tools", [])
|
||||
|
||||
def _is_cache_sync_error(self, error: dict) -> bool:
|
||||
"""Check if the error is related to users cache not being ready."""
|
||||
if isinstance(error, dict):
|
||||
message = error.get("message", "").lower()
|
||||
return (
|
||||
"users cache is not ready" in message or "sync process is still running" in message
|
||||
)
|
||||
return False
|
||||
|
||||
async def _retry_with_backoff(self, func, *args, **kwargs):
|
||||
"""Retry a function with exponential backoff, especially for cache sync issues."""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if this is a cache sync error
|
||||
error_dict = {}
|
||||
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
|
||||
error_dict = e.args[0]
|
||||
elif "Failed to fetch messages" in str(e):
|
||||
# Try to extract error from the exception message
|
||||
import re
|
||||
|
||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
else:
|
||||
# Try alternative format
|
||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
if self._is_cache_sync_error(error_dict):
|
||||
if attempt < self.max_retries:
|
||||
delay = self.retry_delay * (2**attempt) # Exponential backoff
|
||||
logger.info(
|
||||
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cache sync still not ready after {self.max_retries} retries, giving up"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a cache sync error, don't retry
|
||||
break
|
||||
|
||||
# If we get here, all retries failed or it's not a retryable error
|
||||
raise last_exception
|
||||
|
||||
async def fetch_slack_messages(
|
||||
self, channel: Optional[str] = None, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
|
||||
|
||||
Args:
|
||||
channel: Optional channel name to filter messages
|
||||
limit: Maximum number of messages to fetch
|
||||
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
|
||||
|
||||
async def _fetch_slack_messages_impl(
|
||||
self, channel: Optional[str] = None, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Internal implementation of fetch_slack_messages without retry logic.
|
||||
"""
|
||||
# This is a generic implementation - specific MCP servers may have different tool names
|
||||
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
|
||||
|
||||
tools = await self.list_available_tools()
|
||||
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
|
||||
message_tool = None
|
||||
|
||||
# Look for a tool that can fetch messages - prioritize conversations_history
|
||||
message_tool = None
|
||||
|
||||
# First, try to find conversations_history specifically
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if "conversations_history" in tool_name:
|
||||
message_tool = tool
|
||||
logger.info(f"Found conversations_history tool: {tool}")
|
||||
break
|
||||
|
||||
# If not found, look for other message-fetching tools
|
||||
if not message_tool:
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if any(
|
||||
keyword in tool_name
|
||||
for keyword in ["conversations_search", "message", "history"]
|
||||
):
|
||||
message_tool = tool
|
||||
break
|
||||
|
||||
if not message_tool:
|
||||
raise RuntimeError("No message fetching tool found in MCP server")
|
||||
|
||||
# Prepare tool call parameters
|
||||
tool_params = {"limit": "180d"} # Use 180 days to get older messages
|
||||
if channel:
|
||||
# For conversations_history, use channel_id parameter
|
||||
if message_tool["name"] == "conversations_history":
|
||||
tool_params["channel_id"] = channel
|
||||
else:
|
||||
# Try common parameter names for channel specification
|
||||
for param_name in ["channel", "channel_id", "channel_name"]:
|
||||
tool_params[param_name] = channel
|
||||
break
|
||||
|
||||
logger.info(f"Tool parameters: {tool_params}")
|
||||
|
||||
fetch_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {"name": message_tool["name"], "arguments": tool_params},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(fetch_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to fetch messages: {response['error']}")
|
||||
|
||||
# Extract messages from response - format may vary by MCP server
|
||||
result = response.get("result", {})
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
# Some MCP servers return content as a list
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
try:
|
||||
messages = json.loads(content["text"])
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, try to parse as CSV format (Slack MCP server format)
|
||||
messages = self._parse_csv_messages(content["text"], channel)
|
||||
else:
|
||||
messages = result["content"]
|
||||
else:
|
||||
# Direct message format
|
||||
messages = result.get("messages", [result])
|
||||
|
||||
return messages if isinstance(messages, list) else [messages]
|
||||
|
||||
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
|
||||
"""Parse CSV format messages from Slack MCP server."""
|
||||
import csv
|
||||
import io
|
||||
|
||||
messages = []
|
||||
try:
|
||||
# Split by lines and process each line as a CSV row
|
||||
lines = csv_text.strip().split("\n")
|
||||
if not lines:
|
||||
return messages
|
||||
|
||||
# Skip header line if it exists
|
||||
start_idx = 0
|
||||
if lines[0].startswith("MsgID,UserID,UserName"):
|
||||
start_idx = 1
|
||||
|
||||
for line in lines[start_idx:]:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Parse CSV line
|
||||
reader = csv.reader(io.StringIO(line))
|
||||
try:
|
||||
row = next(reader)
|
||||
if len(row) >= 7: # Ensure we have enough columns
|
||||
message = {
|
||||
"ts": row[0],
|
||||
"user": row[1],
|
||||
"username": row[2],
|
||||
"real_name": row[3],
|
||||
"channel": row[4],
|
||||
"thread_ts": row[5],
|
||||
"text": row[6],
|
||||
"time": row[7] if len(row) > 7 else "",
|
||||
"reactions": row[8] if len(row) > 8 else "",
|
||||
"cursor": row[9] if len(row) > 9 else "",
|
||||
}
|
||||
messages.append(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse CSV messages: {e}")
|
||||
# Fallback: treat entire text as one message
|
||||
messages = [{"text": csv_text, "channel": channel or "unknown"}]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_message(self, message: dict[str, Any]) -> str:
|
||||
"""Format a single message for indexing."""
|
||||
text = message.get("text", "")
|
||||
user = message.get("user", message.get("username", "Unknown"))
|
||||
channel = message.get("channel", message.get("channel_name", "Unknown"))
|
||||
timestamp = message.get("ts", message.get("timestamp", ""))
|
||||
|
||||
# Format timestamp if available
|
||||
formatted_time = ""
|
||||
if timestamp:
|
||||
try:
|
||||
import datetime
|
||||
|
||||
if isinstance(timestamp, str) and "." in timestamp:
|
||||
dt = datetime.datetime.fromtimestamp(float(timestamp))
|
||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
elif isinstance(timestamp, (int, float)):
|
||||
dt = datetime.datetime.fromtimestamp(timestamp)
|
||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
formatted_time = str(timestamp)
|
||||
except (ValueError, TypeError):
|
||||
formatted_time = str(timestamp)
|
||||
|
||||
# Build formatted message
|
||||
parts = []
|
||||
if channel:
|
||||
parts.append(f"Channel: #{channel}")
|
||||
if user:
|
||||
parts.append(f"User: {user}")
|
||||
if formatted_time:
|
||||
parts.append(f"Time: {formatted_time}")
|
||||
if text:
|
||||
parts.append(f"Message: {text}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _create_concatenated_content(self, messages: list[dict[str, Any]], channel: str) -> str:
|
||||
"""Create concatenated content from multiple messages in a channel."""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Sort messages by timestamp if available
|
||||
try:
|
||||
messages.sort(key=lambda x: float(x.get("ts", x.get("timestamp", 0))))
|
||||
except (ValueError, TypeError):
|
||||
pass # Keep original order if timestamps aren't numeric
|
||||
|
||||
# Limit messages per conversation
|
||||
if len(messages) > self.max_messages_per_conversation:
|
||||
messages = messages[-self.max_messages_per_conversation :]
|
||||
|
||||
# Create header
|
||||
content_parts = [
|
||||
f"Slack Channel: #{channel}",
|
||||
f"Message Count: {len(messages)}",
|
||||
f"Workspace: {self.workspace_name or 'Unknown'}",
|
||||
"=" * 50,
|
||||
"",
|
||||
]
|
||||
|
||||
# Add messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
content_parts.append(formatted_msg)
|
||||
content_parts.append("-" * 30)
|
||||
content_parts.append("")
|
||||
|
||||
return "\n".join(content_parts)
|
||||
|
||||
async def get_all_channels(self) -> list[str]:
|
||||
"""Get list of all available channels."""
|
||||
try:
|
||||
channels_list_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "channels_list", "arguments": {}},
|
||||
}
|
||||
channels_response = await self.send_mcp_request(channels_list_request)
|
||||
if "result" in channels_response:
|
||||
result = channels_response["result"]
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
# Parse the channels from the response
|
||||
channels = []
|
||||
lines = content["text"].split("\n")
|
||||
for line in lines:
|
||||
if line.strip() and ("#" in line or "C" in line[:10]):
|
||||
# Extract channel ID or name
|
||||
parts = line.split()
|
||||
for part in parts:
|
||||
if part.startswith("C") and len(part) > 5:
|
||||
channels.append(part)
|
||||
elif part.startswith("#"):
|
||||
channels.append(part[1:]) # Remove #
|
||||
logger.info(f"Found {len(channels)} channels: {channels}")
|
||||
return channels
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get channels list: {e}")
|
||||
return []
|
||||
|
||||
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
|
||||
"""
|
||||
Read Slack data and return formatted text chunks.
|
||||
|
||||
Args:
|
||||
channels: Optional list of channel names to fetch. If None, fetches from all available channels.
|
||||
|
||||
Returns:
|
||||
List of formatted text chunks ready for LEANN indexing
|
||||
"""
|
||||
try:
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
|
||||
all_texts = []
|
||||
|
||||
if channels:
|
||||
# Fetch specific channels
|
||||
for channel in channels:
|
||||
try:
|
||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
||||
if messages:
|
||||
if self.concatenate_conversations:
|
||||
text_content = self._create_concatenated_content(messages, channel)
|
||||
if text_content.strip():
|
||||
all_texts.append(text_content)
|
||||
else:
|
||||
# Process individual messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
all_texts.append(formatted_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||
continue
|
||||
else:
|
||||
# Fetch from all available channels
|
||||
logger.info("Fetching from all available channels...")
|
||||
all_channels = await self.get_all_channels()
|
||||
|
||||
if not all_channels:
|
||||
# Fallback to common channel names if we can't get the list
|
||||
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
|
||||
logger.info(f"Using fallback channels: {all_channels}")
|
||||
|
||||
for channel in all_channels:
|
||||
try:
|
||||
logger.info(f"Searching channel: {channel}")
|
||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
||||
if messages:
|
||||
if self.concatenate_conversations:
|
||||
text_content = self._create_concatenated_content(messages, channel)
|
||||
if text_content.strip():
|
||||
all_texts.append(text_content)
|
||||
else:
|
||||
# Process individual messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
all_texts.append(formatted_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||
continue
|
||||
|
||||
return all_texts
|
||||
|
||||
finally:
|
||||
await self.stop_mcp_server()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.stop_mcp_server()
|
||||
227
apps/slack_rag.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Slack RAG Application with MCP Support
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on Slack messages
|
||||
by connecting to Slack MCP servers to fetch live data and index it in LEANN.
|
||||
|
||||
Usage:
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --query "What did the team discuss about the project?"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
||||
|
||||
|
||||
class SlackMCPRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for Slack messages via MCP servers.
|
||||
|
||||
This class provides a complete RAG pipeline for Slack data, including
|
||||
MCP server connection, data fetching, indexing, and interactive chat.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Slack MCP RAG",
|
||||
description="RAG application for Slack messages via MCP servers",
|
||||
default_index_name="slack_messages",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add Slack MCP-specific arguments."""
|
||||
parser.add_argument(
|
||||
"--mcp-server",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Command to start the Slack MCP server (e.g., 'slack-mcp-server' or 'npx slack-mcp-server')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--workspace-name",
|
||||
type=str,
|
||||
help="Slack workspace name for better organization and filtering",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--channels",
|
||||
nargs="+",
|
||||
help="Specific Slack channels to index (e.g., general random). If not specified, fetches from all available channels",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Group messages by channel/thread for better context (default: True)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-concatenate-conversations",
|
||||
action="store_true",
|
||||
help="Process individual messages instead of grouping by channel",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-messages-per-channel",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of messages to include per channel (default: 100)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test-connection",
|
||||
action="store_true",
|
||||
help="Test MCP server connection and list available tools without indexing",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Maximum number of retries for failed operations (default: 5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Initial delay between retries in seconds (default: 2.0)",
|
||||
)
|
||||
|
||||
async def test_mcp_connection(self, args) -> bool:
|
||||
"""Test the MCP server connection and display available tools."""
|
||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
||||
|
||||
try:
|
||||
reader = SlackMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
workspace_name=args.workspace_name,
|
||||
concatenate_conversations=not args.no_concatenate_conversations,
|
||||
max_messages_per_conversation=args.max_messages_per_channel,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
async with reader:
|
||||
tools = await reader.list_available_tools()
|
||||
|
||||
print("Successfully connected to MCP server!")
|
||||
print(f"Available tools ({len(tools)}):")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
name = tool.get("name", "Unknown")
|
||||
description = tool.get("description", "No description available")
|
||||
print(f"\n{i}. {name}")
|
||||
print(
|
||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||
)
|
||||
|
||||
# Show input schema if available
|
||||
schema = tool.get("inputSchema", {})
|
||||
if schema.get("properties"):
|
||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
||||
print(
|
||||
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to MCP server: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure the MCP server is installed and accessible")
|
||||
print("2. Check if the server command is correct")
|
||||
print("3. Ensure you have proper authentication/credentials configured")
|
||||
print("4. Try running the MCP server command directly to test it")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load Slack messages via MCP server."""
|
||||
print(f"Connecting to Slack MCP server: {args.mcp_server}")
|
||||
|
||||
if args.workspace_name:
|
||||
print(f"Workspace: {args.workspace_name}")
|
||||
|
||||
# Filter out empty strings from channels
|
||||
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
|
||||
|
||||
if channels:
|
||||
print(f"Channels: {', '.join(channels)}")
|
||||
else:
|
||||
print("Fetching from all available channels")
|
||||
|
||||
concatenate = not args.no_concatenate_conversations
|
||||
print(
|
||||
f"Processing mode: {'Concatenated conversations' if concatenate else 'Individual messages'}"
|
||||
)
|
||||
|
||||
try:
|
||||
reader = SlackMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
workspace_name=args.workspace_name,
|
||||
concatenate_conversations=concatenate,
|
||||
max_messages_per_conversation=args.max_messages_per_channel,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
texts = await reader.read_slack_data(channels=channels)
|
||||
|
||||
if not texts:
|
||||
print("No messages found! This could mean:")
|
||||
print("- The MCP server couldn't fetch messages")
|
||||
print("- The specified channels don't exist or are empty")
|
||||
print("- Authentication issues with the Slack workspace")
|
||||
return []
|
||||
|
||||
print(f"Successfully loaded {len(texts)} text chunks from Slack")
|
||||
|
||||
# Show sample of what was loaded
|
||||
if texts:
|
||||
sample_text = texts[0][:200] + "..." if len(texts[0]) > 200 else texts[0]
|
||||
print("\nSample content:")
|
||||
print("-" * 40)
|
||||
print(sample_text)
|
||||
print("-" * 40)
|
||||
|
||||
return texts
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading Slack data: {e}")
|
||||
print("\nThis might be due to:")
|
||||
print("- MCP server connection issues")
|
||||
print("- Authentication problems")
|
||||
print("- Network connectivity issues")
|
||||
print("- Incorrect channel names")
|
||||
raise
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point with MCP connection testing."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Test connection if requested
|
||||
if args.test_connection:
|
||||
success = await self.test_mcp_connection(args)
|
||||
if not success:
|
||||
return
|
||||
print(
|
||||
"MCP server is working! You can now run without --test-connection to start indexing."
|
||||
)
|
||||
return
|
||||
|
||||
# Run the standard RAG pipeline
|
||||
await super().run()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the Slack MCP RAG application."""
|
||||
app = SlackMCPRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
1
apps/twitter_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Twitter MCP data integration for LEANN
|
||||
295
apps/twitter_data/twitter_mcp_reader.py
Normal file
@@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Twitter MCP Reader for LEANN
|
||||
|
||||
This module provides functionality to connect to Twitter MCP servers and fetch bookmark data
|
||||
for indexing in LEANN. It supports various Twitter MCP server implementations and provides
|
||||
flexible bookmark processing options.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TwitterMCPReader:
|
||||
"""
|
||||
Reader for Twitter bookmark data via MCP (Model Context Protocol) servers.
|
||||
|
||||
This class connects to Twitter MCP servers to fetch bookmark data and convert it
|
||||
into a format suitable for LEANN indexing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_server_command: str,
|
||||
username: Optional[str] = None,
|
||||
include_tweet_content: bool = True,
|
||||
include_metadata: bool = True,
|
||||
max_bookmarks: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the Twitter MCP Reader.
|
||||
|
||||
Args:
|
||||
mcp_server_command: Command to start the MCP server (e.g., 'twitter-mcp-server')
|
||||
username: Optional Twitter username to filter bookmarks
|
||||
include_tweet_content: Whether to include full tweet content
|
||||
include_metadata: Whether to include tweet metadata (likes, retweets, etc.)
|
||||
max_bookmarks: Maximum number of bookmarks to fetch
|
||||
"""
|
||||
self.mcp_server_command = mcp_server_command
|
||||
self.username = username
|
||||
self.include_tweet_content = include_tweet_content
|
||||
self.include_metadata = include_metadata
|
||||
self.max_bookmarks = max_bookmarks
|
||||
self.mcp_process = None
|
||||
|
||||
async def start_mcp_server(self):
|
||||
"""Start the MCP server process."""
|
||||
try:
|
||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
||||
*self.mcp_server_command.split(),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}")
|
||||
raise
|
||||
|
||||
async def stop_mcp_server(self):
|
||||
"""Stop the MCP server process."""
|
||||
if self.mcp_process:
|
||||
self.mcp_process.terminate()
|
||||
await self.mcp_process.wait()
|
||||
logger.info("Stopped MCP server")
|
||||
|
||||
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send a request to the MCP server and get response."""
|
||||
if not self.mcp_process:
|
||||
raise RuntimeError("MCP server not started")
|
||||
|
||||
request_json = json.dumps(request) + "\n"
|
||||
self.mcp_process.stdin.write(request_json.encode())
|
||||
await self.mcp_process.stdin.drain()
|
||||
|
||||
response_line = await self.mcp_process.stdout.readline()
|
||||
if not response_line:
|
||||
raise RuntimeError("No response from MCP server")
|
||||
|
||||
return json.loads(response_line.decode().strip())
|
||||
|
||||
async def initialize_mcp_connection(self):
|
||||
"""Initialize the MCP connection."""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "leann-twitter-reader", "version": "1.0.0"},
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(init_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
||||
|
||||
logger.info("MCP connection initialized successfully")
|
||||
|
||||
async def list_available_tools(self) -> list[dict[str, Any]]:
|
||||
"""List available tools from the MCP server."""
|
||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
||||
|
||||
response = await self.send_mcp_request(list_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
||||
|
||||
return response.get("result", {}).get("tools", [])
|
||||
|
||||
async def fetch_twitter_bookmarks(self, limit: Optional[int] = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch Twitter bookmarks using MCP tools.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of bookmarks to fetch
|
||||
|
||||
Returns:
|
||||
List of bookmark dictionaries
|
||||
"""
|
||||
tools = await self.list_available_tools()
|
||||
bookmark_tool = None
|
||||
|
||||
# Look for a tool that can fetch bookmarks
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if any(keyword in tool_name for keyword in ["bookmark", "saved", "favorite"]):
|
||||
bookmark_tool = tool
|
||||
break
|
||||
|
||||
if not bookmark_tool:
|
||||
raise RuntimeError("No bookmark fetching tool found in MCP server")
|
||||
|
||||
# Prepare tool call parameters
|
||||
tool_params = {}
|
||||
if limit or self.max_bookmarks:
|
||||
tool_params["limit"] = limit or self.max_bookmarks
|
||||
if self.username:
|
||||
tool_params["username"] = self.username
|
||||
|
||||
fetch_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {"name": bookmark_tool["name"], "arguments": tool_params},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(fetch_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to fetch bookmarks: {response['error']}")
|
||||
|
||||
# Extract bookmarks from response
|
||||
result = response.get("result", {})
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
try:
|
||||
bookmarks = json.loads(content["text"])
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, treat as plain text
|
||||
bookmarks = [{"text": content["text"], "source": "twitter"}]
|
||||
else:
|
||||
bookmarks = result["content"]
|
||||
else:
|
||||
bookmarks = result.get("bookmarks", result.get("tweets", [result]))
|
||||
|
||||
return bookmarks if isinstance(bookmarks, list) else [bookmarks]
|
||||
|
||||
def _format_bookmark(self, bookmark: dict[str, Any]) -> str:
|
||||
"""Format a single bookmark for indexing."""
|
||||
# Extract tweet information
|
||||
text = bookmark.get("text", bookmark.get("content", ""))
|
||||
author = bookmark.get(
|
||||
"author", bookmark.get("username", bookmark.get("user", {}).get("username", "Unknown"))
|
||||
)
|
||||
timestamp = bookmark.get("created_at", bookmark.get("timestamp", ""))
|
||||
url = bookmark.get("url", bookmark.get("tweet_url", ""))
|
||||
|
||||
# Extract metadata if available
|
||||
likes = bookmark.get("likes", bookmark.get("favorite_count", 0))
|
||||
retweets = bookmark.get("retweets", bookmark.get("retweet_count", 0))
|
||||
replies = bookmark.get("replies", bookmark.get("reply_count", 0))
|
||||
|
||||
# Build formatted bookmark
|
||||
parts = []
|
||||
|
||||
# Header
|
||||
parts.append("=== Twitter Bookmark ===")
|
||||
|
||||
if author:
|
||||
parts.append(f"Author: @{author}")
|
||||
|
||||
if timestamp:
|
||||
# Format timestamp if it's a standard format
|
||||
try:
|
||||
import datetime
|
||||
|
||||
if "T" in str(timestamp): # ISO format
|
||||
dt = datetime.datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
formatted_time = str(timestamp)
|
||||
parts.append(f"Date: {formatted_time}")
|
||||
except (ValueError, TypeError):
|
||||
parts.append(f"Date: {timestamp}")
|
||||
|
||||
if url:
|
||||
parts.append(f"URL: {url}")
|
||||
|
||||
# Tweet content
|
||||
if text and self.include_tweet_content:
|
||||
parts.append("")
|
||||
parts.append("Content:")
|
||||
parts.append(text)
|
||||
|
||||
# Metadata
|
||||
if self.include_metadata and any([likes, retweets, replies]):
|
||||
parts.append("")
|
||||
parts.append("Engagement:")
|
||||
if likes:
|
||||
parts.append(f" Likes: {likes}")
|
||||
if retweets:
|
||||
parts.append(f" Retweets: {retweets}")
|
||||
if replies:
|
||||
parts.append(f" Replies: {replies}")
|
||||
|
||||
# Extract hashtags and mentions if available
|
||||
hashtags = bookmark.get("hashtags", [])
|
||||
mentions = bookmark.get("mentions", [])
|
||||
|
||||
if hashtags or mentions:
|
||||
parts.append("")
|
||||
if hashtags:
|
||||
parts.append(f"Hashtags: {', '.join(hashtags)}")
|
||||
if mentions:
|
||||
parts.append(f"Mentions: {', '.join(mentions)}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
async def read_twitter_bookmarks(self) -> list[str]:
|
||||
"""
|
||||
Read Twitter bookmark data and return formatted text chunks.
|
||||
|
||||
Returns:
|
||||
List of formatted text chunks ready for LEANN indexing
|
||||
"""
|
||||
try:
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
|
||||
print(f"Fetching up to {self.max_bookmarks} bookmarks...")
|
||||
if self.username:
|
||||
print(f"Filtering for user: @{self.username}")
|
||||
|
||||
bookmarks = await self.fetch_twitter_bookmarks()
|
||||
|
||||
if not bookmarks:
|
||||
print("No bookmarks found")
|
||||
return []
|
||||
|
||||
print(f"Processing {len(bookmarks)} bookmarks...")
|
||||
|
||||
all_texts = []
|
||||
processed_count = 0
|
||||
|
||||
for bookmark in bookmarks:
|
||||
try:
|
||||
formatted_bookmark = self._format_bookmark(bookmark)
|
||||
if formatted_bookmark.strip():
|
||||
all_texts.append(formatted_bookmark)
|
||||
processed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to format bookmark: {e}")
|
||||
continue
|
||||
|
||||
print(f"Successfully processed {processed_count} bookmarks")
|
||||
return all_texts
|
||||
|
||||
finally:
|
||||
await self.stop_mcp_server()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.stop_mcp_server()
|
||||
195
apps/twitter_rag.py
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Twitter RAG Application with MCP Support
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on Twitter bookmarks
|
||||
by connecting to Twitter MCP servers to fetch live data and index it in LEANN.
|
||||
|
||||
Usage:
|
||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --query "What articles did I bookmark about AI?"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
||||
|
||||
|
||||
class TwitterMCPRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for Twitter bookmarks via MCP servers.
|
||||
|
||||
This class provides a complete RAG pipeline for Twitter bookmark data, including
|
||||
MCP server connection, data fetching, indexing, and interactive chat.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Twitter MCP RAG",
|
||||
description="RAG application for Twitter bookmarks via MCP servers",
|
||||
default_index_name="twitter_bookmarks",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add Twitter MCP-specific arguments."""
|
||||
parser.add_argument(
|
||||
"--mcp-server",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Command to start the Twitter MCP server (e.g., 'twitter-mcp-server' or 'npx twitter-mcp-server')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--username", type=str, help="Twitter username to filter bookmarks (without @)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-bookmarks",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of bookmarks to fetch (default: 1000)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-tweet-content",
|
||||
action="store_true",
|
||||
help="Exclude tweet content, only include metadata",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-metadata",
|
||||
action="store_true",
|
||||
help="Exclude engagement metadata (likes, retweets, etc.)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test-connection",
|
||||
action="store_true",
|
||||
help="Test MCP server connection and list available tools without indexing",
|
||||
)
|
||||
|
||||
async def test_mcp_connection(self, args) -> bool:
|
||||
"""Test the MCP server connection and display available tools."""
|
||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
||||
|
||||
try:
|
||||
reader = TwitterMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
username=args.username,
|
||||
include_tweet_content=not args.no_tweet_content,
|
||||
include_metadata=not args.no_metadata,
|
||||
max_bookmarks=args.max_bookmarks,
|
||||
)
|
||||
|
||||
async with reader:
|
||||
tools = await reader.list_available_tools()
|
||||
|
||||
print("\n✅ Successfully connected to MCP server!")
|
||||
print(f"Available tools ({len(tools)}):")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
name = tool.get("name", "Unknown")
|
||||
description = tool.get("description", "No description available")
|
||||
print(f"\n{i}. {name}")
|
||||
print(
|
||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||
)
|
||||
|
||||
# Show input schema if available
|
||||
schema = tool.get("inputSchema", {})
|
||||
if schema.get("properties"):
|
||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
||||
print(
|
||||
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Failed to connect to MCP server: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure the Twitter MCP server is installed and accessible")
|
||||
print("2. Check if the server command is correct")
|
||||
print("3. Ensure you have proper Twitter API credentials configured")
|
||||
print("4. Verify your Twitter account has bookmarks to fetch")
|
||||
print("5. Try running the MCP server command directly to test it")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load Twitter bookmarks via MCP server."""
|
||||
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
|
||||
|
||||
if args.username:
|
||||
print(f"Username filter: @{args.username}")
|
||||
|
||||
print(f"Max bookmarks: {args.max_bookmarks}")
|
||||
print(f"Include tweet content: {not args.no_tweet_content}")
|
||||
print(f"Include metadata: {not args.no_metadata}")
|
||||
|
||||
try:
|
||||
reader = TwitterMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
username=args.username,
|
||||
include_tweet_content=not args.no_tweet_content,
|
||||
include_metadata=not args.no_metadata,
|
||||
max_bookmarks=args.max_bookmarks,
|
||||
)
|
||||
|
||||
texts = await reader.read_twitter_bookmarks()
|
||||
|
||||
if not texts:
|
||||
print("❌ No bookmarks found! This could mean:")
|
||||
print("- You don't have any bookmarks on Twitter")
|
||||
print("- The MCP server couldn't access your bookmarks")
|
||||
print("- Authentication issues with Twitter API")
|
||||
print("- The username filter didn't match any bookmarks")
|
||||
return []
|
||||
|
||||
print(f"✅ Successfully loaded {len(texts)} bookmarks from Twitter")
|
||||
|
||||
# Show sample of what was loaded
|
||||
if texts:
|
||||
sample_text = texts[0][:300] + "..." if len(texts[0]) > 300 else texts[0]
|
||||
print("\nSample bookmark:")
|
||||
print("-" * 50)
|
||||
print(sample_text)
|
||||
print("-" * 50)
|
||||
|
||||
return texts
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading Twitter bookmarks: {e}")
|
||||
print("\nThis might be due to:")
|
||||
print("- MCP server connection issues")
|
||||
print("- Twitter API authentication problems")
|
||||
print("- Network connectivity issues")
|
||||
print("- Rate limiting from Twitter API")
|
||||
raise
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point with MCP connection testing."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Test connection if requested
|
||||
if args.test_connection:
|
||||
success = await self.test_mcp_connection(args)
|
||||
if not success:
|
||||
return
|
||||
print(
|
||||
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
|
||||
)
|
||||
return
|
||||
|
||||
# Run the standard RAG pipeline
|
||||
await super().run()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the Twitter MCP RAG application."""
|
||||
app = TwitterMCPRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
0
benchmarks/__init__.py
Normal file
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
BM25 vs DiskANN Baselines
|
||||
|
||||
```bash
|
||||
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
|
||||
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
|
||||
```
|
||||
|
||||
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
|
||||
- Machine-specific; results measured locally with the current repo.
|
||||
|
||||
DiskANN (NQ queries, search-only)
|
||||
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
|
||||
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
|
||||
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
|
||||
|
||||
BM25
|
||||
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
|
||||
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
|
||||
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
|
||||
|
||||
Notes
|
||||
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
|
||||
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "pyserini"
|
||||
# ]
|
||||
# ///
|
||||
# sudo pacman -S jdk21-openjdk
|
||||
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
|
||||
# sudo archlinux-java status
|
||||
# sudo archlinux-java set java-21-openjdk
|
||||
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
|
||||
# fish_add_path --global $JAVA_HOME/bin
|
||||
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
|
||||
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from statistics import mean
|
||||
|
||||
|
||||
def load_queries(path: str, limit: int | None) -> list[str]:
|
||||
queries: list[str] = []
|
||||
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
|
||||
_, ext = os.path.splitext(path)
|
||||
if ext.lower() in {".jsonl", ".json"}:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
# Not strict JSONL? treat the whole line as the query
|
||||
queries.append(line)
|
||||
continue
|
||||
q = obj.get("query") or obj.get("text") or obj.get("question")
|
||||
if q:
|
||||
queries.append(str(q))
|
||||
else:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
s = line.strip()
|
||||
if s:
|
||||
queries.append(s)
|
||||
|
||||
if limit is not None and limit > 0:
|
||||
queries = queries[:limit]
|
||||
return queries
|
||||
|
||||
|
||||
def percentile(values: list[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
s = sorted(values)
|
||||
k = (len(s) - 1) * (p / 100.0)
|
||||
f = int(k)
|
||||
c = min(f + 1, len(s) - 1)
|
||||
if f == c:
|
||||
return s[f]
|
||||
return s[f] + (s[c] - s[f]) * (k - f)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
|
||||
ap.add_argument(
|
||||
"--bm25-index",
|
||||
default="benchmarks/data/indices/bm25_index",
|
||||
help="Path to Pyserini Lucene index directory",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--queries",
|
||||
default="benchmarks/data/queries/nq_open.jsonl",
|
||||
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
|
||||
)
|
||||
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
|
||||
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
|
||||
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
|
||||
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
|
||||
ap.add_argument(
|
||||
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
|
||||
)
|
||||
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
|
||||
args = ap.parse_args()
|
||||
|
||||
try:
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
except Exception:
|
||||
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
|
||||
raise
|
||||
|
||||
if not os.path.isdir(args.bm25_index):
|
||||
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
queries = load_queries(args.queries, args.limit)
|
||||
if not queries:
|
||||
print("No queries loaded.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loaded {len(queries)} queries from {args.queries}")
|
||||
print(f"Opening BM25 index: {args.bm25_index}")
|
||||
searcher = LuceneSearcher(args.bm25_index)
|
||||
# Some builds of pyserini require explicit set_bm25; others ignore
|
||||
try:
|
||||
searcher.set_bm25(k1=args.k1, b=args.b)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
latencies: list[float] = []
|
||||
total_searches = 0
|
||||
|
||||
# Warmup
|
||||
for i in range(min(args.warmup, len(queries))):
|
||||
_ = searcher.search(queries[i], k=args.k)
|
||||
|
||||
t0 = time.time()
|
||||
for i, q in enumerate(queries):
|
||||
t1 = time.time()
|
||||
hits = searcher.search(q, k=args.k)
|
||||
t2 = time.time()
|
||||
latencies.append(t2 - t1)
|
||||
total_searches += 1
|
||||
|
||||
if args.fetch_docs:
|
||||
# Optional doc fetch to include I/O time
|
||||
for h in hits:
|
||||
try:
|
||||
_ = searcher.doc(h.docid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (i + 1) % 50 == 0:
|
||||
print(f"Processed {i + 1}/{len(queries)} queries")
|
||||
|
||||
t1 = time.time()
|
||||
total_time = t1 - t0
|
||||
|
||||
if latencies:
|
||||
avg = mean(latencies)
|
||||
p50 = percentile(latencies, 50)
|
||||
p90 = percentile(latencies, 90)
|
||||
p95 = percentile(latencies, 95)
|
||||
p99 = percentile(latencies, 99)
|
||||
qps = total_searches / total_time if total_time > 0 else 0.0
|
||||
else:
|
||||
avg = p50 = p90 = p95 = p99 = qps = 0.0
|
||||
|
||||
print("BM25 Latency Report")
|
||||
print(f" queries: {total_searches}")
|
||||
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
|
||||
print(f" avg per query: {avg:.6f} s")
|
||||
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
|
||||
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
|
||||
|
||||
if args.report:
|
||||
payload = {
|
||||
"queries": total_searches,
|
||||
"k": args.k,
|
||||
"k1": args.k1,
|
||||
"b": args.b,
|
||||
"avg_s": avg,
|
||||
"p50_s": p50,
|
||||
"p90_s": p90,
|
||||
"p95_s": p95,
|
||||
"p99_s": p99,
|
||||
"total_time_s": total_time,
|
||||
"qps": qps,
|
||||
"index_dir": os.path.abspath(args.bm25_index),
|
||||
"fetch_docs": bool(args.fetch_docs),
|
||||
}
|
||||
with open(args.report, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2)
|
||||
print(f"Saved report to {args.report}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "leann-backend-diskann"
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_queries(path: Path, limit: int | None) -> list[str]:
|
||||
out: list[str] = []
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
obj = json.loads(line)
|
||||
out.append(obj["query"])
|
||||
if limit and len(out) >= limit:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(
|
||||
description="DiskANN baseline on real NQ queries (search-only timing)"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--index-dir",
|
||||
default="benchmarks/data/indices/diskann_rpj_wiki",
|
||||
help="Directory containing DiskANN files",
|
||||
)
|
||||
ap.add_argument("--index-prefix", default="ann")
|
||||
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
|
||||
ap.add_argument("--num-queries", type=int, default=200)
|
||||
ap.add_argument("--top-k", type=int, default=10)
|
||||
ap.add_argument("--complexity", type=int, default=62)
|
||||
ap.add_argument("--threads", type=int, default=1)
|
||||
ap.add_argument("--beam-width", type=int, default=1)
|
||||
ap.add_argument("--cache-mechanism", type=int, default=2)
|
||||
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
|
||||
args = ap.parse_args()
|
||||
|
||||
index_dir = Path(args.index_dir).resolve()
|
||||
if not index_dir.is_dir():
|
||||
raise SystemExit(f"Index dir not found: {index_dir}")
|
||||
|
||||
qpath = Path(args.queries_file).resolve()
|
||||
if not qpath.exists():
|
||||
raise SystemExit(f"Queries file not found: {qpath}")
|
||||
|
||||
queries = load_queries(qpath, args.num_queries)
|
||||
print(f"Loaded {len(queries)} queries from {qpath}")
|
||||
|
||||
# Compute embeddings once (exclude from timing)
|
||||
from leann.api import compute_embeddings as _compute
|
||||
|
||||
embs = _compute(
|
||||
queries,
|
||||
model_name="facebook/contriever-msmarco",
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
if embs.ndim != 2:
|
||||
raise SystemExit("Embedding compute failed or returned wrong shape")
|
||||
|
||||
# Build searcher
|
||||
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
|
||||
|
||||
index_prefix_path = str(index_dir / args.index_prefix)
|
||||
searcher = _DiskannSearcher(
|
||||
index_prefix_path,
|
||||
num_threads=int(args.threads),
|
||||
cache_mechanism=int(args.cache_mechanism),
|
||||
num_nodes_to_cache=int(args.num_nodes_to_cache),
|
||||
)
|
||||
|
||||
# Warmup (not timed)
|
||||
_ = searcher.search(
|
||||
embs[0:1],
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
batch_recompute=False,
|
||||
dedup_node_dis=False,
|
||||
)
|
||||
|
||||
# Timed loop
|
||||
times: list[float] = []
|
||||
for i in range(embs.shape[0]):
|
||||
t0 = time.time()
|
||||
_ = searcher.search(
|
||||
embs[i : i + 1],
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
batch_recompute=False,
|
||||
dedup_node_dis=False,
|
||||
)
|
||||
times.append(time.time() - t0)
|
||||
|
||||
times_sorted = sorted(times)
|
||||
avg = float(sum(times) / len(times))
|
||||
p50 = times_sorted[len(times) // 2]
|
||||
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
|
||||
|
||||
print("\nDiskANN (NQ, search-only) Report")
|
||||
print(f" queries: {len(times)}")
|
||||
print(
|
||||
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
|
||||
)
|
||||
print(f" avg per query: {avg:.6f} s")
|
||||
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
|
||||
print(f" QPS: {1.0 / avg:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
141
benchmarks/enron_emails/README.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Enron Emails Benchmark
|
||||
|
||||
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
|
||||
|
||||
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
|
||||
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
|
||||
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
|
||||
|
||||
## Layout
|
||||
|
||||
benchmarks/enron_emails/
|
||||
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
|
||||
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
|
||||
- data/: Generated passages, queries, embeddings-related files
|
||||
- baseline/: FAISS Flat baseline files
|
||||
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
|
||||
|
||||
## Quickstart
|
||||
|
||||
1) Prepare the data and index
|
||||
|
||||
cd benchmarks/enron_emails
|
||||
python setup_enron_emails.py --data-dir data
|
||||
|
||||
Notes:
|
||||
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
|
||||
Alternatively, pass a local path to `--emails-csv`.
|
||||
|
||||
Notes:
|
||||
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
|
||||
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
|
||||
|
||||
2) Run recall evaluation (Stage 2)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
|
||||
|
||||
3) Complexity sweep (Stage 3)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
|
||||
|
||||
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
|
||||
|
||||
4) Index comparison (Stage 4)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
|
||||
|
||||
5) Generation evaluation (Stage 5)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
|
||||
|
||||
6) Combined index + generation evaluation (Stages 4+5, recommended)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
|
||||
|
||||
Notes:
|
||||
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
|
||||
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
|
||||
- `--baseline-dir` defaults to `baseline`
|
||||
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
|
||||
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
|
||||
- `--model-name` defaults to `Qwen/Qwen3-8B`
|
||||
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
|
||||
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
|
||||
|
||||
Optional flags:
|
||||
- --queries data/evaluation_queries.jsonl (custom queries file)
|
||||
- --baseline-dir baseline (where FAISS baseline lives)
|
||||
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
|
||||
- --llm-backend hf|vllm (LLM backend for generation)
|
||||
- --model-name Qwen/Qwen3-8B (LLM model for generation)
|
||||
- --max-queries 1000 (limit number of queries for evaluation)
|
||||
|
||||
## Files Produced
|
||||
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
|
||||
- data/enron_index_hnsw.leann.*: LEANN index files
|
||||
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
|
||||
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
|
||||
|
||||
## Notes
|
||||
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
|
||||
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
|
||||
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
|
||||
|
||||
## Stages Summary
|
||||
|
||||
- Stage 2 (Recall@3):
|
||||
- Compares LEANN vs FAISS Flat baseline on Recall@3.
|
||||
- Compact index runs with `recompute_embeddings=True`.
|
||||
|
||||
- Stage 3 (Binary Search for Complexity):
|
||||
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
|
||||
|
||||
- Stage 4 (Index Comparison):
|
||||
- Reports .index-only sizes for compact vs non-compact.
|
||||
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
|
||||
- Stores retrieval results for Stage 5 generation evaluation.
|
||||
- Fails fast if compact recompute cannot run.
|
||||
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
|
||||
- First from the current run (when running `--stage all`), otherwise
|
||||
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
|
||||
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
|
||||
|
||||
- Stage 5 (Generation Evaluation):
|
||||
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
|
||||
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
|
||||
- Measures generation timing separately from search timing.
|
||||
- Requires Stage 4 results (no additional searching performed).
|
||||
|
||||
## Example Results
|
||||
|
||||
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
|
||||
|
||||
- Stage 3 (Binary Search):
|
||||
- Minimal complexity achieving 90% Recall@3: 88
|
||||
- Sampled points:
|
||||
- C=8 → 59.9% Recall@3
|
||||
- C=72 → 89.4% Recall@3
|
||||
- C=88 → 90.2% Recall@3
|
||||
- C=96 → 90.7% Recall@3
|
||||
- C=112 → 91.1% Recall@3
|
||||
- C=136 → 91.3% Recall@3
|
||||
- C=256 → 92.0% Recall@3
|
||||
|
||||
- Stage 4 (Index Sizes, .index only):
|
||||
- Compact: ~2.2 MB
|
||||
- Non-compact: ~82.0 MB
|
||||
- Storage saving by compact: ~97.3%
|
||||
|
||||
- Stage 4 (Search Timing, 988 queries, complexity=88):
|
||||
- Non-compact (no recompute): ~0.0075 s avg per query
|
||||
- Compact (with recompute): ~1.981 s avg per query
|
||||
- Speed ratio (non-compact/compact): ~0.0038x
|
||||
|
||||
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
|
||||
- Average generation time: ~22.302 s per query
|
||||
- Total queries processed: 988
|
||||
- LLM backend: HuggingFace transformers
|
||||
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||
|
||||
Full JSON output is saved by the script (see `--output`), e.g.:
|
||||
`benchmarks/enron_emails/results_enron_stage45.json`.
|
||||
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
downloads/
|
||||
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""
|
||||
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
||||
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
||||
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
||||
On errors, fail fast without fallbacks.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.passage_ids = pickle.load(f)
|
||||
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||
|
||||
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
for i, query in enumerate(queries):
|
||||
# Compute query embedding with the same model/mode as the index
|
||||
q_emb = compute_embeddings(
|
||||
[query],
|
||||
self.searcher.embedding_model,
|
||||
mode=self.searcher.embedding_mode,
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS Flat ground truth
|
||||
n = q_emb.shape[0]
|
||||
k = 3
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(q_emb),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN (may require embedding server depending on index configuration)
|
||||
results = self.searcher.search(
|
||||
query,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {r.id for r in results}
|
||||
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0
|
||||
total_recall += recall
|
||||
|
||||
if i < 3:
|
||||
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS: {list(baseline_ids)}")
|
||||
print(f" LEANN: {list(test_ids)}")
|
||||
print(f" ∩: {list(intersection)}")
|
||||
|
||||
avg = total_recall / max(1, len(queries))
|
||||
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
||||
return avg
|
||||
|
||||
def cleanup(self):
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class EnronEvaluator:
|
||||
def __init__(self, index_path: str):
|
||||
self.index_path = index_path
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
def load_queries(self, queries_file: str) -> list[str]:
|
||||
queries: list[str] = []
|
||||
with open(queries_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
data = json.loads(line)
|
||||
if "query" in data:
|
||||
queries.append(data["query"])
|
||||
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
||||
return queries
|
||||
|
||||
def cleanup(self):
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||
|
||||
print("📏 Analyzing index sizes (.index only)...")
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem
|
||||
|
||||
sizes: dict[str, float] = {}
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json"
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
||||
|
||||
sizes["index_only_mb"] = (
|
||||
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||
)
|
||||
sizes["metadata_mb"] = (
|
||||
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_text_mb"] = (
|
||||
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_index_mb"] = (
|
||||
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||
)
|
||||
|
||||
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
||||
return sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source and embedding model
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
# Load all passages and ids
|
||||
ids: list[str] = []
|
||||
texts: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
ids.append(str(data["id"]))
|
||||
texts.append(data["text"])
|
||||
|
||||
# Compute embeddings using the same method as LEANN
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
meta["embedding_model"],
|
||||
mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Build non-compact index with same passages and embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=False,
|
||||
is_compact=False,
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact"]
|
||||
},
|
||||
)
|
||||
|
||||
# Persist a pickle for build_index_from_embeddings
|
||||
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings), pf)
|
||||
|
||||
print(
|
||||
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||
)
|
||||
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
||||
) -> dict:
|
||||
"""Compare search speed for non-compact vs compact indexes."""
|
||||
import time
|
||||
|
||||
results: dict = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
"retrieval_results": [], # Store retrieval results for Stage 5
|
||||
}
|
||||
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
# Non-compact (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
for q in test_queries:
|
||||
t0 = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
results["non_compact"]["search_times"].append(time.time() - t0)
|
||||
|
||||
# Compact (with recompute). Fail fast if it cannot run.
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
for q in test_queries:
|
||||
t0 = time.time()
|
||||
docs = compact_searcher.search(
|
||||
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
results["compact"]["search_times"].append(time.time() - t0)
|
||||
|
||||
# Store retrieval results for Stage 5
|
||||
results["retrieval_results"].append(
|
||||
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
||||
)
|
||||
compact_searcher.cleanup()
|
||||
|
||||
if results["non_compact"]["search_times"]:
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
if results["compact"]["search_times"]:
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
if results["avg_search_times"].get("compact", 0) > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = 0.0
|
||||
|
||||
non_compact_searcher.cleanup()
|
||||
return results
|
||||
|
||||
def evaluate_complexity(
|
||||
self,
|
||||
recall_eval: "RecallEvaluator",
|
||||
queries: list[str],
|
||||
target: float = 0.90,
|
||||
c_min: int = 8,
|
||||
c_max: int = 256,
|
||||
max_iters: int = 10,
|
||||
recompute: bool = False,
|
||||
) -> dict:
|
||||
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
||||
|
||||
def round_c(x: int) -> int:
|
||||
# snap to multiple of 8 like other benches typically do
|
||||
return max(1, int((x + 7) // 8) * 8)
|
||||
|
||||
metrics: list[dict] = []
|
||||
|
||||
lo = round_c(c_min)
|
||||
hi = round_c(c_max)
|
||||
|
||||
print(
|
||||
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
||||
)
|
||||
|
||||
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
||||
r_lo = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=lo, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
||||
r_hi = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=hi, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||
|
||||
cap = 1024
|
||||
while r_hi < target and hi < cap:
|
||||
lo = hi
|
||||
r_lo = r_hi
|
||||
hi = round_c(hi * 2)
|
||||
r_hi = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=hi, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||
|
||||
if r_hi < target:
|
||||
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
||||
print("📈 Observations:")
|
||||
for m in metrics:
|
||||
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
||||
|
||||
# Binary search within [lo, hi]
|
||||
best = hi
|
||||
iters = 0
|
||||
while lo < hi and iters < max_iters:
|
||||
mid = round_c((lo + hi) // 2)
|
||||
r_mid = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=mid, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
||||
if r_mid >= target:
|
||||
best = mid
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
||||
iters += 1
|
||||
|
||||
print("📈 Binary search results (sampled points):")
|
||||
# Print unique complexity entries ordered by complexity
|
||||
for m in sorted(
|
||||
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
||||
):
|
||||
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
||||
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument(
|
||||
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "5", "all", "45"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument(
|
||||
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||
)
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
||||
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve queries file: if default path not found, fall back to index's directory
|
||||
if not os.path.exists(args.queries):
|
||||
from pathlib import Path
|
||||
|
||||
idx_dir = Path(args.index).parent
|
||||
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
||||
if fallback_q.exists():
|
||||
args.queries = str(fallback_q)
|
||||
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
||||
raise SystemExit(1)
|
||||
|
||||
results_out: dict = {}
|
||||
|
||||
if args.stage in ("2", "all"):
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
queries = queries[:10]
|
||||
print(f"🧪 Using first {len(queries)} queries")
|
||||
|
||||
complexity = args.complexity or 64
|
||||
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
||||
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
||||
evaluator.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
if args.stage in ("3", "all"):
|
||||
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
queries = queries[: args.max_queries]
|
||||
print(f"🧪 Using first {len(queries)} queries")
|
||||
|
||||
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
||||
from pathlib import Path
|
||||
|
||||
index_path = Path(args.index)
|
||||
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact evaluator for binary search with recompute=False
|
||||
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
sweep = enron_eval.evaluate_complexity(
|
||||
evaluator_nc, queries, target=args.target_recall, recompute=False
|
||||
)
|
||||
results_out["stage3"] = sweep
|
||||
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
||||
from pathlib import Path
|
||||
|
||||
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"stage3": sweep}, f, indent=2)
|
||||
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
||||
evaluator_nc.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 3 completed!\n")
|
||||
|
||||
if args.stage in ("4", "all", "45"):
|
||||
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
test_q = queries[: min(args.max_queries, len(queries))]
|
||||
|
||||
current_sizes = enron_eval.analyze_index_sizes()
|
||||
# Build non-compact index for comparison (no fallback)
|
||||
from pathlib import Path
|
||||
|
||||
index_path = Path(args.index)
|
||||
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
||||
nc_eval = EnronEvaluator(non_compact_path)
|
||||
|
||||
if (
|
||||
current_sizes.get("index_only_mb", 0) > 0
|
||||
and non_compact_sizes.get("index_only_mb", 0) > 0
|
||||
):
|
||||
storage_saving_percent = max(
|
||||
0.0,
|
||||
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
||||
)
|
||||
else:
|
||||
storage_saving_percent = 0.0
|
||||
|
||||
if args.complexity is None:
|
||||
# Prefer in-session Stage 3 result
|
||||
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
||||
complexity = results_out["stage3"]["best_complexity"]
|
||||
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
||||
else:
|
||||
# Try to load last saved Stage 3 result near index
|
||||
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||
if default_stage3_path.exists():
|
||||
with open(default_stage3_path, encoding="utf-8") as f:
|
||||
prev = json.load(f)
|
||||
complexity = prev.get("stage3", {}).get("best_complexity")
|
||||
if complexity is None:
|
||||
raise SystemExit(
|
||||
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
|
||||
)
|
||||
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
||||
else:
|
||||
raise SystemExit(
|
||||
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
||||
)
|
||||
else:
|
||||
complexity = args.complexity
|
||||
|
||||
comp = enron_eval.compare_index_performance(
|
||||
non_compact_path, args.index, test_q, complexity=complexity
|
||||
)
|
||||
results_out["stage4"] = {
|
||||
"current_index": current_sizes,
|
||||
"non_compact_index": non_compact_sizes,
|
||||
"storage_saving_percent": storage_saving_percent,
|
||||
"performance_comparison": comp,
|
||||
}
|
||||
nc_eval.cleanup()
|
||||
evaluator.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage in ("5", "all"):
|
||||
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
||||
|
||||
# Check if Stage 4 results exist
|
||||
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
||||
print("❌ Stage 5 requires Stage 4 retrieval results")
|
||||
print("💡 Run Stage 4 first or use --stage all")
|
||||
raise SystemExit(1)
|
||||
|
||||
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
||||
if not retrieval_results:
|
||||
print("❌ No retrieval results found from Stage 4")
|
||||
raise SystemExit(1)
|
||||
|
||||
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
||||
|
||||
# Load LLM
|
||||
try:
|
||||
if args.llm_backend == "hf":
|
||||
tokenizer, model = load_hf_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_hf(tokenizer, model, prompt)
|
||||
else: # vllm
|
||||
llm, sampling_params = load_vllm_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_vllm(llm, sampling_params, prompt)
|
||||
|
||||
# Run generation using stored retrieval results
|
||||
import time
|
||||
|
||||
from llm_utils import create_prompt
|
||||
|
||||
generation_times = []
|
||||
responses = []
|
||||
|
||||
print("🤖 Running generation on pre-retrieved results...")
|
||||
for i, item in enumerate(retrieval_results):
|
||||
query = item["query"]
|
||||
retrieved_docs = item["retrieved_docs"]
|
||||
|
||||
# Prepare context from retrieved docs
|
||||
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
||||
prompt = create_prompt(context, query, "emails")
|
||||
|
||||
# Time generation only
|
||||
gen_start = time.time()
|
||||
response = llm_func(prompt)
|
||||
gen_time = time.time() - gen_start
|
||||
|
||||
generation_times.append(gen_time)
|
||||
responses.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
||||
|
||||
avg_gen_time = sum(generation_times) / len(generation_times)
|
||||
|
||||
print("\n📊 Generation Results:")
|
||||
print(f" Total Queries: {len(retrieval_results)}")
|
||||
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
||||
print(" (Search time from Stage 4)")
|
||||
|
||||
results_out["stage5"] = {
|
||||
"total_queries": len(retrieval_results),
|
||||
"avg_generation_time": avg_gen_time,
|
||||
"generation_times": generation_times,
|
||||
"responses": responses,
|
||||
}
|
||||
|
||||
# Show sample results
|
||||
print("\n📝 Sample Results:")
|
||||
for i in range(min(3, len(retrieval_results))):
|
||||
query = retrieval_results[i]["query"]
|
||||
response = responses[i]
|
||||
print(f" Q{i + 1}: {query[:60]}...")
|
||||
print(f" A{i + 1}: {response[:100]}...")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation evaluation failed: {e}")
|
||||
print("💡 Make sure transformers/vllm is installed and model is available")
|
||||
|
||||
print("✅ Stage 5 completed!\n")
|
||||
|
||||
if args.output and results_out:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(results_out, f, indent=2)
|
||||
print(f"📝 Saved results to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
Enron Emails Benchmark Setup Script
|
||||
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from email import message_from_string
|
||||
from email.policy import default
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
|
||||
class EnronSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
|
||||
self.index_path = self.data_dir / "enron_index_hnsw.leann"
|
||||
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||
self.downloads_dir = self.data_dir / "downloads"
|
||||
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ----------------------------
|
||||
# Dataset acquisition
|
||||
# ----------------------------
|
||||
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
|
||||
"""Return a path to emails.csv, downloading from Kaggle if needed."""
|
||||
if emails_csv:
|
||||
p = Path(emails_csv)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
|
||||
return str(p)
|
||||
|
||||
print(
|
||||
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
|
||||
)
|
||||
try:
|
||||
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||
|
||||
api = KaggleApi()
|
||||
api.authenticate()
|
||||
api.dataset_download_files(
|
||||
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
|
||||
)
|
||||
candidate = self.downloads_dir / "emails.csv"
|
||||
if candidate.exists():
|
||||
print(f"✅ Downloaded emails.csv: {candidate}")
|
||||
return str(candidate)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
|
||||
)
|
||||
print(
|
||||
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
|
||||
)
|
||||
raise e
|
||||
|
||||
# ----------------------------
|
||||
# Data preparation
|
||||
# ----------------------------
|
||||
@staticmethod
|
||||
def _extract_message_id(raw_email: str) -> str:
|
||||
msg = message_from_string(raw_email, policy=default)
|
||||
val = msg.get("Message-ID", "")
|
||||
if val.startswith("<") and val.endswith(">"):
|
||||
val = val[1:-1]
|
||||
return val or ""
|
||||
|
||||
@staticmethod
|
||||
def _split_header_body(raw_email: str) -> tuple[str, str]:
|
||||
parts = raw_email.split("\n\n", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[0].strip(), parts[1].strip()
|
||||
# Heuristic fallback
|
||||
first_lines = raw_email.splitlines()
|
||||
if first_lines and ":" in first_lines[0]:
|
||||
return raw_email.strip(), ""
|
||||
return "", raw_email.strip()
|
||||
|
||||
@staticmethod
|
||||
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
if chunk_words <= 0:
|
||||
return [text]
|
||||
words = text.split()
|
||||
if not words:
|
||||
return []
|
||||
limit = len(words)
|
||||
if not keep_last:
|
||||
limit = (len(words) // chunk_words) * chunk_words
|
||||
if limit == 0:
|
||||
return []
|
||||
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
|
||||
return [c for c in (s.strip() for s in chunks) if c]
|
||||
|
||||
def _iter_passages_from_csv(
|
||||
self,
|
||||
emails_csv: Path,
|
||||
chunk_words: int = 256,
|
||||
keep_last_header: bool = True,
|
||||
keep_last_body: bool = True,
|
||||
max_emails: int | None = None,
|
||||
) -> Iterable[dict]:
|
||||
with open(emails_csv, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
count = 0
|
||||
for i, row in enumerate(reader):
|
||||
if max_emails is not None and count >= max_emails:
|
||||
break
|
||||
|
||||
raw_message = row.get("message", "")
|
||||
email_file_id = row.get("file", "")
|
||||
|
||||
if not raw_message.strip():
|
||||
continue
|
||||
|
||||
message_id = self._extract_message_id(raw_message)
|
||||
if not message_id:
|
||||
# Fallback ID based on CSV position and file path
|
||||
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
|
||||
message_id = f"enron_{i}_{safe_file}"
|
||||
|
||||
header, body = self._split_header_body(raw_message)
|
||||
|
||||
# Header chunks
|
||||
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
|
||||
yield {
|
||||
"text": chunk,
|
||||
"metadata": {
|
||||
"message_id": message_id,
|
||||
"is_header": True,
|
||||
"email_file_id": email_file_id,
|
||||
},
|
||||
}
|
||||
|
||||
# Body chunks
|
||||
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
|
||||
yield {
|
||||
"text": chunk,
|
||||
"metadata": {
|
||||
"message_id": message_id,
|
||||
"is_header": False,
|
||||
"email_file_id": email_file_id,
|
||||
},
|
||||
}
|
||||
|
||||
count += 1
|
||||
|
||||
# ----------------------------
|
||||
# Build LEANN index and FAISS baseline
|
||||
# ----------------------------
|
||||
def build_leann_index(
|
||||
self,
|
||||
emails_csv: Optional[str],
|
||||
backend: str = "hnsw",
|
||||
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
chunk_words: int = 256,
|
||||
max_emails: int | None = None,
|
||||
) -> str:
|
||||
emails_csv_path = self.ensure_emails_csv(emails_csv)
|
||||
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=True,
|
||||
is_compact=True,
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Stream passages and add to builder
|
||||
preview_written = 0
|
||||
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
|
||||
for p in self._iter_passages_from_csv(
|
||||
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
|
||||
):
|
||||
builder.add_text(p["text"], metadata=p["metadata"])
|
||||
if preview_written < 200:
|
||||
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
|
||||
preview_written += 1
|
||||
|
||||
print(f"🔨 Building index at {self.index_path}...")
|
||||
builder.build_index(str(self.index_path))
|
||||
print("✅ LEANN index built!")
|
||||
return str(self.index_path)
|
||||
|
||||
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
|
||||
print("🔨 Building FAISS Flat baseline from LEANN passages...")
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from leann.api import compute_embeddings
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Read meta for passage source and embedding model
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
embedding_model = meta["embedding_model"]
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
if not os.path.isabs(passage_file):
|
||||
index_dir = os.path.dirname(index_path)
|
||||
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||
|
||||
# Load passages from builder output so IDs match LEANN
|
||||
passages: list[str] = []
|
||||
passage_ids: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"]) # builder-assigned ID
|
||||
|
||||
print(f"📄 Loaded {len(passages)} passages for baseline")
|
||||
print(f"🤖 Embedding model: {embedding_model}")
|
||||
|
||||
embeddings = compute_embeddings(
|
||||
passages,
|
||||
embedding_model,
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
)
|
||||
|
||||
# Build FAISS IndexFlatIP
|
||||
dim = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dim)
|
||||
emb_f32 = embeddings.astype(np.float32)
|
||||
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
|
||||
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as pf:
|
||||
pickle.dump(passage_ids, pf)
|
||||
|
||||
print(f"✅ FAISS baseline saved: {baseline_path}")
|
||||
print(f"✅ Metadata saved: {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
return baseline_path
|
||||
|
||||
# ----------------------------
|
||||
# Queries (optional): prepare evaluation queries file
|
||||
# ----------------------------
|
||||
def prepare_queries(self, min_realism: float = 0.85) -> Path:
|
||||
print(
|
||||
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
|
||||
)
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to load dataset: {e}")
|
||||
return self.queries_file
|
||||
|
||||
kept = 0
|
||||
with open(self.queries_file, "w", encoding="utf-8") as out:
|
||||
for i, item in enumerate(ds):
|
||||
how_realistic = float(item.get("how_realistic", 0.0))
|
||||
if how_realistic < min_realism:
|
||||
continue
|
||||
qid = str(item.get("id", f"enron_q_{i}"))
|
||||
query = item.get("question", "")
|
||||
if not query:
|
||||
continue
|
||||
record = {
|
||||
"id": qid,
|
||||
"query": query,
|
||||
# For reference only, not used in recall metric below
|
||||
"gt_message_ids": item.get("message_ids", []),
|
||||
}
|
||||
out.write(json.dumps(record) + "\n")
|
||||
kept += 1
|
||||
print(f"✅ Wrote {kept} queries to {self.queries_file}")
|
||||
return self.queries_file
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
|
||||
parser.add_argument(
|
||||
"--emails-csv",
|
||||
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
|
||||
)
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model for LEANN",
|
||||
)
|
||||
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
|
||||
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
|
||||
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setup = EnronSetup(args.data_dir)
|
||||
|
||||
# Build index
|
||||
if not args.skip_build:
|
||||
index_path = setup.build_leann_index(
|
||||
emails_csv=args.emails_csv,
|
||||
backend=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
chunk_words=args.chunk_words,
|
||||
max_emails=args.max_emails,
|
||||
)
|
||||
|
||||
# Build FAISS baseline from the same passages & embeddings
|
||||
setup.build_faiss_flat_baseline(index_path)
|
||||
else:
|
||||
print("⏭️ Skipping LEANN index build and baseline")
|
||||
|
||||
# Queries file (optional)
|
||||
if not args.skip_queries:
|
||||
setup.prepare_queries()
|
||||
else:
|
||||
print("⏭️ Skipping query preparation")
|
||||
|
||||
print("\n🎉 Enron Emails setup completed!")
|
||||
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||
print("Next steps:")
|
||||
print(
|
||||
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
115
benchmarks/financebench/README.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# FinanceBench Benchmark for LEANN-RAG
|
||||
|
||||
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
|
||||
|
||||
## Dataset
|
||||
|
||||
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
|
||||
- **Questions**: 150 financial Q&A examples
|
||||
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
|
||||
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
|
||||
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
benchmarks/financebench/
|
||||
├── setup_financebench.py # Downloads PDFs and builds index
|
||||
├── evaluate_financebench.py # Intelligent evaluation script
|
||||
├── data/
|
||||
│ ├── financebench_merged.jsonl # Q&A dataset
|
||||
│ ├── pdfs/ # Downloaded financial documents
|
||||
│ └── index/ # LEANN indexes
|
||||
│ └── financebench_full_hnsw.leann
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Setup (Download & Build Index)
|
||||
|
||||
```bash
|
||||
cd benchmarks/financebench
|
||||
python setup_financebench.py
|
||||
```
|
||||
|
||||
This will:
|
||||
- Download the 150 Q&A examples
|
||||
- Download all 368 PDF documents (parallel processing)
|
||||
- Build a LEANN index from 53K+ text chunks
|
||||
- Verify setup with test query
|
||||
|
||||
### 2. Evaluation
|
||||
|
||||
```bash
|
||||
# Basic retrieval evaluation
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
|
||||
|
||||
|
||||
# RAG generation evaluation with Qwen3-8B
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
|
||||
```
|
||||
|
||||
## Evaluation Methods
|
||||
|
||||
### Retrieval Evaluation
|
||||
Uses intelligent matching with three strategies:
|
||||
1. **Exact text overlap** - Direct substring matches
|
||||
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
|
||||
3. **Semantic similarity** - Word overlap with 20% threshold
|
||||
|
||||
### QA Evaluation
|
||||
LLM-based answer evaluation using GPT-4o:
|
||||
- Handles numerical rounding and equivalent representations
|
||||
- Considers fractions, percentages, and decimal equivalents
|
||||
- Evaluates semantic meaning rather than exact text match
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
|
||||
|
||||
**Retrieval Metrics:**
|
||||
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
|
||||
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
|
||||
- **Number Match Rate**: 120.7% (key financial figures matched)*
|
||||
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
|
||||
- **Average Search Time**: 0.097s
|
||||
|
||||
**QA Metrics:**
|
||||
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
|
||||
- **Average QA Time**: 4.71s (end-to-end response time)
|
||||
|
||||
**System Performance:**
|
||||
- **Index Size**: 53,985 chunks from 368 PDFs
|
||||
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
|
||||
|
||||
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
|
||||
|
||||
### LEANN-RAG Generation Performance (Qwen3-8B)
|
||||
|
||||
- **Stage 4 (Index Comparison):**
|
||||
- Compact Index: 5.0 MB
|
||||
- Non-compact Index: 172.2 MB
|
||||
- **Storage Saving**: 97.1%
|
||||
- **Search Performance**:
|
||||
- Non-compact (no recompute): 0.009s avg per query
|
||||
- Compact (with recompute): 2.203s avg per query
|
||||
- Speed ratio: 0.004x
|
||||
|
||||
**Generation Evaluation (20 queries, complexity=64):**
|
||||
- **Average Search Time**: 1.638s per query
|
||||
- **Average Generation Time**: 45.957s per query
|
||||
- **LLM Backend**: HuggingFace transformers
|
||||
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||
- **Total Questions Processed**: 20
|
||||
|
||||
## Options
|
||||
|
||||
```bash
|
||||
# Use different backends
|
||||
python setup_financebench.py --backend diskann
|
||||
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
|
||||
|
||||
# Use different embedding models
|
||||
python setup_financebench.py --embedding-model facebook/contriever
|
||||
```
|
||||
923
benchmarks/financebench/evaluate_financebench.py
Executable file
@@ -0,0 +1,923 @@
|
||||
"""
|
||||
FinanceBench Evaluation Script - Modular Recall-based Evaluation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
from leann import LeannChat, LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
# Load FAISS flat baseline
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.passage_ids = pickle.load(f)
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 for given queries at specified complexity"""
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = len(queries)
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
# Get ground truth: search with FAISS flat
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
query_embedding = compute_embeddings(
|
||||
[query],
|
||||
self.searcher.embedding_model,
|
||||
mode=self.searcher.embedding_mode,
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||
n = query_embedding.shape[0] # Number of queries
|
||||
k = 3 # Number of nearest neighbors
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(query_embedding),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
# Extract the results
|
||||
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN at specified complexity
|
||||
test_results = self.searcher.search(
|
||||
query,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {result.id for result in test_results}
|
||||
|
||||
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||
return avg_recall
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class FinanceBenchEvaluator:
|
||||
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
|
||||
self.index_path = index_path
|
||||
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
|
||||
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
self.chat = LeannChat(index_path) if openai_api_key else None
|
||||
|
||||
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
|
||||
"""Load FinanceBench dataset"""
|
||||
data = []
|
||||
with open(dataset_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data.append(json.loads(line))
|
||||
|
||||
print(f"📊 Loaded {len(data)} FinanceBench examples")
|
||||
return data
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes with and without embeddings"""
|
||||
|
||||
print("📏 Analyzing index sizes...")
|
||||
|
||||
# Get all index-related files
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem # Remove .leann extension
|
||||
|
||||
sizes = {}
|
||||
total_with_embeddings = 0
|
||||
|
||||
# Core index files
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||
|
||||
for file_path, name in [
|
||||
(index_file, "index"),
|
||||
(meta_file, "metadata"),
|
||||
(passages_file, "passages_text"),
|
||||
(passages_idx_file, "passages_index"),
|
||||
]:
|
||||
if file_path.exists():
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
sizes[name] = size_mb
|
||||
total_with_embeddings += size_mb
|
||||
|
||||
else:
|
||||
sizes[name] = 0
|
||||
|
||||
sizes["total_with_embeddings"] = total_with_embeddings
|
||||
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
|
||||
|
||||
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
|
||||
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
|
||||
|
||||
return sizes
|
||||
|
||||
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
|
||||
"""Create a compact index for comparison purposes"""
|
||||
print("🏗️ Building compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Build compact index with same passages
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||
is_compact=True, # Enable compact storage
|
||||
**meta.get("backend_kwargs", {}),
|
||||
)
|
||||
|
||||
# Load all passages
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||
|
||||
print(f"🔨 Building compact index at {compact_index_path}...")
|
||||
builder.build_index(compact_index_path)
|
||||
|
||||
# Analyze the compact index size
|
||||
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
|
||||
compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
compact_sizes["index_type"] = "compact"
|
||||
|
||||
return compact_sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison purposes"""
|
||||
print("🏗️ Building non-compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Build non-compact index with same passages
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=False, # Disable recompute (store embeddings)
|
||||
is_compact=False, # Disable compact storage
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact"]
|
||||
},
|
||||
)
|
||||
|
||||
# Load all passages
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||
|
||||
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
|
||||
builder.build_index(non_compact_index_path)
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
|
||||
) -> dict:
|
||||
"""Compare performance between non-compact and compact indexes"""
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
|
||||
import time
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
# Test queries
|
||||
test_queries = [item["question"] for item in test_data[:5]]
|
||||
|
||||
results = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
}
|
||||
|
||||
# Test non-compact index (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
query, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["non_compact"]["search_times"].append(search_time)
|
||||
|
||||
# Test compact index (with recompute)
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
_ = compact_searcher.search(
|
||||
query, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["compact"]["search_times"].append(search_time)
|
||||
|
||||
# Calculate averages
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
|
||||
# Performance ratio
|
||||
if results["avg_search_times"]["compact"] > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = float("inf")
|
||||
|
||||
print(
|
||||
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||
)
|
||||
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||
|
||||
# Cleanup
|
||||
non_compact_searcher.cleanup()
|
||||
compact_searcher.cleanup()
|
||||
|
||||
return results
|
||||
|
||||
def evaluate_timing_breakdown(
|
||||
self, data: list[dict], max_samples: Optional[int] = None
|
||||
) -> dict:
|
||||
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
|
||||
if not self.chat or not self.openai_client:
|
||||
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
|
||||
return {
|
||||
"total_questions": 0,
|
||||
"avg_search_time": 0.0,
|
||||
"avg_generation_time": 0.0,
|
||||
"avg_total_time": 0.0,
|
||||
"accuracy": 0.0,
|
||||
}
|
||||
|
||||
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
|
||||
|
||||
if max_samples:
|
||||
data = data[:max_samples]
|
||||
print(f"📝 Using first {max_samples} samples for timing evaluation")
|
||||
|
||||
search_times = []
|
||||
generation_times = []
|
||||
total_times = []
|
||||
correct_answers = 0
|
||||
|
||||
for i, item in enumerate(data):
|
||||
question = item["question"]
|
||||
ground_truth = item["answer"]
|
||||
|
||||
try:
|
||||
# Hack: Monkey-patch the ask method to capture internal timing
|
||||
original_ask = self.chat.ask
|
||||
captured_search_time = None
|
||||
captured_generation_time = None
|
||||
|
||||
def patched_ask(*args, **kwargs):
|
||||
nonlocal captured_search_time, captured_generation_time
|
||||
|
||||
# Time the search part
|
||||
search_start = time.time()
|
||||
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
|
||||
captured_search_time = time.time() - search_start
|
||||
|
||||
# Time the generation part
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
prompt = (
|
||||
"Here is some retrieved context that might help answer your question:\n\n"
|
||||
f"{context}\n\n"
|
||||
f"Question: {args[0]}\n\n"
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
generation_start = time.time()
|
||||
answer = self.chat.llm.ask(prompt)
|
||||
captured_generation_time = time.time() - generation_start
|
||||
|
||||
return answer
|
||||
|
||||
# Apply the patch
|
||||
self.chat.ask = patched_ask
|
||||
|
||||
# Time the total QA
|
||||
total_start = time.time()
|
||||
generated_answer = self.chat.ask(question)
|
||||
total_time = time.time() - total_start
|
||||
|
||||
# Restore original method
|
||||
self.chat.ask = original_ask
|
||||
|
||||
# Store the timings
|
||||
search_times.append(captured_search_time)
|
||||
generation_times.append(captured_generation_time)
|
||||
total_times.append(total_time)
|
||||
|
||||
# Check accuracy using LLM as judge
|
||||
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
|
||||
if is_correct:
|
||||
correct_answers += 1
|
||||
|
||||
status = "✅" if is_correct else "❌"
|
||||
print(
|
||||
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
|
||||
)
|
||||
print(f" GT: {ground_truth}")
|
||||
print(f" Gen: {generated_answer[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error: {e}")
|
||||
search_times.append(0.0)
|
||||
generation_times.append(0.0)
|
||||
total_times.append(0.0)
|
||||
|
||||
accuracy = correct_answers / len(data) if data else 0.0
|
||||
|
||||
metrics = {
|
||||
"total_questions": len(data),
|
||||
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
|
||||
"avg_generation_time": sum(generation_times) / len(generation_times)
|
||||
if generation_times
|
||||
else 0.0,
|
||||
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
|
||||
"accuracy": accuracy,
|
||||
"correct_answers": correct_answers,
|
||||
"search_times": search_times,
|
||||
"generation_times": generation_times,
|
||||
"total_times": total_times,
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
def _check_answer_accuracy(
|
||||
self, generated_answer: str, ground_truth: str, question: str
|
||||
) -> bool:
|
||||
"""Check if generated answer matches ground truth using LLM as judge"""
|
||||
judge_prompt = f"""You are an expert judge evaluating financial question answering.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Ground Truth Answer: {ground_truth}
|
||||
|
||||
Generated Answer: {generated_answer}
|
||||
|
||||
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
|
||||
1. Numerical accuracy (exact values, units, currency)
|
||||
2. Key financial concepts and terminology
|
||||
3. Overall factual correctness
|
||||
|
||||
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
|
||||
|
||||
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
|
||||
|
||||
try:
|
||||
judge_response = self.openai_client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
)
|
||||
judgment = judge_response.choices[0].message.content.strip().upper()
|
||||
return judgment == "CORRECT"
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Judge error: {e}, falling back to string matching")
|
||||
# Fallback to simple string matching
|
||||
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
|
||||
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
|
||||
return gt_clean in gen_clean
|
||||
|
||||
def _print_results(self, timing_metrics: dict):
|
||||
"""Print evaluation results"""
|
||||
print("\n🎯 EVALUATION RESULTS")
|
||||
print("=" * 50)
|
||||
|
||||
# Index comparison analysis
|
||||
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||
print("\n📏 Index Comparison Analysis:")
|
||||
current = timing_metrics["current_index"]
|
||||
non_compact = timing_metrics["non_compact_index"]
|
||||
|
||||
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
|
||||
print(
|
||||
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
|
||||
print(" Component breakdown (non-compact):")
|
||||
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||
|
||||
# Performance comparison
|
||||
if "performance_comparison" in timing_metrics:
|
||||
perf = timing_metrics["performance_comparison"]
|
||||
print("\n⚡ Performance Comparison:")
|
||||
print(
|
||||
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||
)
|
||||
print(
|
||||
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||
)
|
||||
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||
|
||||
# Legacy single index analysis (fallback)
|
||||
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||
print("\n📏 Index Size Analysis:")
|
||||
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
|
||||
|
||||
print("\n📊 Accuracy:")
|
||||
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
|
||||
print(
|
||||
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
|
||||
)
|
||||
|
||||
print("\n📊 Timing Breakdown:")
|
||||
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
|
||||
|
||||
if timing_metrics.get("avg_total_time", 0) > 0:
|
||||
search_pct = (
|
||||
timing_metrics.get("avg_search_time", 0)
|
||||
/ timing_metrics.get("avg_total_time", 1)
|
||||
* 100
|
||||
)
|
||||
gen_pct = (
|
||||
timing_metrics.get("avg_generation_time", 0)
|
||||
/ timing_metrics.get("avg_total_time", 1)
|
||||
* 100
|
||||
)
|
||||
print("\n📈 Time Distribution:")
|
||||
print(f" Search: {search_pct:.1f}%")
|
||||
print(f" Generation: {gen_pct:.1f}%")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "all"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument(
|
||||
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
|
||||
)
|
||||
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Check if baseline exists
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_financebench.py first to build the baseline")
|
||||
exit(1)
|
||||
|
||||
if args.stage == "2" or args.stage == "all":
|
||||
# Stage 2: Recall@3 evaluation
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
# Load FinanceBench queries for testing
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
queries = []
|
||||
with open(args.dataset, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
|
||||
# Test with more queries for robust measurement
|
||||
test_queries = queries[:2000]
|
||||
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||
|
||||
# Test with complexity 64
|
||||
complexity = 64
|
||||
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
|
||||
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
# Shared non-compact index path for Stage 3 and 4
|
||||
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||
complexity = args.complexity
|
||||
|
||||
if args.stage == "3" or args.stage == "all":
|
||||
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
|
||||
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||
print(
|
||||
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||
)
|
||||
|
||||
# Create non-compact index for binary search (will be reused in Stage 4)
|
||||
print("🏗️ Creating non-compact index for binary search...")
|
||||
evaluator = FinanceBenchEvaluator(args.index)
|
||||
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact index for binary search
|
||||
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
|
||||
# Load queries for testing
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
queries = []
|
||||
with open(args.dataset, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
|
||||
# Use more queries for robust measurement
|
||||
test_queries = queries[:200]
|
||||
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||
|
||||
# Binary search for 90% recall complexity (without recompute for speed)
|
||||
target_recall = 0.9
|
||||
min_complexity, max_complexity = 1, 32
|
||||
|
||||
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||
|
||||
best_complexity = None
|
||||
best_recall = 0.0
|
||||
|
||||
while min_complexity <= max_complexity:
|
||||
mid_complexity = (min_complexity + max_complexity) // 2
|
||||
|
||||
print(
|
||||
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||
)
|
||||
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_queries, mid_complexity, recompute_embeddings=False
|
||||
)
|
||||
|
||||
print(
|
||||
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||
)
|
||||
|
||||
if recall >= target_recall:
|
||||
best_complexity = mid_complexity
|
||||
best_recall = recall
|
||||
max_complexity = mid_complexity - 1
|
||||
print(" ✅ Target reached! Searching for lower complexity...")
|
||||
else:
|
||||
min_complexity = mid_complexity + 1
|
||||
print(" ❌ Below target. Searching for higher complexity...")
|
||||
|
||||
if best_complexity is not None:
|
||||
print("\n🎯 Optimal complexity found!")
|
||||
print(f" Complexity: {best_complexity}")
|
||||
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||
|
||||
# Test a few complexities around the optimal one for verification
|
||||
print("\n🔬 Verification test around optimal complexity:")
|
||||
verification_complexities = [
|
||||
max(1, best_complexity - 2),
|
||||
max(1, best_complexity - 1),
|
||||
best_complexity,
|
||||
best_complexity + 1,
|
||||
best_complexity + 2,
|
||||
]
|
||||
|
||||
for complexity in verification_complexities:
|
||||
if complexity <= 512: # reasonable upper bound
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_queries, complexity, recompute_embeddings=False
|
||||
)
|
||||
status = "✅" if recall >= target_recall else "❌"
|
||||
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||
|
||||
# Now test the optimal complexity with compact index and recompute for comparison
|
||||
print(
|
||||
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||
)
|
||||
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||
test_queries[:10], best_complexity, recompute_embeddings=True
|
||||
)
|
||||
print(
|
||||
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||
)
|
||||
complexity = best_complexity
|
||||
print(
|
||||
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||
)
|
||||
compact_evaluator.cleanup()
|
||||
else:
|
||||
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||
print("All tested complexities were below target.")
|
||||
|
||||
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||
binary_search_evaluator.cleanup()
|
||||
evaluator.cleanup()
|
||||
|
||||
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||
|
||||
if args.stage == "4" or args.stage == "all":
|
||||
# Stage 4: Comprehensive evaluation with dual index comparison
|
||||
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
|
||||
|
||||
# Use FinanceBench evaluator for QA evaluation
|
||||
evaluator = FinanceBenchEvaluator(
|
||||
args.index, args.openai_api_key if args.llm_backend == "openai" else None
|
||||
)
|
||||
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
data = evaluator.load_dataset(args.dataset)
|
||||
|
||||
# Step 1: Analyze current (compact) index
|
||||
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||
compact_size_metrics["index_type"] = "compact"
|
||||
|
||||
# Step 2: Use existing non-compact index or create if needed
|
||||
from pathlib import Path
|
||||
|
||||
if Path(non_compact_index_path).exists():
|
||||
print(
|
||||
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||
)
|
||||
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_size_metrics["index_type"] = "non_compact"
|
||||
else:
|
||||
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||
non_compact_index_path
|
||||
)
|
||||
|
||||
# Step 3: Compare index sizes
|
||||
print("\n📊 Index size comparison:")
|
||||
print(
|
||||
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||
)
|
||||
print("\n📊 Index-only size comparison (.index file only):")
|
||||
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
|
||||
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
|
||||
# Use index-only size for fair comparison (same as Enron emails)
|
||||
storage_saving = (
|
||||
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
|
||||
/ non_compact_size_metrics["index_only_mb"]
|
||||
* 100
|
||||
)
|
||||
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||
|
||||
# Step 4: Performance comparison between the two indexes
|
||||
if complexity is None:
|
||||
raise ValueError("Complexity is required for performance comparison")
|
||||
|
||||
print("\n⚡ Performance comparison between indexes...")
|
||||
performance_metrics = evaluator.compare_index_performance(
|
||||
non_compact_index_path, args.index, data[:10], complexity=complexity
|
||||
)
|
||||
|
||||
# Step 5: Generation evaluation
|
||||
test_samples = 20
|
||||
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
|
||||
|
||||
if args.llm_backend == "openai" and args.openai_api_key:
|
||||
print("🔍🤖 Running OpenAI-based generation evaluation...")
|
||||
evaluation_start = time.time()
|
||||
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
|
||||
evaluation_time = time.time() - evaluation_start
|
||||
else:
|
||||
print(
|
||||
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
|
||||
)
|
||||
try:
|
||||
# Load LLM
|
||||
if args.llm_backend == "hf":
|
||||
tokenizer, model = load_hf_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_hf(tokenizer, model, prompt)
|
||||
else: # vllm
|
||||
llm, sampling_params = load_vllm_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_vllm(llm, sampling_params, prompt)
|
||||
|
||||
# Simple generation evaluation
|
||||
queries = [item["question"] for item in data[:test_samples]]
|
||||
gen_results = evaluate_rag(
|
||||
evaluator.searcher,
|
||||
llm_func,
|
||||
queries,
|
||||
domain="finance",
|
||||
complexity=complexity,
|
||||
)
|
||||
|
||||
timing_metrics = {
|
||||
"total_questions": len(queries),
|
||||
"avg_search_time": gen_results["avg_search_time"],
|
||||
"avg_generation_time": gen_results["avg_generation_time"],
|
||||
"results": gen_results["results"],
|
||||
}
|
||||
evaluation_time = time.time()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation evaluation failed: {e}")
|
||||
timing_metrics = {
|
||||
"total_questions": 0,
|
||||
"avg_search_time": 0,
|
||||
"avg_generation_time": 0,
|
||||
}
|
||||
evaluation_time = 0
|
||||
|
||||
# Combine all metrics
|
||||
combined_metrics = {
|
||||
**timing_metrics,
|
||||
"total_evaluation_time": evaluation_time,
|
||||
"current_index": compact_size_metrics,
|
||||
"non_compact_index": non_compact_size_metrics,
|
||||
"performance_comparison": performance_metrics,
|
||||
"storage_saving_percent": storage_saving,
|
||||
}
|
||||
|
||||
# Print results
|
||||
print("\n📊 Generation Results:")
|
||||
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||
|
||||
# Save results if requested
|
||||
if args.output:
|
||||
print(f"\n💾 Saving results to {args.output}...")
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(combined_metrics, f, indent=2, default=str)
|
||||
print(f"✅ Results saved to {args.output}")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage == "all":
|
||||
print("🎉 All evaluation stages completed successfully!")
|
||||
print("\n📋 Summary:")
|
||||
print(" Stage 2: ✅ Recall@3 evaluation completed")
|
||||
print(" Stage 3: ✅ Optimal complexity found")
|
||||
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
|
||||
print("\n🔧 Recommended next steps:")
|
||||
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||
print(" - Review accuracy and timing breakdown for performance optimization")
|
||||
print(" - Run full evaluation on complete dataset if needed")
|
||||
|
||||
# Clean up non-compact index after all stages complete
|
||||
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||
from pathlib import Path
|
||||
|
||||
if Path(non_compact_index_path).exists():
|
||||
temp_index_dir = Path(non_compact_index_path).parent
|
||||
temp_index_name = Path(non_compact_index_path).name
|
||||
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||
temp_file.unlink()
|
||||
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||
else:
|
||||
print("📝 No temporary index to clean up")
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Evaluation interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
462
benchmarks/financebench/setup_financebench.py
Executable file
@@ -0,0 +1,462 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FinanceBench Complete Setup Script
|
||||
Downloads all PDFs and builds full LEANN datastore
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import pymupdf
|
||||
import requests
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class FinanceBenchSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.base_dir = Path(__file__).parent # benchmarks/financebench/
|
||||
self.data_dir = self.base_dir / data_dir
|
||||
self.pdf_dir = self.data_dir / "pdfs"
|
||||
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
|
||||
self.index_dir = self.data_dir / "index"
|
||||
self.download_lock = Lock()
|
||||
|
||||
def download_dataset(self):
|
||||
"""Download the main FinanceBench dataset"""
|
||||
print("📊 Downloading FinanceBench dataset...")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.dataset_file.exists():
|
||||
print(f"✅ Dataset already exists: {self.dataset_file}")
|
||||
return
|
||||
|
||||
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(self.dataset_file, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"✅ Dataset downloaded: {self.dataset_file}")
|
||||
|
||||
def get_pdf_list(self):
|
||||
"""Get list of all PDF files from GitHub"""
|
||||
print("📋 Fetching PDF list from GitHub...")
|
||||
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
|
||||
)
|
||||
response.raise_for_status()
|
||||
pdf_files = response.json()
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files")
|
||||
return pdf_files
|
||||
|
||||
def download_single_pdf(self, pdf_info, position):
|
||||
"""Download a single PDF file"""
|
||||
pdf_name = pdf_info["name"]
|
||||
pdf_path = self.pdf_dir / pdf_name
|
||||
|
||||
# Skip if already downloaded
|
||||
if pdf_path.exists() and pdf_path.stat().st_size > 0:
|
||||
return f"✅ {pdf_name} (cached)"
|
||||
|
||||
try:
|
||||
# Download PDF
|
||||
response = requests.get(pdf_info["download_url"], timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Write to file
|
||||
with self.download_lock:
|
||||
with open(pdf_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return f"✅ {pdf_name} ({len(response.content) // 1024}KB)"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ {pdf_name}: {e!s}"
|
||||
|
||||
def download_all_pdfs(self, max_workers: int = 5):
|
||||
"""Download all PDF files with parallel processing"""
|
||||
self.pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_files = self.get_pdf_list()
|
||||
|
||||
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all download tasks
|
||||
future_to_pdf = {
|
||||
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
|
||||
for i, pdf_info in enumerate(pdf_files)
|
||||
}
|
||||
|
||||
# Process completed downloads with progress bar
|
||||
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
|
||||
for future in as_completed(future_to_pdf):
|
||||
result = future.result()
|
||||
pbar.set_postfix_str(result.split()[-1] if "✅" in result else "Error")
|
||||
pbar.update(1)
|
||||
|
||||
# Verify downloads
|
||||
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
|
||||
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
|
||||
|
||||
# Show any failures
|
||||
missing_pdfs = []
|
||||
for pdf_info in pdf_files:
|
||||
pdf_path = self.pdf_dir / pdf_info["name"]
|
||||
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
|
||||
missing_pdfs.append(pdf_info["name"])
|
||||
|
||||
if missing_pdfs:
|
||||
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
|
||||
for pdf in missing_pdfs[:5]: # Show first 5
|
||||
print(f" - {pdf}")
|
||||
if len(missing_pdfs) > 5:
|
||||
print(f" ... and {len(missing_pdfs) - 5} more")
|
||||
|
||||
def build_leann_index(
|
||||
self,
|
||||
backend: str = "hnsw",
|
||||
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
):
|
||||
"""Build LEANN index from all PDFs"""
|
||||
print(f"🏗️ Building LEANN index with {backend} backend...")
|
||||
|
||||
# Check if we have PDFs
|
||||
pdf_files = list(self.pdf_dir.glob("*.pdf"))
|
||||
if not pdf_files:
|
||||
raise RuntimeError("No PDF files found! Run download first.")
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files to process")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize builder with standard compact configuration
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||
is_compact=True, # Enable compact storage (pruned)
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Process PDFs and extract text
|
||||
total_chunks = 0
|
||||
failed_pdfs = []
|
||||
|
||||
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
|
||||
try:
|
||||
chunks = self.extract_pdf_text(pdf_path)
|
||||
for chunk in chunks:
|
||||
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||
total_chunks += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to process {pdf_path.name}: {e}")
|
||||
failed_pdfs.append(pdf_path.name)
|
||||
continue
|
||||
|
||||
# Build index in index directory
|
||||
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||
index_path = self.index_dir / f"financebench_full_{backend}.leann"
|
||||
print(f"🔨 Building index: {index_path}")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
|
||||
print("✅ Index built successfully!")
|
||||
print(f" 📁 Index path: {index_path}")
|
||||
print(f" 📊 Total chunks: {total_chunks:,}")
|
||||
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
|
||||
print(f" ⏱️ Build time: {build_time:.1f}s")
|
||||
|
||||
if failed_pdfs:
|
||||
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
|
||||
|
||||
return str(index_path)
|
||||
|
||||
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
|
||||
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
|
||||
print("🔨 Building FAISS Flat baseline...")
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from leann.api import compute_embeddings
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Read metadata from the built index
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.loads(f.read())
|
||||
|
||||
embedding_model = meta["embedding_model"]
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not os.path.isabs(passage_file):
|
||||
index_dir = os.path.dirname(index_path)
|
||||
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||
|
||||
print(f"📊 Loading passages from {passage_file}...")
|
||||
print(f"🤖 Using embedding model: {embedding_model}")
|
||||
|
||||
# Load all passages for baseline
|
||||
passages = []
|
||||
passage_ids = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"])
|
||||
|
||||
print(f"📄 Loaded {len(passages)} passages")
|
||||
|
||||
# Compute embeddings using the same method as LEANN
|
||||
print("🧮 Computing embeddings...")
|
||||
embeddings = compute_embeddings(
|
||||
passages,
|
||||
embedding_model,
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
)
|
||||
|
||||
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||
|
||||
# Build FAISS flat index
|
||||
print("🏗️ Building FAISS IndexFlatIP...")
|
||||
dimension = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
# Add embeddings to flat index
|
||||
embeddings_f32 = embeddings.astype(np.float32)
|
||||
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||
|
||||
# Save index and metadata
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as f:
|
||||
pickle.dump(passage_ids, f)
|
||||
|
||||
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||
print(f"✅ Metadata saved to {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
|
||||
return baseline_path
|
||||
|
||||
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
|
||||
"""Extract and chunk text from a PDF file"""
|
||||
chunks = []
|
||||
doc = pymupdf.open(pdf_path)
|
||||
|
||||
for page_num in range(len(doc)):
|
||||
page = doc[page_num]
|
||||
text = page.get_text() # type: ignore
|
||||
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source_file": pdf_path.name,
|
||||
"page_number": page_num + 1,
|
||||
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
|
||||
"company": pdf_path.name.split("_")[0],
|
||||
"doc_period": self.extract_year_from_filename(pdf_path.name),
|
||||
}
|
||||
|
||||
# Use recursive character splitting like LangChain
|
||||
if len(text.split()) > 500:
|
||||
# Split by double newlines (paragraphs)
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
|
||||
current_chunk = ""
|
||||
for para in paragraphs:
|
||||
# If adding this paragraph would make chunk too long, save current chunk
|
||||
if current_chunk and len((current_chunk + " " + para).split()) > 300:
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
current_chunk = para
|
||||
else:
|
||||
current_chunk = (current_chunk + " " + para).strip()
|
||||
|
||||
# Add the last chunk
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Page is short enough, use as single chunk
|
||||
chunks.append(
|
||||
{
|
||||
"text": text.strip(),
|
||||
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
|
||||
}
|
||||
)
|
||||
|
||||
doc.close()
|
||||
return chunks
|
||||
|
||||
def extract_year_from_filename(self, filename: str) -> str:
|
||||
"""Extract year from PDF filename"""
|
||||
# Try to find 4-digit year in filename
|
||||
|
||||
match = re.search(r"(\d{4})", filename)
|
||||
return match.group(1) if match else "unknown"
|
||||
|
||||
def verify_setup(self, index_path: str):
|
||||
"""Verify the setup by testing a simple query"""
|
||||
print("🧪 Verifying setup with test query...")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
# Test query
|
||||
test_query = "What is the capital expenditure for 3M in 2018?"
|
||||
results = searcher.search(test_query, top_k=3)
|
||||
|
||||
print(f"✅ Test query successful! Found {len(results)} results:")
|
||||
for i, result in enumerate(results, 1):
|
||||
company = result.metadata.get("company", "Unknown")
|
||||
year = result.metadata.get("doc_period", "Unknown")
|
||||
page = result.metadata.get("page_number", "Unknown")
|
||||
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
|
||||
print(f" {result.text[:100]}...")
|
||||
|
||||
searcher.cleanup()
|
||||
print("✅ Setup verification completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Setup verification failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument(
|
||||
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model",
|
||||
)
|
||||
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
|
||||
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||
parser.add_argument(
|
||||
"--build-baseline-only",
|
||||
action="store_true",
|
||||
help="Only build FAISS baseline from existing index",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🏦 FinanceBench Complete Setup")
|
||||
print("=" * 50)
|
||||
|
||||
setup = FinanceBenchSetup(args.data_dir)
|
||||
|
||||
try:
|
||||
if args.build_baseline_only:
|
||||
# Only build baseline from existing index
|
||||
index_path = setup.index_dir / f"financebench_full_{args.backend}"
|
||||
index_file = f"{index_path}.index"
|
||||
meta_file = f"{index_path}.leann.meta.json"
|
||||
|
||||
if not os.path.exists(index_file) or not os.path.exists(meta_file):
|
||||
print("❌ Index files not found:")
|
||||
print(f" Index: {index_file}")
|
||||
print(f" Meta: {meta_file}")
|
||||
print("💡 Run without --build-baseline-only to build the index first")
|
||||
exit(1)
|
||||
|
||||
print(f"🔨 Building baseline from existing index: {index_path}")
|
||||
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
|
||||
print(f"✅ Baseline built at {baseline_path}")
|
||||
return
|
||||
|
||||
# Step 1: Download dataset
|
||||
setup.download_dataset()
|
||||
|
||||
# Step 2: Download PDFs
|
||||
if not args.skip_download:
|
||||
setup.download_all_pdfs(max_workers=args.max_workers)
|
||||
else:
|
||||
print("⏭️ Skipping PDF download")
|
||||
|
||||
# Step 3: Build LEANN index
|
||||
if not args.skip_build:
|
||||
index_path = setup.build_leann_index(
|
||||
backend=args.backend, embedding_model=args.embedding_model
|
||||
)
|
||||
|
||||
# Step 4: Build FAISS flat baseline
|
||||
print("\n🔨 Building FAISS flat baseline...")
|
||||
baseline_path = setup.build_faiss_flat_baseline(index_path)
|
||||
print(f"✅ Baseline built at {baseline_path}")
|
||||
|
||||
# Step 5: Verify setup
|
||||
setup.verify_setup(index_path)
|
||||
else:
|
||||
print("⏭️ Skipping index building")
|
||||
|
||||
print("\n🎉 FinanceBench setup completed!")
|
||||
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||
print("\nNext steps:")
|
||||
print(
|
||||
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
|
||||
)
|
||||
print(
|
||||
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Setup interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Setup failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
214
benchmarks/financebench/verify_recall.py
Normal file
@@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.9"
|
||||
# dependencies = [
|
||||
# "faiss-cpu",
|
||||
# "numpy",
|
||||
# "sentence-transformers",
|
||||
# "torch",
|
||||
# "tqdm",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Independent recall verification script using standard FAISS.
|
||||
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
|
||||
"""
|
||||
Direct embedding computation using sentence-transformers.
|
||||
Copied logic to avoid dependency issues.
|
||||
"""
|
||||
print(f"Loading model: {model_name}")
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
print(f"Computing embeddings for {len(chunks)} chunks...")
|
||||
embeddings = model.encode(
|
||||
chunks,
|
||||
show_progress_bar=True,
|
||||
batch_size=32,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False,
|
||||
)
|
||||
|
||||
return embeddings.astype(np.float32)
|
||||
|
||||
|
||||
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
|
||||
"""Load FinanceBench queries from dataset"""
|
||||
queries = []
|
||||
with open(dataset_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
if len(queries) >= max_queries:
|
||||
break
|
||||
return queries
|
||||
|
||||
|
||||
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
|
||||
"""Load passages from LEANN index structure"""
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
index_dir = Path(index_path).parent
|
||||
passage_file = index_dir / Path(passage_file).name
|
||||
|
||||
print(f"Loading passages from {passage_file}")
|
||||
|
||||
passages = []
|
||||
passage_ids = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in tqdm(f, desc="Loading passages"):
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"])
|
||||
|
||||
print(f"Loaded {len(passages)} passages")
|
||||
return passages, passage_ids
|
||||
|
||||
|
||||
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
|
||||
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
|
||||
dimension = embeddings.shape[1]
|
||||
|
||||
# Build Flat index (ground truth)
|
||||
print("Building FAISS IndexFlatIP (ground truth)...")
|
||||
flat_index = faiss.IndexFlatIP(dimension)
|
||||
flat_index.add(embeddings)
|
||||
|
||||
# Build HNSW index
|
||||
print("Building FAISS IndexHNSWFlat...")
|
||||
M = 32 # Same as LEANN default
|
||||
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
|
||||
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
|
||||
hnsw_index.add(embeddings)
|
||||
|
||||
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
|
||||
return flat_index, hnsw_index
|
||||
|
||||
|
||||
def evaluate_recall_at_k(
|
||||
query_embeddings: np.ndarray,
|
||||
flat_index: faiss.Index,
|
||||
hnsw_index: faiss.Index,
|
||||
passage_ids: list[str],
|
||||
k: int = 3,
|
||||
ef_search: int = 64,
|
||||
) -> float:
|
||||
"""Evaluate recall@k comparing HNSW vs Flat"""
|
||||
|
||||
# Set search parameters for HNSW
|
||||
hnsw_index.hnsw.efSearch = ef_search
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = query_embeddings.shape[0]
|
||||
|
||||
for i in range(num_queries):
|
||||
query = query_embeddings[i : i + 1] # Keep 2D shape
|
||||
|
||||
# Get ground truth from Flat index (standard FAISS API)
|
||||
flat_distances, flat_indices = flat_index.search(query, k)
|
||||
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
|
||||
|
||||
# Get results from HNSW index (standard FAISS API)
|
||||
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
|
||||
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
|
||||
|
||||
# Calculate recall
|
||||
intersection = ground_truth_ids.intersection(hnsw_ids)
|
||||
recall = len(intersection) / k
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
|
||||
print(f" Flat: {list(ground_truth_ids)}")
|
||||
print(f" HNSW: {list(hnsw_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
return avg_recall
|
||||
|
||||
|
||||
def main():
|
||||
# Configuration
|
||||
dataset_path = "data/financebench_merged.jsonl"
|
||||
index_path = "data/index/financebench_full_hnsw.leann"
|
||||
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
print("🔍 FAISS Recall Verification")
|
||||
print("=" * 50)
|
||||
|
||||
# Check if files exist
|
||||
if not Path(dataset_path).exists():
|
||||
print(f"❌ Dataset not found: {dataset_path}")
|
||||
return
|
||||
if not Path(f"{index_path}.meta.json").exists():
|
||||
print(f"❌ Index metadata not found: {index_path}.meta.json")
|
||||
return
|
||||
|
||||
# Load data
|
||||
print("📖 Loading FinanceBench queries...")
|
||||
queries = load_financebench_queries(dataset_path, max_queries=50)
|
||||
print(f"Loaded {len(queries)} queries")
|
||||
|
||||
print("📄 Loading passages from LEANN index...")
|
||||
passages, passage_ids = load_passages_from_leann_index(index_path)
|
||||
|
||||
# Compute embeddings
|
||||
print("🧮 Computing passage embeddings...")
|
||||
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
|
||||
|
||||
print("🧮 Computing query embeddings...")
|
||||
query_embeddings = compute_embeddings_direct(queries, embedding_model)
|
||||
|
||||
# Build FAISS indexes
|
||||
print("🏗️ Building FAISS indexes...")
|
||||
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
|
||||
|
||||
# Test different efSearch values (equivalent to LEANN complexity)
|
||||
print("\n📊 Evaluating Recall@3 at different efSearch values...")
|
||||
ef_search_values = [16, 32, 64, 128, 256]
|
||||
|
||||
for ef_search in ef_search_values:
|
||||
print(f"\n🧪 Testing efSearch = {ef_search}")
|
||||
start_time = time.time()
|
||||
|
||||
recall = evaluate_recall_at_k(
|
||||
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(
|
||||
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
print("\n✅ Verification completed!")
|
||||
print("\n📋 Summary:")
|
||||
print(" - Built independent FAISS Flat and HNSW indexes")
|
||||
print(" - Compared recall@3 at different efSearch values")
|
||||
print(" - Used same embedding model as LEANN")
|
||||
print(" - This validates LEANN's recall measurements")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
benchmarks/laion/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
data/
|
||||
199
benchmarks/laion/README.md
Normal file
@@ -0,0 +1,199 @@
|
||||
# LAION Multimodal Benchmark
|
||||
|
||||
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
|
||||
|
||||
## Overview
|
||||
|
||||
This benchmark evaluates:
|
||||
- **Image retrieval timing** using caption-based queries
|
||||
- **Recall@K performance** for image search
|
||||
- **Complexity analysis** across different search parameters
|
||||
- **Index size and storage efficiency**
|
||||
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
|
||||
|
||||
## Dataset Configuration
|
||||
|
||||
- **Dataset**: LAION-400M subset (10,000 images)
|
||||
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
|
||||
- **Queries**: 200 random captions from the dataset
|
||||
- **Ground Truth**: Self-recall (query caption → original image)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Setup the benchmark
|
||||
|
||||
```bash
|
||||
cd benchmarks/laion
|
||||
python setup_laion.py --num-samples 10000 --num-queries 200
|
||||
```
|
||||
|
||||
This will:
|
||||
- Create dummy LAION data (10K samples)
|
||||
- Generate CLIP embeddings (512-dim)
|
||||
- Build LEANN index with HNSW backend
|
||||
- Create 200 evaluation queries
|
||||
|
||||
### 2. Run evaluation
|
||||
|
||||
```bash
|
||||
# Run all evaluation stages
|
||||
python evaluate_laion.py --index data/laion_index.leann
|
||||
|
||||
# Run specific stages
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
|
||||
|
||||
# Multimodal generation with Qwen2.5-VL
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
### 3. Save results
|
||||
|
||||
```bash
|
||||
python evaluate_laion.py --index data/laion_index.leann --output results.json
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Setup Options
|
||||
```bash
|
||||
python setup_laion.py \
|
||||
--num-samples 10000 \
|
||||
--num-queries 200 \
|
||||
--index-path data/laion_index.leann \
|
||||
--backend hnsw
|
||||
```
|
||||
|
||||
### Evaluation Options
|
||||
```bash
|
||||
python evaluate_laion.py \
|
||||
--index data/laion_index.leann \
|
||||
--queries data/evaluation_queries.jsonl \
|
||||
--complexity 64 \
|
||||
--top-k 3 \
|
||||
--num-samples 100 \
|
||||
--stage all
|
||||
```
|
||||
|
||||
## Evaluation Stages
|
||||
|
||||
### Stage 2: Recall Evaluation
|
||||
- Evaluates Recall@3 for multimodal retrieval
|
||||
- Compares LEANN vs FAISS baseline performance
|
||||
- Self-recall: query caption should retrieve original image
|
||||
|
||||
### Stage 3: Complexity Analysis
|
||||
- Binary search for optimal complexity (90% recall target)
|
||||
- Tests performance across different complexity levels
|
||||
- Analyzes speed vs. accuracy tradeoffs
|
||||
|
||||
### Stage 4: Index Comparison
|
||||
- Compares compact vs non-compact index sizes
|
||||
- Measures search performance differences
|
||||
- Reports storage efficiency and speed ratios
|
||||
|
||||
### Stage 5: Multimodal Generation
|
||||
- Uses Qwen2.5-VL for image understanding and description
|
||||
- Retrieval-Augmented Generation (RAG) with multimodal context
|
||||
- Measures both search and generation timing
|
||||
|
||||
## Output Metrics
|
||||
|
||||
### Timing Metrics
|
||||
- Average/median/min/max search time
|
||||
- Standard deviation
|
||||
- Searches per second
|
||||
- Latency in milliseconds
|
||||
|
||||
### Recall Metrics
|
||||
- Recall@3 percentage for image retrieval
|
||||
- Number of queries with ground truth
|
||||
|
||||
### Index Metrics
|
||||
- Total index size (MB)
|
||||
- Component breakdown (index, passages, metadata)
|
||||
- Storage savings (compact vs non-compact)
|
||||
- Backend and embedding model info
|
||||
|
||||
### Generation Metrics (Stage 5)
|
||||
- Average search time per query
|
||||
- Average generation time per query
|
||||
- Time distribution (search vs generation)
|
||||
- Sample multimodal responses
|
||||
- Model: Qwen2.5-VL performance
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
|
||||
|
||||
**Stage 3: Optimal Complexity Analysis**
|
||||
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
|
||||
- **Binary Search Range**: 1-128
|
||||
- **Target Recall**: 90%
|
||||
- **Index Type**: Non-compact (for fast binary search)
|
||||
|
||||
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
|
||||
- **Total Queries**: 20
|
||||
- **Average Search Time**: 1.200s per query
|
||||
- **Average Generation Time**: 6.558s per query
|
||||
- **Time Distribution**: Search 15.5%, Generation 84.5%
|
||||
- **LLM Backend**: HuggingFace transformers
|
||||
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
- **Optimal Complexity**: 85
|
||||
|
||||
**System Performance:**
|
||||
- **Index Size**: ~10,000 image embeddings from LAION subset
|
||||
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
|
||||
- **Backend**: HNSW with cosine distance
|
||||
|
||||
### Example Results
|
||||
|
||||
```
|
||||
🎯 LAION MULTIMODAL BENCHMARK RESULTS
|
||||
============================================================
|
||||
|
||||
📊 Multimodal Generation Results:
|
||||
Total Queries: 20
|
||||
Avg Search Time: 1.200s
|
||||
Avg Generation Time: 6.558s
|
||||
Time Distribution: Search 15.5%, Generation 84.5%
|
||||
LLM Backend: HuggingFace transformers
|
||||
Model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
⚙️ Optimal Complexity Analysis:
|
||||
Target Recall: 90%
|
||||
Optimal Complexity: 85
|
||||
Binary Search Range: 1-128
|
||||
Non-compact Index (fast search, no recompute)
|
||||
|
||||
🚀 Performance Summary:
|
||||
Multimodal RAG: 7.758s total per query
|
||||
Search: 15.5% of total time
|
||||
Generation: 84.5% of total time
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
benchmarks/laion/
|
||||
├── setup_laion.py # Setup script
|
||||
├── evaluate_laion.py # Evaluation script
|
||||
├── README.md # This file
|
||||
└── data/ # Generated data
|
||||
├── laion_images/ # Image files (placeholder)
|
||||
├── laion_metadata.jsonl # Image metadata
|
||||
├── laion_passages.jsonl # LEANN passages
|
||||
├── laion_embeddings.npy # CLIP embeddings
|
||||
├── evaluation_queries.jsonl # Evaluation queries
|
||||
└── laion_index.leann/ # LEANN index files
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Current implementation uses dummy data for demonstration
|
||||
- For real LAION data, implement actual download logic in `setup_laion.py`
|
||||
- CLIP embeddings are randomly generated - replace with real CLIP model for production
|
||||
- Adjust `num_samples` and `num_queries` based on available resources
|
||||
- Consider using `--num-samples` during evaluation for faster testing
|
||||
725
benchmarks/laion/evaluate_laion.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""
|
||||
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann import LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
# Load FAISS flat baseline (image embeddings)
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.image_ids = pickle.load(f)
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
|
||||
|
||||
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
|
||||
self.st_clip = SentenceTransformer("clip-ViT-L-14")
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = len(captions)
|
||||
|
||||
for i, caption in enumerate(captions):
|
||||
# Get ground truth: search with FAISS flat using caption text embedding
|
||||
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
|
||||
query_embedding = self.st_clip.encode(
|
||||
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||
n = query_embedding.shape[0] # Number of queries
|
||||
k = 3 # Number of nearest neighbors
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(query_embedding),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
# Extract the results (image IDs from FAISS)
|
||||
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN at specified complexity (using caption as text query)
|
||||
test_results = self.searcher.search(
|
||||
caption,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {result.id for result in test_results}
|
||||
|
||||
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||
return avg_recall
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class LAIONEvaluator:
|
||||
def __init__(self, index_path: str):
|
||||
self.index_path = index_path
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
def load_queries(self, queries_file: str) -> list[str]:
|
||||
"""Load caption queries from evaluation file"""
|
||||
captions = []
|
||||
with open(queries_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
query_data = json.loads(line)
|
||||
captions.append(query_data["query"])
|
||||
|
||||
print(f"📊 Loaded {len(captions)} caption queries")
|
||||
return captions
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
|
||||
print("📏 Analyzing index sizes (.index only)...")
|
||||
|
||||
# Get all index-related files
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem # Remove .leann extension
|
||||
|
||||
sizes: dict[str, float] = {}
|
||||
|
||||
# Core index files
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||
|
||||
# Core index size (.index only)
|
||||
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||
sizes["index_only_mb"] = index_mb
|
||||
|
||||
# Other files for reference (not counted in index_only_mb)
|
||||
sizes["metadata_mb"] = (
|
||||
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_text_mb"] = (
|
||||
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_index_mb"] = (
|
||||
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||
)
|
||||
|
||||
print(f" 📁 .index size: {index_mb:.1f} MB")
|
||||
if sizes["metadata_mb"]:
|
||||
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
|
||||
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
|
||||
print(
|
||||
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
|
||||
)
|
||||
|
||||
return sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison purposes"""
|
||||
print("🏗️ Building non-compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Load CLIP embeddings
|
||||
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
|
||||
embeddings = np.load(embeddings_file)
|
||||
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
|
||||
|
||||
# Build non-compact index with same passages and embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=False, # Disable recompute (store embeddings)
|
||||
is_compact=False, # Disable compact storage
|
||||
distance_metric="cosine",
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact", "distance_metric"]
|
||||
},
|
||||
)
|
||||
|
||||
# Prepare ids and add passages
|
||||
ids: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
ids.append(str(data["id"]))
|
||||
# Ensure metadata contains the id used by the vector index
|
||||
metadata = {**data.get("metadata", {}), "id": data["id"]}
|
||||
builder.add_text(text=data["text"], metadata=metadata)
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
|
||||
# Persist a pickle for build_index_from_embeddings
|
||||
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
|
||||
print(
|
||||
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||
)
|
||||
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
|
||||
) -> dict:
|
||||
"""Compare performance between non-compact and compact indexes"""
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
|
||||
# Test queries
|
||||
test_queries = test_captions[:5]
|
||||
|
||||
results = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
}
|
||||
|
||||
# Test non-compact index (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
|
||||
for caption in test_queries:
|
||||
start_time = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
caption, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["non_compact"]["search_times"].append(search_time)
|
||||
|
||||
# Test compact index (with recompute)
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
|
||||
for caption in test_queries:
|
||||
start_time = time.time()
|
||||
_ = compact_searcher.search(
|
||||
caption, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["compact"]["search_times"].append(search_time)
|
||||
|
||||
# Calculate averages
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
|
||||
# Performance ratio
|
||||
if results["avg_search_times"]["compact"] > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = float("inf")
|
||||
|
||||
print(
|
||||
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||
)
|
||||
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||
|
||||
# Cleanup
|
||||
non_compact_searcher.cleanup()
|
||||
compact_searcher.cleanup()
|
||||
|
||||
return results
|
||||
|
||||
def _print_results(self, timing_metrics: dict):
|
||||
"""Print evaluation results"""
|
||||
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
|
||||
print("=" * 60)
|
||||
|
||||
# Index comparison analysis (prefer .index-only view if present)
|
||||
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||
current = timing_metrics["current_index"]
|
||||
non_compact = timing_metrics["non_compact_index"]
|
||||
|
||||
if "index_only_mb" in current and "index_only_mb" in non_compact:
|
||||
print("\n📏 Index Comparison Analysis (.index only):")
|
||||
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
|
||||
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
# Show excluded components for reference if available
|
||||
if any(
|
||||
k in non_compact
|
||||
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
|
||||
):
|
||||
print(" (passages excluded in totals, shown for reference):")
|
||||
print(
|
||||
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
|
||||
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
|
||||
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
|
||||
)
|
||||
else:
|
||||
# Fallback to legacy totals if running with older metrics
|
||||
print("\n📏 Index Comparison Analysis:")
|
||||
print(
|
||||
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
print(" Component breakdown (non-compact):")
|
||||
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||
|
||||
# Performance comparison
|
||||
if "performance_comparison" in timing_metrics:
|
||||
perf = timing_metrics["performance_comparison"]
|
||||
print("\n⚡ Performance Comparison:")
|
||||
print(
|
||||
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||
)
|
||||
print(
|
||||
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||
)
|
||||
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||
|
||||
# Legacy single index analysis (fallback)
|
||||
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||
print("\n📏 Index Size Analysis:")
|
||||
print(
|
||||
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument(
|
||||
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "5", "all"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument(
|
||||
"--llm-backend",
|
||||
choices=["hf"],
|
||||
default="hf",
|
||||
help="LLM backend (Qwen2.5-VL only supports HF)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Check if baseline exists
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_laion.py first to build the baseline")
|
||||
exit(1)
|
||||
|
||||
if args.stage == "2" or args.stage == "all":
|
||||
# Stage 2: Recall@3 evaluation
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
|
||||
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
# Load caption queries for testing
|
||||
laion_evaluator = LAIONEvaluator(args.index)
|
||||
captions = laion_evaluator.load_queries(args.queries)
|
||||
|
||||
# Test with queries for robust measurement
|
||||
test_captions = captions[:100] # Use subset for speed
|
||||
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||
|
||||
# Test with complexity 64
|
||||
complexity = 64
|
||||
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
|
||||
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
# Shared non-compact index path for Stage 3 and 4
|
||||
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||
complexity = args.complexity
|
||||
|
||||
if args.stage == "3" or args.stage == "all":
|
||||
# Stage 3: Binary search for 90% recall complexity
|
||||
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||
print(
|
||||
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||
)
|
||||
|
||||
# Create non-compact index for binary search
|
||||
print("🏗️ Creating non-compact index for binary search...")
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact index for binary search
|
||||
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
|
||||
# Load caption queries for testing
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
|
||||
# Use subset for robust measurement
|
||||
test_captions = captions[:50] # Smaller subset for binary search speed
|
||||
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||
|
||||
# Binary search for 90% recall complexity
|
||||
target_recall = 0.9
|
||||
min_complexity, max_complexity = 1, 128
|
||||
|
||||
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||
|
||||
best_complexity = None
|
||||
best_recall = 0.0
|
||||
|
||||
while min_complexity <= max_complexity:
|
||||
mid_complexity = (min_complexity + max_complexity) // 2
|
||||
|
||||
print(
|
||||
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||
)
|
||||
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_captions, mid_complexity, recompute_embeddings=False
|
||||
)
|
||||
|
||||
print(
|
||||
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||
)
|
||||
|
||||
if recall >= target_recall:
|
||||
best_complexity = mid_complexity
|
||||
best_recall = recall
|
||||
max_complexity = mid_complexity - 1
|
||||
print(" ✅ Target reached! Searching for lower complexity...")
|
||||
else:
|
||||
min_complexity = mid_complexity + 1
|
||||
print(" ❌ Below target. Searching for higher complexity...")
|
||||
|
||||
if best_complexity is not None:
|
||||
print("\n🎯 Optimal complexity found!")
|
||||
print(f" Complexity: {best_complexity}")
|
||||
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||
|
||||
# Test a few complexities around the optimal one for verification
|
||||
print("\n🔬 Verification test around optimal complexity:")
|
||||
verification_complexities = [
|
||||
max(1, best_complexity - 2),
|
||||
max(1, best_complexity - 1),
|
||||
best_complexity,
|
||||
best_complexity + 1,
|
||||
best_complexity + 2,
|
||||
]
|
||||
|
||||
for complexity in verification_complexities:
|
||||
if complexity <= 512: # reasonable upper bound
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_captions, complexity, recompute_embeddings=False
|
||||
)
|
||||
status = "✅" if recall >= target_recall else "❌"
|
||||
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||
|
||||
# Now test the optimal complexity with compact index and recompute for comparison
|
||||
print(
|
||||
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||
)
|
||||
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||
test_captions[:10], best_complexity, recompute_embeddings=True
|
||||
)
|
||||
print(
|
||||
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||
)
|
||||
complexity = best_complexity
|
||||
print(
|
||||
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||
)
|
||||
compact_evaluator.cleanup()
|
||||
else:
|
||||
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||
print("All tested complexities were below target.")
|
||||
|
||||
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||
binary_search_evaluator.cleanup()
|
||||
evaluator.cleanup()
|
||||
|
||||
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||
|
||||
if args.stage == "4" or args.stage == "all":
|
||||
# Stage 4: Index comparison (without LLM generation)
|
||||
print("🚀 Starting Stage 4: Index comparison analysis")
|
||||
|
||||
# Use LAION evaluator for index comparison
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
|
||||
# Load caption queries
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
|
||||
# Step 1: Analyze current (compact) index
|
||||
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||
compact_size_metrics["index_type"] = "compact"
|
||||
|
||||
# Step 2: Use existing non-compact index or create if needed
|
||||
if Path(non_compact_index_path).exists():
|
||||
print(
|
||||
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||
)
|
||||
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_size_metrics["index_type"] = "non_compact"
|
||||
else:
|
||||
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||
non_compact_index_path
|
||||
)
|
||||
|
||||
# Step 3: Compare index sizes (.index only)
|
||||
print("\n📊 Index size comparison (.index only):")
|
||||
print(
|
||||
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
|
||||
)
|
||||
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
|
||||
|
||||
storage_saving = 0.0
|
||||
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
|
||||
storage_saving = (
|
||||
(
|
||||
non_compact_size_metrics.get("index_only_mb", 0)
|
||||
- compact_size_metrics.get("index_only_mb", 0)
|
||||
)
|
||||
/ non_compact_size_metrics.get("index_only_mb", 1)
|
||||
* 100
|
||||
)
|
||||
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||
|
||||
# Step 4: Performance comparison between the two indexes
|
||||
if complexity is None:
|
||||
raise ValueError("Complexity is required for index comparison")
|
||||
|
||||
print("\n⚡ Performance comparison between indexes...")
|
||||
performance_metrics = evaluator.compare_index_performance(
|
||||
non_compact_index_path, args.index, captions[:10], complexity=complexity
|
||||
)
|
||||
|
||||
# Combine all metrics
|
||||
combined_metrics = {
|
||||
"current_index": compact_size_metrics,
|
||||
"non_compact_index": non_compact_size_metrics,
|
||||
"performance_comparison": performance_metrics,
|
||||
"storage_saving_percent": storage_saving,
|
||||
}
|
||||
|
||||
# Print comprehensive results
|
||||
evaluator._print_results(combined_metrics)
|
||||
|
||||
# Save results if requested
|
||||
if args.output:
|
||||
print(f"\n💾 Saving results to {args.output}...")
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(combined_metrics, f, indent=2, default=str)
|
||||
print(f"✅ Results saved to {args.output}")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage in ("5", "all"):
|
||||
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
test_captions = captions[: min(20, len(captions))] # Use subset for generation
|
||||
|
||||
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
|
||||
|
||||
# Load Qwen2.5-VL model
|
||||
try:
|
||||
print("Loading Qwen2.5-VL model...")
|
||||
processor, model = load_qwen_vl_model(args.model_name)
|
||||
|
||||
# Run multimodal generation evaluation
|
||||
complexity = args.complexity or 64
|
||||
gen_results = evaluate_multimodal_rag(
|
||||
evaluator.searcher,
|
||||
test_captions,
|
||||
processor=processor,
|
||||
model=model,
|
||||
complexity=complexity,
|
||||
)
|
||||
|
||||
print("\n📊 Multimodal Generation Results:")
|
||||
print(f" Total Queries: {len(test_captions)}")
|
||||
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
|
||||
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
|
||||
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
|
||||
search_pct = (gen_results["avg_search_time"] / total_time) * 100
|
||||
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
|
||||
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
|
||||
print(" LLM Backend: HuggingFace transformers")
|
||||
print(f" Model: {args.model_name}")
|
||||
|
||||
# Show sample results
|
||||
print("\n📝 Sample Multimodal Generations:")
|
||||
for i, response in enumerate(gen_results["results"][:3]):
|
||||
# Handle both string and dict formats for captions
|
||||
if isinstance(test_captions[i], dict):
|
||||
caption_text = test_captions[i].get("query", str(test_captions[i]))
|
||||
else:
|
||||
caption_text = str(test_captions[i])
|
||||
print(f" Query {i + 1}: {caption_text[:60]}...")
|
||||
print(f" Response {i + 1}: {response[:100]}...")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Multimodal generation evaluation failed: {e}")
|
||||
print("💡 Make sure transformers and Qwen2.5-VL are installed")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 5 completed!\n")
|
||||
|
||||
if args.stage == "all":
|
||||
print("🎉 All evaluation stages completed successfully!")
|
||||
print("\n📋 Summary:")
|
||||
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
|
||||
print(" Stage 3: ✅ Optimal complexity found")
|
||||
print(" Stage 4: ✅ Index comparison analysis completed")
|
||||
print(" Stage 5: ✅ Multimodal generation evaluation completed")
|
||||
print("\n🔧 Recommended next steps:")
|
||||
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||
print(" - Review index comparison for storage vs performance tradeoffs")
|
||||
|
||||
# Clean up non-compact index after all stages complete
|
||||
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||
if Path(non_compact_index_path).exists():
|
||||
temp_index_dir = Path(non_compact_index_path).parent
|
||||
temp_index_name = Path(non_compact_index_path).name
|
||||
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||
temp_file.unlink()
|
||||
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||
else:
|
||||
print("📝 No temporary index to clean up")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Evaluation interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
576
benchmarks/laion/setup_laion.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""
|
||||
LAION Multimodal Benchmark Setup Script
|
||||
Downloads LAION subset and builds LEANN index with sentence embeddings
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from leann import LeannBuilder
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class LAIONSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.images_dir = self.data_dir / "laion_images"
|
||||
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
|
||||
|
||||
# Create directories
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
self.images_dir.mkdir(exist_ok=True)
|
||||
|
||||
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
|
||||
"""Download a single image asynchronously"""
|
||||
async with semaphore: # Limit concurrent downloads
|
||||
try:
|
||||
image_url = sample_data["url"]
|
||||
image_path = sample_data["image_path"]
|
||||
|
||||
# Skip if already exists
|
||||
if os.path.exists(image_path):
|
||||
progress_bar.update(1)
|
||||
return sample_data
|
||||
|
||||
async with session.get(image_url, timeout=10) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
|
||||
# Verify it's a valid image
|
||||
try:
|
||||
img = Image.open(io.BytesIO(content))
|
||||
img = img.convert("RGB")
|
||||
img.save(image_path, "JPEG")
|
||||
progress_bar.update(1)
|
||||
return sample_data
|
||||
except Exception:
|
||||
progress_bar.update(1)
|
||||
return None # Skip invalid images
|
||||
else:
|
||||
progress_bar.update(1)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
progress_bar.update(1)
|
||||
return None
|
||||
|
||||
def download_laion_subset(self, num_samples: int = 1000):
|
||||
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
|
||||
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
|
||||
|
||||
# Load LAION-400M subset from HuggingFace
|
||||
print("🤗 Loading from HuggingFace datasets...")
|
||||
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
|
||||
|
||||
# Collect sample metadata first (fast)
|
||||
print("📋 Collecting sample metadata...")
|
||||
candidates = []
|
||||
for sample in dataset:
|
||||
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
|
||||
break
|
||||
|
||||
image_url = sample.get("url", "")
|
||||
caption = sample.get("caption", "")
|
||||
|
||||
if not image_url or not caption:
|
||||
continue
|
||||
|
||||
image_filename = f"laion_{len(candidates):06d}.jpg"
|
||||
image_path = self.images_dir / image_filename
|
||||
|
||||
candidate = {
|
||||
"id": f"laion_{len(candidates):06d}",
|
||||
"url": image_url,
|
||||
"caption": caption,
|
||||
"image_path": str(image_path),
|
||||
"width": sample.get("original_width", 512),
|
||||
"height": sample.get("original_height", 512),
|
||||
"similarity": sample.get("similarity", 0.0),
|
||||
}
|
||||
candidates.append(candidate)
|
||||
|
||||
print(
|
||||
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
|
||||
)
|
||||
|
||||
# Download images in parallel
|
||||
async def download_batch():
|
||||
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
|
||||
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
|
||||
|
||||
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||
tasks = []
|
||||
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
|
||||
task = self.download_single_image(session, candidate, semaphore, progress_bar)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all downloads
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
progress_bar.close()
|
||||
|
||||
# Filter successful downloads
|
||||
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
|
||||
return successful[:num_samples]
|
||||
|
||||
# Run async download
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
samples = loop.run_until_complete(download_batch())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Save metadata
|
||||
with open(self.metadata_file, "w", encoding="utf-8") as f:
|
||||
for sample in samples:
|
||||
f.write(json.dumps(sample) + "\n")
|
||||
|
||||
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
|
||||
return samples
|
||||
|
||||
def generate_clip_image_embeddings(self, samples: list[dict]):
|
||||
"""Generate CLIP image embeddings for downloaded images"""
|
||||
print("🔍 Generating CLIP image embeddings...")
|
||||
|
||||
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
|
||||
# This single model can encode both images and text.
|
||||
model = SentenceTransformer("clip-ViT-L-14")
|
||||
|
||||
embeddings = []
|
||||
valid_samples = []
|
||||
|
||||
for sample in tqdm(samples, desc="Processing images"):
|
||||
try:
|
||||
# Load image
|
||||
image_path = sample["image_path"]
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Encode image to 768-dim embedding via sentence-transformers (normalized)
|
||||
vec = model.encode(
|
||||
[image],
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=1,
|
||||
show_progress_bar=False,
|
||||
)[0]
|
||||
embeddings.append(vec.astype(np.float32))
|
||||
valid_samples.append(sample)
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Failed to process {sample['id']}: {e}")
|
||||
# Skip invalid images
|
||||
|
||||
embeddings = np.array(embeddings, dtype=np.float32)
|
||||
|
||||
# Save embeddings
|
||||
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
|
||||
np.save(embeddings_file, embeddings)
|
||||
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
|
||||
|
||||
return embeddings, valid_samples
|
||||
|
||||
def build_faiss_baseline(
|
||||
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
|
||||
):
|
||||
"""Build FAISS flat baseline using CLIP image embeddings"""
|
||||
print("🔨 Building FAISS Flat baseline...")
|
||||
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Extract image IDs (must be present)
|
||||
if not samples or "id" not in samples[0]:
|
||||
raise KeyError("samples missing 'id' field for FAISS baseline")
|
||||
image_ids: list[str] = [str(sample["id"]) for sample in samples]
|
||||
|
||||
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||
print(f"📄 Processing {len(image_ids)} images")
|
||||
|
||||
# Build FAISS flat index
|
||||
print("🏗️ Building FAISS IndexFlatIP...")
|
||||
dimension = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
# Add embeddings to flat index
|
||||
embeddings_f32 = embeddings.astype(np.float32)
|
||||
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||
|
||||
# Save index and metadata
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as f:
|
||||
pickle.dump(image_ids, f)
|
||||
|
||||
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||
print(f"✅ Metadata saved to {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
|
||||
return baseline_path
|
||||
|
||||
def create_leann_passages(self, samples: list[dict]):
|
||||
"""Create LEANN-compatible passages from LAION data"""
|
||||
print("📝 Creating LEANN passages...")
|
||||
|
||||
passages_file = self.data_dir / "laion_passages.jsonl"
|
||||
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
for i, sample in enumerate(samples):
|
||||
passage = {
|
||||
"id": sample["id"],
|
||||
"text": sample["caption"], # Use caption as searchable text
|
||||
"metadata": {
|
||||
"image_url": sample["url"],
|
||||
"image_path": sample.get("image_path", ""),
|
||||
"width": sample["width"],
|
||||
"height": sample["height"],
|
||||
"similarity": sample["similarity"],
|
||||
"image_index": i, # Index for embedding lookup
|
||||
},
|
||||
}
|
||||
f.write(json.dumps(passage) + "\n")
|
||||
|
||||
print(f"✅ Created {len(samples)} passages")
|
||||
return passages_file
|
||||
|
||||
def build_compact_index(
|
||||
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
|
||||
print(f"🏗️ Building compact LEANN index with {backend} backend...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
|
||||
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||
np.save(npy_path, embeddings)
|
||||
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||
|
||||
# Prepare ids in the same order as passages_file (matches embeddings order)
|
||||
ids: list[str] = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
rec = json.loads(line)
|
||||
ids.append(str(rec["id"]))
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
|
||||
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||
|
||||
# Initialize builder - compact with recompute
|
||||
# Note: For multimodal case, we need to handle embeddings differently
|
||||
# Let's try using sentence-transformers mode but with custom embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
# HNSW params (or forwarded to chosen backend)
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
# Compact/pruned with recompute at query time
|
||||
is_recompute=True,
|
||||
is_compact=True,
|
||||
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Add passages (text + metadata)
|
||||
print("📚 Adding passages...")
|
||||
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||
|
||||
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
|
||||
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
print(f"✅ Compact index built in {build_time:.2f}s")
|
||||
|
||||
# Analyze index size
|
||||
self._analyze_index_size(index_path)
|
||||
|
||||
return index_path
|
||||
|
||||
def build_non_compact_index(
|
||||
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
|
||||
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Ensure embeddings are saved (npy + pickle)
|
||||
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||
if not npy_path.exists():
|
||||
np.save(npy_path, embeddings)
|
||||
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||
# Prepare ids in same order as passages_file
|
||||
ids: list[str] = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
rec = json.loads(line)
|
||||
ids.append(str(rec["id"]))
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||
if not pkl_path.exists():
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||
|
||||
# Initialize builder - non-compact without recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=False, # Store embeddings (no recompute needed)
|
||||
is_compact=False, # Store full index (not pruned)
|
||||
distance_metric="cosine",
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Add passages - embeddings will be loaded from file
|
||||
print("📚 Adding passages...")
|
||||
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||
|
||||
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
|
||||
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
print(f"✅ Non-compact index built in {build_time:.2f}s")
|
||||
|
||||
# Analyze index size
|
||||
self._analyze_index_size(index_path)
|
||||
|
||||
return index_path
|
||||
|
||||
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
|
||||
"""Helper to add passages with pre-computed CLIP embeddings"""
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in tqdm(f, desc="Adding passages"):
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
|
||||
# Add image metadata - LEANN will handle embeddings separately
|
||||
# Note: We store image metadata and caption text for searchability
|
||||
# Important: ensure passage ID in metadata matches vector ID
|
||||
builder.add_text(
|
||||
text=passage["text"], # Image caption for searchability
|
||||
metadata={**passage["metadata"], "id": passage["id"]},
|
||||
)
|
||||
|
||||
def _analyze_index_size(self, index_path: str):
|
||||
"""Analyze index file sizes"""
|
||||
print("📏 Analyzing index sizes...")
|
||||
|
||||
index_path = Path(index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.name # e.g., laion_index.leann
|
||||
index_prefix = index_path.stem # e.g., laion_index
|
||||
|
||||
files = [
|
||||
(f"{index_prefix}.index", ".index", "core"),
|
||||
(f"{index_name}.meta.json", ".meta.json", "core"),
|
||||
(f"{index_name}.ids.txt", ".ids.txt", "core"),
|
||||
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
|
||||
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
|
||||
]
|
||||
|
||||
def _fmt_size(bytes_val: int) -> str:
|
||||
if bytes_val < 1024:
|
||||
return f"{bytes_val} B"
|
||||
kb = bytes_val / 1024
|
||||
if kb < 1024:
|
||||
return f"{kb:.1f} KB"
|
||||
mb = kb / 1024
|
||||
if mb < 1024:
|
||||
return f"{mb:.2f} MB"
|
||||
gb = mb / 1024
|
||||
return f"{gb:.2f} GB"
|
||||
|
||||
total_index_only_mb = 0.0
|
||||
total_all_mb = 0.0
|
||||
for filename, label, group in files:
|
||||
file_path = index_dir / filename
|
||||
if file_path.exists():
|
||||
size_bytes = file_path.stat().st_size
|
||||
print(f" {label}: {_fmt_size(size_bytes)}")
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
total_all_mb += size_mb
|
||||
if group == "core":
|
||||
total_index_only_mb += size_mb
|
||||
else:
|
||||
print(f" {label}: (missing)")
|
||||
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
|
||||
print(f" Total (including passages): {total_all_mb:.2f} MB")
|
||||
|
||||
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
|
||||
"""Create evaluation queries from captions"""
|
||||
print(f"📝 Creating {num_queries} evaluation queries...")
|
||||
|
||||
# Sample random captions as queries
|
||||
import random
|
||||
|
||||
random.seed(42) # For reproducibility
|
||||
|
||||
query_samples = random.sample(samples, min(num_queries, len(samples)))
|
||||
|
||||
queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||
with open(queries_file, "w", encoding="utf-8") as f:
|
||||
for sample in query_samples:
|
||||
query = {
|
||||
"id": sample["id"],
|
||||
"query": sample["caption"],
|
||||
"ground_truth_id": sample["id"], # For potential recall evaluation
|
||||
}
|
||||
f.write(json.dumps(query) + "\n")
|
||||
|
||||
print(f"✅ Created {len(query_samples)} evaluation queries")
|
||||
return queries_file
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
|
||||
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
|
||||
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
|
||||
parser.add_argument(
|
||||
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
|
||||
)
|
||||
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🚀 Setting up LAION Multimodal Benchmark")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Initialize setup
|
||||
setup = LAIONSetup(args.data_dir)
|
||||
|
||||
# Step 1: Download LAION subset
|
||||
if not args.skip_download:
|
||||
print("\n📦 Step 1: Download LAION subset")
|
||||
samples = setup.download_laion_subset(args.num_samples)
|
||||
|
||||
# Step 2: Generate CLIP image embeddings
|
||||
print("\n🔍 Step 2: Generate CLIP image embeddings")
|
||||
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
|
||||
|
||||
# Step 3: Create LEANN passages (image metadata with embeddings)
|
||||
print("\n📝 Step 3: Create LEANN passages")
|
||||
passages_file = setup.create_leann_passages(valid_samples)
|
||||
else:
|
||||
print("⏭️ Skipping LAION dataset download")
|
||||
# Load existing data
|
||||
passages_file = setup.data_dir / "laion_passages.jsonl"
|
||||
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
|
||||
|
||||
if not passages_file.exists() or not embeddings_file.exists():
|
||||
raise FileNotFoundError(
|
||||
"Passages or embeddings file not found. Run without --skip-download first."
|
||||
)
|
||||
|
||||
embeddings = np.load(embeddings_file)
|
||||
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
|
||||
|
||||
# Step 4: Build LEANN indexes (both compact and non-compact)
|
||||
if not args.skip_build:
|
||||
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
|
||||
|
||||
# Build compact index (production mode - small, recompute required)
|
||||
compact_index_path = args.index_path
|
||||
print(f"Building compact index: {compact_index_path}")
|
||||
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
|
||||
|
||||
# Build non-compact index (comparison mode - large, fast search)
|
||||
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
|
||||
print(f"Building non-compact index: {non_compact_index_path}")
|
||||
setup.build_non_compact_index(
|
||||
passages_file, embeddings, non_compact_index_path, args.backend
|
||||
)
|
||||
|
||||
# Step 5: Build FAISS flat baseline
|
||||
print("\n🔨 Step 5: Build FAISS flat baseline")
|
||||
if not args.skip_download:
|
||||
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||
else:
|
||||
# Load valid_samples from passages file for FAISS baseline
|
||||
valid_samples = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
|
||||
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||
|
||||
# Step 6: Create evaluation queries
|
||||
print("\n📝 Step 6: Create evaluation queries")
|
||||
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
|
||||
else:
|
||||
print("⏭️ Skipping index building")
|
||||
baseline_path = "data/baseline/faiss_index.bin"
|
||||
queries_file = setup.data_dir / "evaluation_queries.jsonl"
|
||||
|
||||
print("\n🎉 Setup completed successfully!")
|
||||
print("📊 Summary:")
|
||||
if not args.skip_download:
|
||||
print(f" Downloaded samples: {len(samples)}")
|
||||
print(f" Valid samples with embeddings: {len(valid_samples)}")
|
||||
else:
|
||||
print(f" Loaded {len(embeddings)} embeddings")
|
||||
|
||||
if not args.skip_build:
|
||||
print(f" Compact index: {compact_index_path}")
|
||||
print(f" Non-compact index: {non_compact_index_path}")
|
||||
print(f" FAISS baseline: {baseline_path}")
|
||||
print(f" Queries: {queries_file}")
|
||||
|
||||
print("\n🔧 Next steps:")
|
||||
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
|
||||
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
|
||||
else:
|
||||
print(" Skipped building indexes")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Setup interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Setup failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
342
benchmarks/llm_utils.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
|
||||
def is_qwen3_model(model_name):
|
||||
"""Check if model is Qwen3"""
|
||||
return "Qwen3" in model_name or "qwen3" in model_name.lower()
|
||||
|
||||
|
||||
def is_qwen_vl_model(model_name):
|
||||
"""Check if model is Qwen2.5-VL"""
|
||||
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
|
||||
|
||||
|
||||
def apply_qwen3_chat_template(tokenizer, prompt):
|
||||
"""Apply Qwen3 chat template with thinking enabled"""
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
|
||||
|
||||
def extract_thinking_answer(response):
|
||||
"""Extract final answer from Qwen3 thinking model response"""
|
||||
if "<think>" in response and "</think>" in response:
|
||||
try:
|
||||
think_end = response.index("</think>") + len("</think>")
|
||||
final_answer = response[think_end:].strip()
|
||||
return final_answer
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return response.strip()
|
||||
|
||||
|
||||
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load HuggingFace model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading HF: {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load vLLM model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading vLLM: {model_name}")
|
||||
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
|
||||
|
||||
# Qwen3 specific config
|
||||
if is_qwen3_model(model_name):
|
||||
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
|
||||
max_tokens = 2048
|
||||
else:
|
||||
stop_tokens = None
|
||||
max_tokens = 1024
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
|
||||
return llm, sampling_params
|
||||
|
||||
|
||||
def generate_hf(tokenizer, model, prompt, max_tokens=None):
|
||||
"""Generate with HF - supports Qwen3 thinking models"""
|
||||
model_name = getattr(model, "name_or_path", "unknown")
|
||||
is_qwen3 = is_qwen3_model(model_name)
|
||||
|
||||
# Apply chat template for Qwen3
|
||||
if is_qwen3:
|
||||
prompt = apply_qwen3_chat_template(tokenizer, prompt)
|
||||
max_tokens = max_tokens or 2048
|
||||
else:
|
||||
max_tokens = max_tokens or 1024
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=0.7,
|
||||
do_sample=True,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
response = response[len(prompt) :].strip()
|
||||
|
||||
# Extract final answer for thinking models
|
||||
if is_qwen3:
|
||||
return extract_thinking_answer(response)
|
||||
return response
|
||||
|
||||
|
||||
def generate_vllm(llm, sampling_params, prompt):
|
||||
"""Generate with vLLM - supports Qwen3 thinking models"""
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
response = outputs[0].outputs[0].text.strip()
|
||||
|
||||
# Extract final answer for Qwen3 thinking models
|
||||
model_name = str(llm.llm_engine.model_config.model)
|
||||
if is_qwen3_model(model_name):
|
||||
return extract_thinking_answer(response)
|
||||
return response
|
||||
|
||||
|
||||
def create_prompt(context, query, domain="default"):
|
||||
"""Create RAG prompt"""
|
||||
if domain == "emails":
|
||||
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
elif domain == "finance":
|
||||
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
elif domain == "multimodal":
|
||||
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
else:
|
||||
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
|
||||
|
||||
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
|
||||
"""Simple RAG evaluation with timing"""
|
||||
search_times = []
|
||||
gen_times = []
|
||||
results = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
# Search
|
||||
start = time.time()
|
||||
docs = searcher.search(query, top_k=top_k, complexity=complexity)
|
||||
search_time = time.time() - start
|
||||
|
||||
# Generate
|
||||
context = "\n\n".join([doc.text for doc in docs])
|
||||
prompt = create_prompt(context, query, domain)
|
||||
|
||||
start = time.time()
|
||||
response = llm_func(prompt)
|
||||
gen_time = time.time() - start
|
||||
|
||||
search_times.append(search_time)
|
||||
gen_times.append(gen_time)
|
||||
results.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||
|
||||
return {
|
||||
"avg_search_time": sum(search_times) / len(search_times),
|
||||
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
|
||||
"""Load Qwen2.5-VL multimodal model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
|
||||
|
||||
# Fallback to specific class
|
||||
try:
|
||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e2:
|
||||
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
|
||||
|
||||
|
||||
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
|
||||
"""Generate with Qwen2.5-VL multimodal model"""
|
||||
from PIL import Image
|
||||
|
||||
# Prepare inputs
|
||||
if image_path:
|
||||
image = Image.open(image_path)
|
||||
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
||||
else:
|
||||
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
|
||||
)
|
||||
|
||||
# Decode response
|
||||
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
|
||||
response = processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
|
||||
"""Create prompt for multimodal RAG"""
|
||||
if task_type == "images":
|
||||
return f"""Based on the retrieved images and their descriptions, answer the following question.
|
||||
|
||||
Retrieved Image Descriptions:
|
||||
{context}
|
||||
|
||||
Question: {query}
|
||||
|
||||
Provide a detailed answer based on the visual content described above."""
|
||||
|
||||
return f"Context: {context}\nQuestion: {query}\nAnswer:"
|
||||
|
||||
|
||||
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
|
||||
"""Evaluate multimodal RAG with Qwen2.5-VL"""
|
||||
search_times = []
|
||||
gen_times = []
|
||||
results = []
|
||||
|
||||
for i, query_item in enumerate(queries):
|
||||
# Handle both string and dict formats for queries
|
||||
if isinstance(query_item, dict):
|
||||
query = query_item.get("query", "")
|
||||
image_path = query_item.get("image_path") # Optional reference image
|
||||
else:
|
||||
query = str(query_item)
|
||||
image_path = None
|
||||
|
||||
# Search
|
||||
start_time = time.time()
|
||||
search_results = searcher.search(query, top_k=3, complexity=complexity)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
|
||||
# Prepare context from search results
|
||||
context_parts = []
|
||||
for result in search_results:
|
||||
context_parts.append(f"- {result.text}")
|
||||
context = "\n".join(context_parts)
|
||||
|
||||
# Generate with multimodal model
|
||||
start_time = time.time()
|
||||
if processor and model:
|
||||
prompt = create_multimodal_prompt(context, query, context_parts)
|
||||
response = generate_qwen_vl(processor, model, prompt, image_path)
|
||||
else:
|
||||
response = f"Context: {context}"
|
||||
gen_time = time.time() - start_time
|
||||
|
||||
gen_times.append(gen_time)
|
||||
results.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||
|
||||
return {
|
||||
"avg_search_time": sum(search_times) / len(search_times),
|
||||
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||
"results": results,
|
||||
}
|
||||
@@ -53,7 +53,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
print(
|
||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||
)
|
||||
print("uv pip install -e '.[dev]'")
|
||||
print("uv sync --only-group dev")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during data download: {e}")
|
||||
|
||||
143
benchmarks/update/README.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# Update Benchmarks
|
||||
|
||||
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
||||
search” pipeline under different assumptions:
|
||||
|
||||
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
||||
settings influence incremental `add()` latency when embeddings are fetched
|
||||
over the ZMQ embedding server.
|
||||
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
||||
against an offline approach that keeps the graph static and fuses results.
|
||||
|
||||
Both suites build a non-compact, `is_recompute=True` index so that new
|
||||
embeddings are pulled from the embedding server. Benchmark outputs are written
|
||||
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
||||
|
||||
## Benchmarks
|
||||
|
||||
### 1. HNSW RNG Recompute Benchmark
|
||||
|
||||
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
||||
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
||||
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
||||
is enabled:
|
||||
|
||||
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
||||
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
||||
| `baseline` | Enabled | Enabled | Enabled |
|
||||
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
||||
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
||||
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
||||
|
||||
For each scenario the script:
|
||||
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
||||
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
||||
3. Appends the requested updates using the scenario’s RNG flags.
|
||||
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
||||
timings before appending a row to the CSV output.
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
||||
LEANN_LOG_LEVEL=INFO \
|
||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||
--runs 1 \
|
||||
--index-path .leann/bench/test.leann \
|
||||
--initial-files data/PrideandPrejudice.txt \
|
||||
--update-files data/huawei_pangu.md \
|
||||
--max-initial 300 \
|
||||
--max-updates 1 \
|
||||
--add-timeout 120
|
||||
```
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
||||
(including ms/passage) for each run.
|
||||
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
||||
`LEANN_HNSW_LOG_PATH`).
|
||||
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
||||
|
||||
### 2. Sequential vs. Offline Update Benchmark
|
||||
|
||||
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
||||
same dataset:
|
||||
|
||||
- **Scenario A – Sequential Update**
|
||||
- Start an embedding server.
|
||||
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
||||
mutates the HNSW graph.
|
||||
- After all inserts, run a search on the updated graph.
|
||||
- Metrics recorded: update time (`add_total_s`), post-update search time
|
||||
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
||||
latency.
|
||||
|
||||
- **Scenario B – Offline Embedding + Concurrent Search**
|
||||
- Stop Scenario A’s server and start a fresh embedding server.
|
||||
- Spawn two threads: one generates embeddings for the new passages offline
|
||||
(graph unchanged); the other computes the query embedding and searches the
|
||||
existing graph.
|
||||
- Merge offline similarities with the graph search results to emulate late
|
||||
fusion, then report the merged top‑k preview.
|
||||
- Metrics recorded: embedding time (`emb_time_s`), search time
|
||||
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
||||
|
||||
**Run (both scenarios):**
|
||||
```bash
|
||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||
--index-path .leann/bench/offline_vs_update.leann \
|
||||
--max-initial 300 \
|
||||
--num-updates 1
|
||||
```
|
||||
|
||||
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
||||
print timing summaries to stdout and append the results to CSV.
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
||||
Scenario A and B.
|
||||
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
||||
checks.
|
||||
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
||||
|
||||
### 3. Visualisation
|
||||
|
||||
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
||||
benchmark into a single two-panel plot.
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
uv run -m benchmarks.update.plot_bench_results \
|
||||
--csv benchmarks/update/bench_results.csv \
|
||||
--csv-right benchmarks/update/offline_vs_update.csv \
|
||||
--out benchmarks/update/bench_latency_from_csv.png
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
||||
- `--csv` – RNG benchmark results CSV (left panel).
|
||||
- `--csv-right` – Update strategy results CSV (right panel).
|
||||
- `--out` – Output image path (PNG/PDF supported).
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
||||
suites.
|
||||
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
||||
slides/papers.
|
||||
|
||||
## Parameters & Environment
|
||||
|
||||
### Common CLI Flags
|
||||
- `--max-initial` – Number of initial passages used to seed the index.
|
||||
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
||||
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
||||
- `--runs` – Number of repetitions (RNG benchmark only).
|
||||
|
||||
### Environment Variables
|
||||
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
||||
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
||||
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
||||
execution of the embedding model.
|
||||
|
||||
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
||||
multiple RNG strategies, and evaluate whether sequential updates or offline
|
||||
fusion better match your latency/accuracy trade-offs.
|
||||
16
benchmarks/update/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Benchmarks for LEANN update workflows."""
|
||||
|
||||
# Expose helper to locate repository root for other modules that need it.
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def find_repo_root() -> Path:
|
||||
"""Return the project root containing pyproject.toml."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
return current.parents[1]
|
||||
|
||||
|
||||
__all__ = ["find_repo_root"]
|
||||
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
@@ -0,0 +1,804 @@
|
||||
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
||||
embedding recomputation.
|
||||
|
||||
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
||||
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
||||
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
||||
when RNG pruning is fully enabled vs. partially/fully disabled.
|
||||
|
||||
Example usage (run from the repo root; downloads the model on first run)::
|
||||
|
||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||
--index-path .leann/bench/leann-demo.leann \
|
||||
--runs 1
|
||||
|
||||
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
||||
if you want a larger or different workload, and change the embedding model via
|
||||
``--model-name``.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import zmq
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_project_directory
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def _find_repo_root() -> Path:
|
||||
"""Locate project root by walking up until pyproject.toml is found."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
# Fallback: assume repo is two levels up (../..)
|
||||
return current.parents[2]
|
||||
|
||||
|
||||
REPO_ROOT = _find_repo_root()
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from apps.chunking import create_text_chunks # noqa: E402
|
||||
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
if limit is not None:
|
||||
cleaned = cleaned[:limit]
|
||||
return cleaned
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
ef_construction: int,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=True,
|
||||
distance_metric=distance_metric,
|
||||
backend_kwargs={
|
||||
"distance_metric": distance_metric,
|
||||
"is_compact": False,
|
||||
"is_recompute": True,
|
||||
"efConstruction": ef_construction,
|
||||
},
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
||||
return [{"text": text, "metadata": {}} for text in paragraphs]
|
||||
|
||||
|
||||
def benchmark_update_with_mode(
|
||||
index_path: Path,
|
||||
new_chunks: list[dict[str, Any]],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
disable_forward_rng: bool,
|
||||
disable_reverse_rng: bool,
|
||||
server_port: int,
|
||||
add_timeout: int,
|
||||
ef_construction: int,
|
||||
) -> tuple[float, float]:
|
||||
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
with open(offset_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
existing_ids = set(offset_map.keys())
|
||||
|
||||
valid_chunks: list[dict[str, Any]] = []
|
||||
for chunk in new_chunks:
|
||||
text = chunk.get("text", "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
continue
|
||||
metadata = chunk.setdefault("metadata", {})
|
||||
passage_id = chunk.get("id") or metadata.get("id")
|
||||
if passage_id and passage_id in existing_ids:
|
||||
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||
valid_chunks.append(chunk)
|
||||
|
||||
if not valid_chunks:
|
||||
raise ValueError("No valid chunks to append.")
|
||||
|
||||
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
|
||||
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
if distance_metric == "cosine":
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
|
||||
index = faiss.read_index(str(index_file))
|
||||
index.is_recompute = True
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
try:
|
||||
storage_index.ntotal = index.ntotal
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
||||
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
||||
if ef_construction is not None:
|
||||
index.hnsw.efConstruction = ef_construction
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
||||
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
||||
logger.info(
|
||||
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
||||
disable_forward_rng,
|
||||
disable_reverse_rng,
|
||||
applied_forward,
|
||||
applied_reverse,
|
||||
)
|
||||
|
||||
base_id = index.ntotal
|
||||
for offset, chunk in enumerate(valid_chunks):
|
||||
new_id = str(base_id + offset)
|
||||
chunk.setdefault("metadata", {})["id"] = new_id
|
||||
chunk["id"] = new_id
|
||||
|
||||
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||
offset_map_backup = offset_map.copy()
|
||||
|
||||
try:
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for chunk in valid_chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
server_started, actual_port = server_manager.start_server(
|
||||
port=server_port,
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=distance_metric,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError("Failed to start embedding server.")
|
||||
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(actual_port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(actual_port)
|
||||
|
||||
_warmup_embedding_server(actual_port)
|
||||
|
||||
total_start = time.time()
|
||||
add_elapsed = 0.0
|
||||
|
||||
try:
|
||||
import signal
|
||||
|
||||
def _timeout_handler(signum, frame):
|
||||
raise TimeoutError("incremental add timed out")
|
||||
|
||||
if add_timeout > 0:
|
||||
signal.signal(signal.SIGALRM, _timeout_handler)
|
||||
signal.alarm(add_timeout)
|
||||
|
||||
add_start = time.time()
|
||||
for i in range(embeddings.shape[0]):
|
||||
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||
add_elapsed = time.time() - add_start
|
||||
if add_timeout > 0:
|
||||
signal.alarm(0)
|
||||
faiss.write_index(index, str(index_file))
|
||||
finally:
|
||||
server_manager.stop_server()
|
||||
|
||||
except TimeoutError:
|
||||
raise
|
||||
except Exception:
|
||||
if passages_file.exists():
|
||||
with open(passages_file, "rb+") as f:
|
||||
f.truncate(rollback_size)
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map_backup, f)
|
||||
raise
|
||||
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
# Reset toggles so the index on disk returns to baseline behaviour.
|
||||
try:
|
||||
index.hnsw.set_disable_rng_during_add(False)
|
||||
index.hnsw.set_disable_reverse_prune(False)
|
||||
except AttributeError:
|
||||
pass
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
total_elapsed = time.time() - total_start
|
||||
|
||||
return total_elapsed, add_elapsed
|
||||
|
||||
|
||||
def _total_zmq_nodes(log_path: Path) -> int:
|
||||
if not log_path.exists():
|
||||
return 0
|
||||
with log_path.open("r", encoding="utf-8") as log_file:
|
||||
text = log_file.read()
|
||||
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
||||
|
||||
|
||||
def _warmup_embedding_server(port: int) -> None:
|
||||
"""Send a dummy REQ so the embedding server loads its model."""
|
||||
ctx = zmq.Context()
|
||||
try:
|
||||
sock = ctx.socket(zmq.REQ)
|
||||
sock.setsockopt(zmq.LINGER, 0)
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||
sock.connect(f"tcp://127.0.0.1:{port}")
|
||||
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
||||
sock.send(payload)
|
||||
try:
|
||||
sock.recv()
|
||||
except zmq.error.Again:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
ctx.term()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/bench/leann-demo.leann"),
|
||||
help="Output index base path (without extension).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
help="Files used to build the initial index.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
help="Files appended during the benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
help="Embedding model used for build/update.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
default="sentence-transformers",
|
||||
help="Embedding mode passed to LeannBuilder/embedding server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
default="mips",
|
||||
choices=["mips", "l2", "cosine"],
|
||||
help="Distance metric for HNSW backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ef-construction",
|
||||
type=int,
|
||||
default=200,
|
||||
help="efConstruction setting for initial build.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
type=int,
|
||||
default=5557,
|
||||
help="Port for the real embedding server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-initial",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Optional cap on initial passages (after chunking).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-updates",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Optional cap on update passages (after chunking).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-timeout",
|
||||
type=int,
|
||||
default=900,
|
||||
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-path",
|
||||
type=Path,
|
||||
default=Path("bench_latency.png"),
|
||||
help="Where to save the latency bar plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--broken-y",
|
||||
action="store_true",
|
||||
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lower-cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upper-start-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/bench_results.csv"),
|
||||
help="Where to append per-scenario results as CSV.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
||||
if not update_paragraphs:
|
||||
raise ValueError("No update passages found; please provide --update-files with content.")
|
||||
|
||||
update_chunks = prepare_new_chunks(update_paragraphs)
|
||||
ensure_index_dir(args.index_path)
|
||||
|
||||
scenarios = [
|
||||
("baseline", False, False, True),
|
||||
("no_cache_baseline", False, False, False),
|
||||
("disable_forward_rng", True, False, True),
|
||||
("disable_forward_and_reverse_rng", True, True, True),
|
||||
]
|
||||
|
||||
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
||||
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
||||
|
||||
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
||||
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
|
||||
# CSV setup
|
||||
import csv
|
||||
|
||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||
csv_fields = [
|
||||
"run_id",
|
||||
"scenario",
|
||||
"cache_enabled",
|
||||
"ef_construction",
|
||||
"max_initial",
|
||||
"max_updates",
|
||||
"total_time_s",
|
||||
"add_only_s",
|
||||
"latency_ms_per_passage",
|
||||
"zmq_nodes",
|
||||
"stageA_time_s",
|
||||
"stageBC_time_s",
|
||||
"model_name",
|
||||
"embedding_mode",
|
||||
"distance_metric",
|
||||
]
|
||||
# Create CSV with header if missing
|
||||
if args.csv_path:
|
||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writeheader()
|
||||
|
||||
for run in range(args.runs):
|
||||
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
||||
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
||||
print(f"\nScenario: {name}")
|
||||
cleanup_index_files(args.index_path)
|
||||
if log_path.exists():
|
||||
try:
|
||||
log_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
||||
build_initial_index(
|
||||
args.index_path,
|
||||
initial_paragraphs,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
args.ef_construction,
|
||||
)
|
||||
|
||||
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
||||
|
||||
try:
|
||||
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
||||
args.index_path,
|
||||
update_chunks,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
disable_forward,
|
||||
disable_reverse,
|
||||
args.server_port,
|
||||
args.add_timeout,
|
||||
args.ef_construction,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
print(f"Scenario {name} timed out: {exc}")
|
||||
continue
|
||||
|
||||
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
||||
if curr_size < prev_size:
|
||||
prev_size = 0
|
||||
zmq_count = 0
|
||||
if log_path.exists():
|
||||
with log_path.open("r", encoding="utf-8") as log_file:
|
||||
log_file.seek(prev_size)
|
||||
new_entries = log_file.read()
|
||||
zmq_count = sum(
|
||||
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
||||
)
|
||||
stageA = sum(
|
||||
float(x)
|
||||
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
||||
)
|
||||
stageBC = sum(
|
||||
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
||||
)
|
||||
else:
|
||||
stageA = 0.0
|
||||
stageBC = 0.0
|
||||
|
||||
per_chunk = add_elapsed / len(update_chunks)
|
||||
print(
|
||||
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
||||
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
||||
)
|
||||
print(f"ZMQ node fetch total: {zmq_count}")
|
||||
results_total[name].append(total_elapsed)
|
||||
results_add[name].append(add_elapsed)
|
||||
results_zmq[name].append(zmq_count)
|
||||
results_ms_per_passage[name].append(per_chunk * 1e3)
|
||||
results_stageA[name].append(stageA)
|
||||
results_stageBC[name].append(stageBC)
|
||||
|
||||
# Append row to CSV
|
||||
if args.csv_path:
|
||||
row = {
|
||||
"run_id": run_id,
|
||||
"scenario": name,
|
||||
"cache_enabled": 1 if cache_enabled else 0,
|
||||
"ef_construction": args.ef_construction,
|
||||
"max_initial": args.max_initial,
|
||||
"max_updates": args.max_updates,
|
||||
"total_time_s": round(total_elapsed, 6),
|
||||
"add_only_s": round(add_elapsed, 6),
|
||||
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
||||
"zmq_nodes": int(zmq_count),
|
||||
"stageA_time_s": round(stageA, 6),
|
||||
"stageBC_time_s": round(stageBC, 6),
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row)
|
||||
|
||||
print("\n=== Summary ===")
|
||||
for name in results_add:
|
||||
add_values = results_add[name]
|
||||
total_values = results_total[name]
|
||||
zmq_values = results_zmq[name]
|
||||
latency_values = results_ms_per_passage[name]
|
||||
if not add_values:
|
||||
print(f"{name}: no successful runs")
|
||||
continue
|
||||
avg_add = sum(add_values) / len(add_values)
|
||||
avg_total = sum(total_values) / len(total_values)
|
||||
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
||||
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
||||
runs = len(add_values)
|
||||
print(
|
||||
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
||||
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
||||
)
|
||||
|
||||
if args.plot_path:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
labels = [name for name, *_ in scenarios]
|
||||
values = [
|
||||
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
||||
if results_ms_per_passage[name]
|
||||
else 0.0
|
||||
for name in labels
|
||||
]
|
||||
|
||||
def _auto_cap(vals: list[float]) -> float | None:
|
||||
s = sorted(vals, reverse=True)
|
||||
if len(s) < 2:
|
||||
return None
|
||||
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
||||
return s[1] * 1.1
|
||||
return None
|
||||
|
||||
def _fmt_ms(v: float) -> str:
|
||||
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
||||
|
||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||
|
||||
if args.broken_y:
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.2, lower_cap * 1.02)
|
||||
)
|
||||
ymax = max(values) * 1.10 if values else 1.0
|
||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||
2,
|
||||
1,
|
||||
sharex=True,
|
||||
figsize=(7.4, 5.0),
|
||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
||||
)
|
||||
x = list(range(len(labels)))
|
||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_bottom.set_ylim(0, lower_cap)
|
||||
ax_top.set_ylim(upper_start, ymax)
|
||||
for i, v in enumerate(values):
|
||||
if v <= lower_cap:
|
||||
ax_bottom.text(
|
||||
i,
|
||||
v + lower_cap * 0.02,
|
||||
_fmt_ms(v),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
else:
|
||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
ax_top.spines["bottom"].set_visible(False)
|
||||
ax_bottom.spines["top"].set_visible(False)
|
||||
ax_top.tick_params(labeltop=False)
|
||||
ax_bottom.xaxis.tick_bottom()
|
||||
d = 0.015
|
||||
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
||||
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||
kwargs.update({"transform": ax_bottom.transAxes})
|
||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||
ax_bottom.set_xticks(range(len(labels)))
|
||||
ax_bottom.set_xticklabels(labels)
|
||||
ax = ax_bottom
|
||||
else:
|
||||
cap = args.cap_y or _auto_cap(values)
|
||||
plt.figure(figsize=(7.2, 4.2))
|
||||
ax = plt.gca()
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = []
|
||||
for i, (v, show) in enumerate(zip(values, show_vals)):
|
||||
b = ax.bar(i, show, color=colors[i], width=0.8)
|
||||
bars.append(b[0])
|
||||
if v > cap:
|
||||
bars[-1].set_hatch("//")
|
||||
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
else:
|
||||
ax.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
_fmt_ms(v),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax.set_ylim(0, cap * 1.10)
|
||||
ax.plot(
|
||||
[0.02 - 0.02, 0.02 + 0.02],
|
||||
[0.98 + 0.02, 0.98 - 0.02],
|
||||
transform=ax.transAxes,
|
||||
color="k",
|
||||
lw=1,
|
||||
)
|
||||
ax.plot(
|
||||
[0.98 - 0.02, 0.98 + 0.02],
|
||||
[0.98 + 0.02, 0.98 - 0.02],
|
||||
transform=ax.transAxes,
|
||||
color="k",
|
||||
lw=1,
|
||||
)
|
||||
if any(v > cap for v in values):
|
||||
ax.legend(
|
||||
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
||||
)
|
||||
ax.set_xticks(range(len(labels)))
|
||||
ax.set_xticklabels(labels)
|
||||
else:
|
||||
ax.bar(labels, values, color=colors[: len(labels)])
|
||||
for idx, val in enumerate(values):
|
||||
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
||||
|
||||
plt.ylabel("Average add latency (ms per passage)")
|
||||
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(args.plot_path)
|
||||
print(f"Saved latency bar plot to {args.plot_path}")
|
||||
# ZMQ time split (Stage A vs B/C)
|
||||
try:
|
||||
plt.figure(figsize=(6, 4))
|
||||
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
||||
bc_vals = [
|
||||
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
||||
]
|
||||
ind = range(len(labels))
|
||||
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
||||
plt.bar(
|
||||
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
||||
)
|
||||
plt.xticks(list(ind), labels, rotation=10)
|
||||
plt.ylabel("Server ZMQ time (s)")
|
||||
plt.title(
|
||||
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
||||
)
|
||||
plt.legend()
|
||||
out2 = args.plot_path.with_name(
|
||||
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
||||
)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out2)
|
||||
print(f"Saved ZMQ time split plot to {out2}")
|
||||
except Exception as e:
|
||||
print("Failed to plot ZMQ split:", e)
|
||||
except ImportError:
|
||||
print("matplotlib not available; skipping plot generation")
|
||||
|
||||
# leave the last build on disk for inspection
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
benchmarks/update/bench_results.csv
Normal file
@@ -0,0 +1,5 @@
|
||||
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
||||
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
|
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
@@ -0,0 +1,704 @@
|
||||
"""
|
||||
Compare two latency models for small incremental updates vs. search:
|
||||
|
||||
Scenario A (sequential update then search):
|
||||
- Build initial HNSW (is_recompute=True)
|
||||
- Start embedding server (ZMQ) for recompute
|
||||
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
||||
- Then run a search query on the updated index
|
||||
- Report total time = sum(add_i) + search_time, with breakdowns
|
||||
|
||||
Scenario B (offline embeds + concurrent search; no graph updates):
|
||||
- Do NOT insert the N passages into the graph
|
||||
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
||||
embedding and run a search on the existing index
|
||||
- After both finish, compute similarity between the query embedding and the N
|
||||
new passage embeddings, merge with the index search results by score, and
|
||||
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
||||
|
||||
This script reuses the model/data loading conventions of
|
||||
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
||||
comparison for the two execution strategies above.
|
||||
|
||||
Example (from the repository root):
|
||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||
--index-path .leann/bench/offline_vs_update.leann \
|
||||
--max-initial 300 --num-updates 5 --k 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import psutil # type: ignore
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_project_directory
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def _find_repo_root() -> Path:
|
||||
"""Locate project root by walking up until pyproject.toml is found."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
# Fallback: assume repo is two levels up (../..)
|
||||
return current.parents[2]
|
||||
|
||||
|
||||
REPO_ROOT = _find_repo_root()
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from apps.chunking import create_text_chunks # noqa: E402
|
||||
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
if limit is not None:
|
||||
cleaned = cleaned[:limit]
|
||||
return cleaned
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
ef_construction: int,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=True,
|
||||
distance_metric=distance_metric,
|
||||
backend_kwargs={
|
||||
"distance_metric": distance_metric,
|
||||
"is_compact": False,
|
||||
"is_recompute": True,
|
||||
"efConstruction": ef_construction,
|
||||
},
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
||||
if metric == "cosine":
|
||||
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
||||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
vecs = vecs / norms
|
||||
return vecs
|
||||
|
||||
|
||||
def _read_index_for_search(index_path: Path) -> Any:
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
# Force-disable experimental disk cache when loading the index so that
|
||||
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
||||
cfg = faiss.HNSWIndexConfig()
|
||||
cfg.is_recompute = True
|
||||
if hasattr(cfg, "disk_cache_ratio"):
|
||||
cfg.disk_cache_ratio = 0.0
|
||||
if hasattr(cfg, "external_storage_path"):
|
||||
cfg.external_storage_path = None
|
||||
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
||||
index = faiss.read_index(str(index_file), io_flags, cfg)
|
||||
# ensure recompute mode persists after reload
|
||||
try:
|
||||
index.is_recompute = True
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
actual_ntotal = index.hnsw.levels.size()
|
||||
except AttributeError:
|
||||
actual_ntotal = index.ntotal
|
||||
if actual_ntotal != index.ntotal:
|
||||
print(
|
||||
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
||||
flush=True,
|
||||
)
|
||||
index.ntotal = actual_ntotal
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
return index
|
||||
|
||||
|
||||
def _append_passages_for_updates(
|
||||
meta_path: Path,
|
||||
start_id: int,
|
||||
texts: list[str],
|
||||
) -> list[str]:
|
||||
"""Append update passages so the embedding server can serve recompute fetches."""
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
index_dir = meta_path.parent
|
||||
meta_name = meta_path.name
|
||||
if not meta_name.endswith(".meta.json"):
|
||||
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
||||
index_base = meta_name[: -len(".meta.json")]
|
||||
|
||||
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
||||
offsets_file = index_dir / f"{index_base}.passages.idx"
|
||||
|
||||
if not passages_file.exists() or not offsets_file.exists():
|
||||
raise FileNotFoundError(
|
||||
"Passage store missing; cannot register update passages for recompute mode."
|
||||
)
|
||||
|
||||
with open(offsets_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
|
||||
assigned_ids: list[str] = []
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for i, text in enumerate(texts):
|
||||
passage_id = str(start_id + i)
|
||||
offset = f.tell()
|
||||
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
offset_map[passage_id] = offset
|
||||
assigned_ids.append(passage_id)
|
||||
|
||||
with open(offsets_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
try:
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
meta = {}
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
return assigned_ids
|
||||
|
||||
|
||||
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
q = np.ascontiguousarray(q, dtype=np.float32)
|
||||
distances = np.zeros((1, k), dtype=np.float32)
|
||||
indices = np.zeros((1, k), dtype=np.int64)
|
||||
index.search(
|
||||
1,
|
||||
faiss.swig_ptr(q),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(indices),
|
||||
)
|
||||
return distances[0], indices[0]
|
||||
|
||||
|
||||
def _score_for_metric(dist: float, metric: str) -> float:
|
||||
# Convert FAISS distance to a "higher is better" score
|
||||
if metric in ("mips", "cosine"):
|
||||
return float(dist)
|
||||
# l2 distance (smaller better) -> negative distance as score
|
||||
return -float(dist)
|
||||
|
||||
|
||||
def _merge_results(
|
||||
index_results: tuple[np.ndarray, np.ndarray],
|
||||
offline_scores: list[tuple[int, float]],
|
||||
k: int,
|
||||
metric: str,
|
||||
) -> list[tuple[str, float]]:
|
||||
distances, indices = index_results
|
||||
merged: list[tuple[str, float]] = []
|
||||
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
||||
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
||||
for j, s in offline_scores:
|
||||
merged.append((f"offline:{j}", s))
|
||||
merged.sort(key=lambda x: x[1], reverse=True)
|
||||
return merged[:k]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioResult:
|
||||
name: str
|
||||
update_total_s: float
|
||||
search_s: float
|
||||
overall_s: float
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/bench/offline-vs-update.leann"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
)
|
||||
parser.add_argument("--max-initial", type=int, default=300)
|
||||
parser.add_argument("--num-updates", type=int, default=5)
|
||||
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="neural network",
|
||||
help="Query text used for the search benchmark.",
|
||||
)
|
||||
parser.add_argument("--server-port", type=int, default=5557)
|
||||
parser.add_argument("--add-timeout", type=int, default=600)
|
||||
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
||||
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
default="mips",
|
||||
choices=["mips", "l2", "cosine"],
|
||||
)
|
||||
parser.add_argument("--ef-construction", type=int, default=200)
|
||||
parser.add_argument(
|
||||
"--only",
|
||||
choices=["A", "B", "both"],
|
||||
default="both",
|
||||
help="Run only Scenario A, Scenario B, or both",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||
help="Where to append results (CSV).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
# Load data
|
||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
||||
if not update_paragraphs:
|
||||
raise ValueError("No update passages loaded from --update-files")
|
||||
update_paragraphs = update_paragraphs[: args.num_updates]
|
||||
if len(update_paragraphs) < args.num_updates:
|
||||
raise ValueError(
|
||||
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
||||
)
|
||||
|
||||
ensure_index_dir(args.index_path)
|
||||
cleanup_index_files(args.index_path)
|
||||
|
||||
# Build initial index
|
||||
build_initial_index(
|
||||
args.index_path,
|
||||
initial_paragraphs,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
args.ef_construction,
|
||||
)
|
||||
|
||||
# Prepare index object and meta
|
||||
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
||||
index = _read_index_for_search(args.index_path)
|
||||
|
||||
# CSV setup
|
||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||
if args.csv_path:
|
||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
csv_fields = [
|
||||
"run_id",
|
||||
"scenario",
|
||||
"max_initial",
|
||||
"num_updates",
|
||||
"k",
|
||||
"total_time_s",
|
||||
"add_total_s",
|
||||
"search_time_s",
|
||||
"emb_time_s",
|
||||
"makespan_s",
|
||||
"model_name",
|
||||
"embedding_mode",
|
||||
"distance_metric",
|
||||
]
|
||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writeheader()
|
||||
|
||||
# Debug: list existing HNSW server PIDs before starting
|
||||
try:
|
||||
existing = [
|
||||
p
|
||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||
if any(
|
||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||
for arg in (p.info.get("cmdline") or [])
|
||||
)
|
||||
]
|
||||
if existing:
|
||||
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
||||
for p in existing:
|
||||
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
||||
except Exception as _e:
|
||||
pass
|
||||
|
||||
add_total = 0.0
|
||||
search_after_add = 0.0
|
||||
total_seq = 0.0
|
||||
port_a = None
|
||||
if args.only in ("A", "both"):
|
||||
# Scenario A: sequential update then search
|
||||
start_id = index.ntotal
|
||||
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
||||
if assigned_ids:
|
||||
logger.debug(
|
||||
"Registered %d update passages starting at id %s",
|
||||
len(assigned_ids),
|
||||
assigned_ids[0],
|
||||
)
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
ok, port = server_manager.start_server(
|
||||
port=args.server_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
if not ok:
|
||||
raise RuntimeError("Failed to start embedding server")
|
||||
try:
|
||||
# Set ZMQ port for recompute mode
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(port)
|
||||
|
||||
# Start A overall timer BEFORE computing update embeddings
|
||||
t0 = time.time()
|
||||
|
||||
# Compute embeddings for updates (counted into A's overall)
|
||||
t_emb0 = time.time()
|
||||
upd_embs = compute_embeddings(
|
||||
update_paragraphs,
|
||||
args.model_name,
|
||||
mode=args.embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
emb_time_updates = time.time() - t_emb0
|
||||
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
||||
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
||||
|
||||
# Perform sequential adds
|
||||
for i in range(upd_embs.shape[0]):
|
||||
t_add0 = time.time()
|
||||
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
||||
add_total += time.time() - t_add0
|
||||
# Don't persist index after adds to avoid contaminating Scenario B
|
||||
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
||||
# faiss.write_index(index, str(index_file))
|
||||
|
||||
# Search after updates
|
||||
q_emb = compute_embeddings(
|
||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
q_emb = np.asarray(q_emb, dtype=np.float32)
|
||||
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
||||
|
||||
# Warm up search with a dummy query first
|
||||
print("[DEBUG] Warming up search...")
|
||||
_ = _search(index, q_emb, 1)
|
||||
|
||||
t_s0 = time.time()
|
||||
D_upd, I_upd = _search(index, q_emb, args.k)
|
||||
search_after_add = time.time() - t_s0
|
||||
total_seq = time.time() - t0
|
||||
finally:
|
||||
server_manager.stop_server()
|
||||
port_a = port
|
||||
|
||||
print("\n=== Scenario A: update->search (sequential) ===")
|
||||
# emb_time_updates is defined only when A runs
|
||||
try:
|
||||
_emb_a = emb_time_updates
|
||||
except NameError:
|
||||
_emb_a = 0.0
|
||||
print(
|
||||
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
||||
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
||||
)
|
||||
# CSV row for A
|
||||
if args.csv_path:
|
||||
row_a = {
|
||||
"run_id": run_id,
|
||||
"scenario": "A",
|
||||
"max_initial": args.max_initial,
|
||||
"num_updates": args.num_updates,
|
||||
"k": args.k,
|
||||
"total_time_s": round(total_seq, 6),
|
||||
"add_total_s": round(add_total, 6),
|
||||
"search_time_s": round(search_after_add, 6),
|
||||
"emb_time_s": round(_emb_a, 6),
|
||||
"makespan_s": 0.0,
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row_a)
|
||||
|
||||
# Verify server cleanup
|
||||
try:
|
||||
# short sleep to allow signal handling to finish
|
||||
time.sleep(0.5)
|
||||
leftovers = [
|
||||
p
|
||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||
if any(
|
||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||
for arg in (p.info.get("cmdline") or [])
|
||||
)
|
||||
]
|
||||
if leftovers:
|
||||
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
||||
for p in leftovers:
|
||||
print(
|
||||
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
||||
)
|
||||
else:
|
||||
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Scenario B: offline embeds + concurrent search (no graph updates)
|
||||
if args.only in ("B", "both"):
|
||||
# ensure a server is available for recompute search
|
||||
server_manager_b = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
requested_port = args.server_port if port_a is None else port_a
|
||||
ok_b, port_b = server_manager_b.start_server(
|
||||
port=requested_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
if not ok_b:
|
||||
raise RuntimeError("Failed to start embedding server for Scenario B")
|
||||
|
||||
# Wait for server to fully initialize
|
||||
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
||||
time.sleep(2)
|
||||
|
||||
try:
|
||||
# Read the index first
|
||||
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
||||
|
||||
# Then configure ZMQ port on the correct index object
|
||||
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
||||
index_no_update.hnsw.set_zmq_port(port_b)
|
||||
elif hasattr(index_no_update, "set_zmq_port"):
|
||||
index_no_update.set_zmq_port(port_b)
|
||||
|
||||
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
||||
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
||||
logger.info("Warming up embedding model for Scenario B...")
|
||||
_ = compute_embeddings(
|
||||
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
|
||||
# Prepare worker A: compute embeddings for the same N passages
|
||||
emb_time = 0.0
|
||||
updates_embs_offline: np.ndarray | None = None
|
||||
|
||||
def _worker_emb():
|
||||
nonlocal emb_time, updates_embs_offline
|
||||
t = time.time()
|
||||
updates_embs_offline = compute_embeddings(
|
||||
update_paragraphs,
|
||||
args.model_name,
|
||||
mode=args.embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
emb_time = time.time() - t
|
||||
|
||||
# Pre-compute query embedding and warm up search outside of timed section.
|
||||
q_vec = compute_embeddings(
|
||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
q_vec = np.asarray(q_vec, dtype=np.float32)
|
||||
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
||||
print("[DEBUG B] Warming up search...")
|
||||
_ = _search(index_no_update, q_vec, 1)
|
||||
|
||||
# Worker B: timed search on the warmed index
|
||||
search_time = 0.0
|
||||
offline_elapsed = 0.0
|
||||
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
||||
|
||||
def _worker_search():
|
||||
nonlocal search_time, index_results
|
||||
t = time.time()
|
||||
distances, indices = _search(index_no_update, q_vec, args.k)
|
||||
search_time = time.time() - t
|
||||
index_results = (distances, indices)
|
||||
|
||||
# Run two workers concurrently
|
||||
t0 = time.time()
|
||||
th1 = threading.Thread(target=_worker_emb)
|
||||
th2 = threading.Thread(target=_worker_search)
|
||||
th1.start()
|
||||
th2.start()
|
||||
th1.join()
|
||||
th2.join()
|
||||
offline_elapsed = time.time() - t0
|
||||
|
||||
# For mixing: compute query vs. offline update similarities (pure client-side)
|
||||
offline_scores: list[tuple[int, float]] = []
|
||||
if updates_embs_offline is not None:
|
||||
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
||||
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
||||
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
||||
for j in range(upd2.shape[0]):
|
||||
if args.distance_metric in ("mips", "cosine"):
|
||||
s = float(np.dot(q_vec[0], upd2[j]))
|
||||
else:
|
||||
diff = q_vec[0] - upd2[j]
|
||||
s = -float(np.dot(diff, diff))
|
||||
offline_scores.append((j, s))
|
||||
|
||||
merged_topk = (
|
||||
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
||||
if index_results
|
||||
else []
|
||||
)
|
||||
|
||||
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
||||
print(
|
||||
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
||||
)
|
||||
if merged_topk:
|
||||
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
||||
print(f"Merged top-5 preview: {preview}")
|
||||
# CSV row for B
|
||||
if args.csv_path:
|
||||
row_b = {
|
||||
"run_id": run_id,
|
||||
"scenario": "B",
|
||||
"max_initial": args.max_initial,
|
||||
"num_updates": args.num_updates,
|
||||
"k": args.k,
|
||||
"total_time_s": 0.0,
|
||||
"add_total_s": 0.0,
|
||||
"search_time_s": round(search_time, 6),
|
||||
"emb_time_s": round(emb_time, 6),
|
||||
"makespan_s": round(offline_elapsed, 6),
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row_b)
|
||||
|
||||
finally:
|
||||
server_manager_b.stop_server()
|
||||
|
||||
# Summary
|
||||
print("\n=== Summary ===")
|
||||
msg_a = (
|
||||
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
||||
if args.only in ("A", "both")
|
||||
else "A: skipped"
|
||||
)
|
||||
msg_b = (
|
||||
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
||||
if args.only in ("B", "both")
|
||||
else "B: skipped"
|
||||
)
|
||||
print(msg_a + "\n" + msg_b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
benchmarks/update/offline_vs_update.csv
Normal file
@@ -0,0 +1,5 @@
|
||||
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
||||
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
|
645
benchmarks/update/plot_bench_results.py
Normal file
@@ -0,0 +1,645 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Plot latency bars from the benchmark CSV produced by
|
||||
benchmarks/update/bench_hnsw_rng_recompute.py.
|
||||
|
||||
If you also provide an offline_vs_update.csv via --csv-right
|
||||
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
||||
output a side-by-side figure:
|
||||
- Left: ms/passage bars (four RNG scenarios).
|
||||
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
||||
|
||||
Usage:
|
||||
uv run python benchmarks/update/plot_bench_results.py \
|
||||
--csv benchmarks/update/bench_results.csv \
|
||||
--out benchmarks/update/bench_latency_from_csv.png
|
||||
|
||||
The script selects the latest run_id in the CSV and plots four bars for
|
||||
the default scenarios:
|
||||
- baseline
|
||||
- no_cache_baseline
|
||||
- disable_forward_rng
|
||||
- disable_forward_and_reverse_rng
|
||||
|
||||
If multiple rows exist per scenario for that run_id, the script averages
|
||||
their latency_ms_per_passage values.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_SCENARIOS = [
|
||||
"no_cache_baseline",
|
||||
"baseline",
|
||||
"disable_forward_rng",
|
||||
"disable_forward_and_reverse_rng",
|
||||
]
|
||||
|
||||
SCENARIO_LABELS = {
|
||||
"baseline": "+ Cache",
|
||||
"no_cache_baseline": "Naive \n Recompute",
|
||||
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
||||
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
||||
}
|
||||
|
||||
# Paper-style colors and hatches for scenarios
|
||||
SCENARIO_STYLES = {
|
||||
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
||||
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
||||
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
||||
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
||||
}
|
||||
|
||||
|
||||
def load_latest_run(csv_path: Path):
|
||||
rows = []
|
||||
with csv_path.open("r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
rows.append(row)
|
||||
if not rows:
|
||||
raise SystemExit("CSV is empty: no rows to plot")
|
||||
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
||||
run_ids = [r.get("run_id", "") for r in rows]
|
||||
latest = max(run_ids)
|
||||
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
||||
if not latest_rows:
|
||||
# Fallback: take last 4 rows
|
||||
latest_rows = rows[-4:]
|
||||
latest = latest_rows[-1].get("run_id", "unknown")
|
||||
return latest, latest_rows
|
||||
|
||||
|
||||
def aggregate_latency(rows):
|
||||
acc = defaultdict(list)
|
||||
for r in rows:
|
||||
sc = r.get("scenario", "")
|
||||
try:
|
||||
val = float(r.get("latency_ms_per_passage", "nan"))
|
||||
except ValueError:
|
||||
continue
|
||||
acc[sc].append(val)
|
||||
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
||||
return avg
|
||||
|
||||
|
||||
def _auto_cap(values: list[float]) -> float | None:
|
||||
if not values:
|
||||
return None
|
||||
sorted_vals = sorted(values, reverse=True)
|
||||
if len(sorted_vals) < 2:
|
||||
return None
|
||||
max_v, second = sorted_vals[0], sorted_vals[1]
|
||||
if second <= 0:
|
||||
return None
|
||||
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
||||
if max_v >= 2.5 * second:
|
||||
return second * 1.1
|
||||
return None
|
||||
|
||||
|
||||
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
||||
# Draw small diagonal ticks near left/right to signal cap
|
||||
x0, x1 = rel_x0, rel_x1
|
||||
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||
|
||||
|
||||
def _fmt_ms(v: float) -> str:
|
||||
if v >= 1000:
|
||||
return f"{v / 1000:.1f}k"
|
||||
return f"{v:.1f}"
|
||||
|
||||
|
||||
def main():
|
||||
# Set LaTeX style for paper figures (matching paper_fig.py)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument(
|
||||
"--csv",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/bench_results.csv"),
|
||||
help="Path to results CSV (defaults to bench_results.csv)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--out",
|
||||
type=Path,
|
||||
default=Path("add_ablation.pdf"),
|
||||
help="Output image path",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--csv-right",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-auto-cap",
|
||||
action="store_true",
|
||||
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--broken-y",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--lower-cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--upper-start-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
latest_run, latest_rows = load_latest_run(args.csv)
|
||||
avg = aggregate_latency(latest_rows)
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except Exception as e:
|
||||
raise SystemExit(f"matplotlib not available: {e}")
|
||||
|
||||
scenarios = DEFAULT_SCENARIOS
|
||||
values = [avg.get(name, 0.0) for name in scenarios]
|
||||
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||
|
||||
# If right CSV is provided, build side-by-side figure
|
||||
if args.csv_right is not None:
|
||||
try:
|
||||
right_rows_all = []
|
||||
with args.csv_right.open("r", encoding="utf-8") as f:
|
||||
rreader = csv.DictReader(f)
|
||||
right_rows_all = list(rreader)
|
||||
if right_rows_all:
|
||||
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
||||
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
||||
else:
|
||||
r_latest = None
|
||||
right_rows = []
|
||||
except Exception:
|
||||
r_latest = None
|
||||
right_rows = []
|
||||
|
||||
a_total = 0.0
|
||||
b_makespan = 0.0
|
||||
for r in right_rows:
|
||||
sc = (r.get("scenario", "") or "").strip().upper()
|
||||
if sc == "A":
|
||||
try:
|
||||
a_total = float(r.get("total_time_s", 0.0))
|
||||
except Exception:
|
||||
pass
|
||||
elif sc == "B":
|
||||
try:
|
||||
b_makespan = float(r.get("makespan_s", 0.0))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import gridspec
|
||||
|
||||
# Left subplot (reuse current style, with optional cap)
|
||||
cap = args.cap_y
|
||||
if cap is None and not args.no_auto_cap:
|
||||
cap = _auto_cap(values)
|
||||
x = list(range(len(labels)))
|
||||
|
||||
if args.broken_y:
|
||||
# Use broken axis for left subplot
|
||||
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
||||
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
||||
gs = gridspec.GridSpec(
|
||||
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
||||
)
|
||||
ax_left_top = fig.add_subplot(gs[0, 0])
|
||||
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
||||
ax_right = fig.add_subplot(gs[:, 1])
|
||||
|
||||
# Determine break points
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = (
|
||||
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
||||
) # Increased to show more range
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.5, lower_cap * 1.02)
|
||||
)
|
||||
ymax = (
|
||||
max(values) * 1.90 if values else 1.0
|
||||
) # Increase headroom to 1.90 for text label and tick range
|
||||
|
||||
# Draw bars on both axes
|
||||
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
|
||||
# Set limits
|
||||
ax_left_bottom.set_ylim(0, lower_cap)
|
||||
ax_left_top.set_ylim(upper_start, ymax)
|
||||
|
||||
# Annotate values (convert ms to s)
|
||||
values_s = [v / 1000.0 for v in values]
|
||||
lower_cap_s = lower_cap / 1000.0
|
||||
upper_start_s = upper_start / 1000.0
|
||||
ymax_s = ymax / 1000.0
|
||||
|
||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||
|
||||
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
||||
ax_left_bottom.clear()
|
||||
ax_left_top.clear()
|
||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
||||
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
||||
# Draw in bottom axis for all bars
|
||||
ax_left_bottom.bar(
|
||||
i,
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
# Only draw in top axis if the bar is tall enough to reach the upper range
|
||||
if v > upper_start_s:
|
||||
ax_left_top.bar(
|
||||
i,
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||
|
||||
for i, v in enumerate(values_s):
|
||||
if v <= lower_cap_s:
|
||||
ax_left_bottom.text(
|
||||
i,
|
||||
v + lower_cap_s * 0.02,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
else:
|
||||
ax_left_top.text(
|
||||
i,
|
||||
v + (ymax_s - upper_start_s) * 0.02,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Hide spines between axes
|
||||
ax_left_top.spines["bottom"].set_visible(False)
|
||||
ax_left_bottom.spines["top"].set_visible(False)
|
||||
ax_left_top.tick_params(
|
||||
labeltop=False, labelbottom=False, bottom=False
|
||||
) # Hide tick marks
|
||||
ax_left_bottom.xaxis.tick_bottom()
|
||||
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
||||
|
||||
# Draw break marks (matching paper_fig.py style)
|
||||
d = 0.015
|
||||
kwargs = {
|
||||
"transform": ax_left_top.transAxes,
|
||||
"color": "k",
|
||||
"clip_on": False,
|
||||
"linewidth": 0.8,
|
||||
"zorder": 10,
|
||||
}
|
||||
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||
kwargs.update({"transform": ax_left_bottom.transAxes})
|
||||
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||
|
||||
ax_left_bottom.set_xticks(x)
|
||||
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
||||
# Don't set ylabel here - will use fig.text for alignment
|
||||
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
||||
ax_left_top.tick_params(axis="y", labelsize=10)
|
||||
# Add subtle grid for better readability
|
||||
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||
|
||||
# Set x-axis limits to match bar width with right subplot
|
||||
ax_left_bottom.set_xlim(-0.6, 3.6)
|
||||
ax_left_top.set_xlim(-0.6, 3.6)
|
||||
|
||||
ax_left = ax_left_bottom # for compatibility
|
||||
else:
|
||||
# Regular side-by-side layout
|
||||
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
||||
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
||||
for i, (val, show) in enumerate(zip(values, show_vals)):
|
||||
if val > cap:
|
||||
bars[i].set_hatch("//")
|
||||
ax_left.text(
|
||||
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
||||
)
|
||||
else:
|
||||
ax_left.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
_fmt_ms(val),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax_left.set_ylim(0, cap * 1.10)
|
||||
_add_break_marker(ax_left, y=0.98)
|
||||
ax_left.set_xticks(x)
|
||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
else:
|
||||
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
for i, v in enumerate(values):
|
||||
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
ax_left.set_xticks(x)
|
||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
ax_left.set_ylabel("Latency (ms per passage)")
|
||||
max_initial = latest_rows[0].get("max_initial", "?")
|
||||
max_updates = latest_rows[0].get("max_updates", "?")
|
||||
ax_left.set_title(
|
||||
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
||||
)
|
||||
|
||||
# Right subplot (A vs B, seconds) - paper style
|
||||
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
||||
r_values = [a_total or 0.0, b_makespan or 0.0]
|
||||
r_styles = [
|
||||
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
||||
{"edgecolor": "#edc948", "hatch": "/////"},
|
||||
]
|
||||
# 2 bars, centered with proper spacing
|
||||
xr = [0, 1]
|
||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
||||
ax_right.bar(
|
||||
xr[i],
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
for i, v in enumerate(r_values):
|
||||
max_v = max(r_values) if r_values else 1.0
|
||||
offset = max(0.0002, 0.02 * max_v)
|
||||
ax_right.text(
|
||||
xr[i],
|
||||
v + offset,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_right.set_xticks(xr)
|
||||
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
||||
# Don't set ylabel here - will use fig.text for alignment
|
||||
ax_right.tick_params(axis="y", labelsize=10)
|
||||
# Add subtle grid for better readability
|
||||
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||
|
||||
# Set x-axis limits to match left subplot's bar width visually
|
||||
# Accounting for width_ratios=[1.5, 1]:
|
||||
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
||||
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
||||
# Right: 2 bars, need same visual width
|
||||
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
||||
# range_right = 4.2 / 1.5 = 2.8
|
||||
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
||||
ax_right.set_xlim(-0.9, 1.9)
|
||||
|
||||
# Set y-axis limit with headroom for text labels
|
||||
if r_values:
|
||||
max_v = max(r_values)
|
||||
ax_right.set_ylim(0, max_v * 1.15)
|
||||
|
||||
# Format y-axis to avoid scientific notation
|
||||
ax_right.ticklabel_format(style="plain", axis="y")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Add aligned ylabels using fig.text (after tight_layout)
|
||||
# Get the vertical center of the entire figure
|
||||
fig_center_y = 0.5
|
||||
# Left ylabel - closer to left plot
|
||||
left_x = 0.05
|
||||
fig.text(
|
||||
left_x,
|
||||
fig_center_y,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
# Right ylabel - closer to right plot
|
||||
right_bbox = ax_right.get_position()
|
||||
right_x = right_bbox.x0 - 0.07
|
||||
fig.text(
|
||||
right_x,
|
||||
fig_center_y,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||
# Also save PDF for paper
|
||||
pdf_out = args.out.with_suffix(".pdf")
|
||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||
print(f"Saved: {args.out}")
|
||||
print(f"Saved: {pdf_out}")
|
||||
return
|
||||
|
||||
# Broken-Y mode
|
||||
if args.broken_y:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||
2,
|
||||
1,
|
||||
sharex=True,
|
||||
figsize=(7.5, 6.75),
|
||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
||||
)
|
||||
|
||||
# Determine default breaks from second-highest
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.2, lower_cap * 1.02)
|
||||
)
|
||||
ymax = max(values) * 1.10 if values else 1.0
|
||||
|
||||
x = list(range(len(labels)))
|
||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
|
||||
# Limits
|
||||
ax_bottom.set_ylim(0, lower_cap)
|
||||
ax_top.set_ylim(upper_start, ymax)
|
||||
|
||||
# Annotate values
|
||||
for i, v in enumerate(values):
|
||||
if v <= lower_cap:
|
||||
ax_bottom.text(
|
||||
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
||||
)
|
||||
else:
|
||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
|
||||
# Hide spines between axes and draw diagonal break marks
|
||||
ax_top.spines["bottom"].set_visible(False)
|
||||
ax_bottom.spines["top"].set_visible(False)
|
||||
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
||||
ax_bottom.xaxis.tick_bottom()
|
||||
|
||||
# Diagonal lines at the break (matching paper_fig.py style)
|
||||
d = 0.015
|
||||
kwargs = {
|
||||
"transform": ax_top.transAxes,
|
||||
"color": "k",
|
||||
"clip_on": False,
|
||||
"linewidth": 0.8,
|
||||
"zorder": 10,
|
||||
}
|
||||
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
||||
kwargs.update({"transform": ax_bottom.transAxes})
|
||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
||||
|
||||
ax_bottom.set_xticks(x)
|
||||
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
ax = ax_bottom # for labeling below
|
||||
else:
|
||||
cap = args.cap_y
|
||||
if cap is None and not args.no_auto_cap:
|
||||
cap = _auto_cap(values)
|
||||
|
||||
plt.figure(figsize=(5.4, 3.15))
|
||||
ax = plt.gca()
|
||||
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = []
|
||||
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
||||
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
||||
bars.append(bar[0])
|
||||
# Hatch and annotate when capped
|
||||
if val > cap:
|
||||
bars[-1].set_hatch("//")
|
||||
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
||||
else:
|
||||
ax.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
f"{_fmt_ms(val)}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax.set_ylim(0, cap * 1.10)
|
||||
_add_break_marker(ax, y=0.98)
|
||||
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
||||
v > cap for v in values
|
||||
) else None
|
||||
ax.set_xticks(range(len(labels)))
|
||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||
else:
|
||||
ax.bar(labels, values, color=colors[: len(labels)])
|
||||
for idx, val in enumerate(values):
|
||||
ax.text(
|
||||
idx,
|
||||
val + 1.0,
|
||||
f"{_fmt_ms(val)}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||
# Try to extract some context for title
|
||||
max_initial = latest_rows[0].get("max_initial", "?")
|
||||
max_updates = latest_rows[0].get("max_updates", "?")
|
||||
|
||||
if args.broken_y:
|
||||
fig.text(
|
||||
0.02,
|
||||
0.5,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
fig.suptitle(
|
||||
"Add Operation Latency",
|
||||
fontsize=11,
|
||||
y=0.98,
|
||||
fontweight="bold",
|
||||
)
|
||||
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
||||
else:
|
||||
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
||||
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
||||
plt.tight_layout()
|
||||
|
||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||
# Also save PDF for paper
|
||||
pdf_out = args.out.with_suffix(".pdf")
|
||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||
print(f"Saved: {args.out}")
|
||||
print(f"Saved: {pdf_out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -53,9 +53,9 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
|
||||
|
||||
### Setup Pre-commit
|
||||
|
||||
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||
1. **Install pre-commit tools**:
|
||||
```bash
|
||||
uv pip install pre-commit
|
||||
uv sync lint
|
||||
```
|
||||
|
||||
2. **Install the git hooks**:
|
||||
@@ -65,7 +65,7 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
|
||||
|
||||
3. **Run pre-commit manually** (optional):
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Pre-commit Checks
|
||||
@@ -85,6 +85,9 @@ Our pre-commit configuration includes:
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Install test tools only (no project runtime)
|
||||
uv sync --group test
|
||||
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
|
||||
@@ -83,6 +83,170 @@ ollama pull nomic-embed-text
|
||||
|
||||
</details>
|
||||
|
||||
## Local & Remote Inference Endpoints
|
||||
|
||||
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
|
||||
|
||||
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables.
|
||||
|
||||
### One-Time Environment Setup
|
||||
|
||||
```bash
|
||||
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
|
||||
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
|
||||
export OPENAI_BASE_URL="http://localhost:1234/v1"
|
||||
|
||||
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
|
||||
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
|
||||
```
|
||||
|
||||
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
|
||||
|
||||
### Passing Hosts Per Command
|
||||
|
||||
```bash
|
||||
# Build an index with a remote embedding server
|
||||
leann build my-notes \
|
||||
--docs ./notes \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-qwen3-embedding-0.6b \
|
||||
--embedding-api-base http://192.168.1.50:1234/v1 \
|
||||
--embedding-api-key local-dev-key
|
||||
|
||||
# Query using a local LM Studio instance via OpenAI-compatible API
|
||||
leann ask my-notes \
|
||||
--llm openai \
|
||||
--llm-model qwen3-8b \
|
||||
--api-base http://localhost:1234/v1 \
|
||||
--api-key local-dev-key
|
||||
|
||||
# Query an Ollama instance running on another box
|
||||
leann ask my-notes \
|
||||
--llm ollama \
|
||||
--llm-model qwen3:14b \
|
||||
--host http://192.168.1.101:11434
|
||||
```
|
||||
|
||||
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
|
||||
|
||||
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
|
||||
- Configure router or cloud provider port forwarding.
|
||||
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
|
||||
|
||||
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
|
||||
|
||||
**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials.
|
||||
|
||||
### Python API Usage
|
||||
|
||||
You can pass the same configuration from Python:
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_mode="openai",
|
||||
embedding_model="text-embedding-qwen3-embedding-0.6b",
|
||||
embedding_options={
|
||||
"base_url": "http://192.168.1.50:1234/v1",
|
||||
"api_key": "local-dev-key",
|
||||
},
|
||||
)
|
||||
builder.build_index("./indexes/my-notes", chunks)
|
||||
```
|
||||
|
||||
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||
|
||||
## Optional Embedding Features
|
||||
|
||||
### Task-Specific Prompt Templates
|
||||
|
||||
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
|
||||
|
||||
- **Indexing documents**: `"title: none | text: "`
|
||||
- **Search queries**: `"task: search result | query: "`
|
||||
|
||||
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
|
||||
|
||||
```bash
|
||||
# Build index with EmbeddingGemma (via LM Studio or Ollama)
|
||||
leann build my-docs \
|
||||
--docs ./documents \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-embeddinggemma-300m-qat \
|
||||
--embedding-api-base http://localhost:1234/v1 \
|
||||
--embedding-prompt-template "title: none | text: " \
|
||||
--force
|
||||
|
||||
# Search with query-specific prompt
|
||||
leann search my-docs \
|
||||
--query "What is quantum computing?" \
|
||||
--embedding-prompt-template "task: search result | query: "
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
|
||||
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
|
||||
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
|
||||
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
|
||||
|
||||
**Python API:**
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder(
|
||||
embedding_mode="openai",
|
||||
embedding_model="text-embedding-embeddinggemma-300m-qat",
|
||||
embedding_options={
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"api_key": "lm-studio",
|
||||
"prompt_template": "title: none | text: ",
|
||||
},
|
||||
)
|
||||
builder.build_index("./indexes/my-docs", chunks)
|
||||
```
|
||||
|
||||
**References:**
|
||||
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
|
||||
|
||||
### LM Studio Auto-Detection (Optional)
|
||||
|
||||
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
|
||||
|
||||
**Prerequisites:**
|
||||
```bash
|
||||
# Install Node.js (if not already installed)
|
||||
# Then install the LM Studio SDK globally
|
||||
npm install -g @lmstudio/sdk
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
|
||||
2. Queries model metadata via Node.js subprocess
|
||||
3. Automatically unloads model after query (respects your JIT auto-evict settings)
|
||||
4. Falls back to static registry if SDK unavailable
|
||||
|
||||
**No configuration needed** - it works automatically when SDK is installed:
|
||||
|
||||
```bash
|
||||
leann build my-docs \
|
||||
--docs ./documents \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-nomic-embed-text-v1.5 \
|
||||
--embedding-api-base http://localhost:1234/v1
|
||||
# Context length auto-detected if SDK available
|
||||
# Falls back to registry (2048) if not
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- ✅ Automatic token limit detection
|
||||
- ✅ Respects LM Studio JIT auto-evict settings
|
||||
- ✅ No manual registry maintenance
|
||||
- ✅ Graceful fallback if SDK unavailable
|
||||
|
||||
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
|
||||
|
||||
## Index Selection: Matching Your Scale
|
||||
|
||||
### HNSW (Hierarchical Navigable Small World)
|
||||
@@ -290,7 +454,7 @@ leann search my-index "your query" \
|
||||
|
||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
|
||||
|
||||
```bash
|
||||
# One-time: install and configure SkyPilot
|
||||
@@ -380,5 +544,5 @@ Conclusion:
|
||||
|
||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||
- [DiskANN Original Paper](https://suhasjs.github.io/files/diskann_neurips19.pdf)
|
||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||
|
||||
48
docs/faq.md
@@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
|
||||
## 2. When should I use prompt templates?
|
||||
|
||||
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
|
||||
|
||||
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
|
||||
|
||||
**Example usage with EmbeddingGemma:**
|
||||
```bash
|
||||
# Build with document prompt
|
||||
leann build my-docs --embedding-prompt-template "title: none | text: "
|
||||
|
||||
# Search with query prompt
|
||||
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
|
||||
```
|
||||
|
||||
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
|
||||
|
||||
## 3. Why is LM Studio loading multiple copies of my model?
|
||||
|
||||
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
|
||||
|
||||
**If you still see duplicates:**
|
||||
- Update to the latest LEANN version
|
||||
- Restart LM Studio to clear loaded models
|
||||
- Check that you have JIT auto-evict enabled in LM Studio settings
|
||||
|
||||
**How it works now:**
|
||||
1. LEANN loads model temporarily to get context length
|
||||
2. Immediately unloads after query
|
||||
3. LM Studio JIT loads model on-demand for actual embeddings
|
||||
4. Auto-evicts per your settings
|
||||
|
||||
## 4. Do I need Node.js and @lmstudio/sdk?
|
||||
|
||||
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
|
||||
|
||||
**Benefits if you install it:**
|
||||
- Automatic context length detection for LM Studio models
|
||||
- No manual registry maintenance
|
||||
- Always gets accurate token limits from the model itself
|
||||
|
||||
**To install (optional):**
|
||||
```bash
|
||||
npm install -g @lmstudio/sdk
|
||||
```
|
||||
|
||||
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.
|
||||
|
||||
395
docs/slack-setup-guide.md
Normal file
@@ -0,0 +1,395 @@
|
||||
# Slack Integration Setup Guide
|
||||
|
||||
This guide provides step-by-step instructions for setting up Slack integration with LEANN.
|
||||
|
||||
## Overview
|
||||
|
||||
LEANN's Slack integration uses MCP (Model Context Protocol) servers to fetch and index your Slack messages for RAG (Retrieval-Augmented Generation). This allows you to search through your Slack conversations using natural language queries.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Slack Workspace Access**: You need admin or owner permissions in your Slack workspace to create apps and configure OAuth tokens.
|
||||
|
||||
2. **Slack MCP Server**: Install a Slack MCP server (e.g., `slack-mcp-server` via npm)
|
||||
|
||||
3. **LEANN**: Ensure you have LEANN installed and working
|
||||
|
||||
## Step 1: Create a Slack App
|
||||
|
||||
### 1.1 Go to Slack API Dashboard
|
||||
|
||||
1. Visit [https://api.slack.com/apps](https://api.slack.com/apps)
|
||||
2. Click **"Create New App"**
|
||||
3. Choose **"From scratch"**
|
||||
4. Enter your app name (e.g., "LEANN Slack Integration")
|
||||
5. Select your workspace
|
||||
6. Click **"Create App"**
|
||||
|
||||
### 1.2 Configure App Permissions
|
||||
|
||||
#### Token Scopes
|
||||
|
||||
1. In your app dashboard, go to **"OAuth & Permissions"** in the left sidebar
|
||||
2. Scroll down to **"Scopes"** section
|
||||
3. Under **"Bot Token Scopes & OAuth Scope"**, click **"Add an OAuth Scope"**
|
||||
4. Add the following scopes:
|
||||
- `channels:read` - Read public channel information
|
||||
- `channels:history` - Read messages in public channels
|
||||
- `groups:read` - Read private channel information
|
||||
- `groups:history` - Read messages in private channels
|
||||
- `im:read` - Read direct message information
|
||||
- `im:history` - Read direct messages
|
||||
- `mpim:read` - Read group direct message information
|
||||
- `mpim:history` - Read group direct messages
|
||||
- `users:read` - Read user information
|
||||
- `team:read` - Read workspace information
|
||||
|
||||
#### App-Level Tokens (Optional)
|
||||
|
||||
Some MCP servers may require app-level tokens:
|
||||
|
||||
1. Go to **"Basic Information"** in the left sidebar
|
||||
2. Scroll down to **"App-Level Tokens"**
|
||||
3. Click **"Generate Token and Scopes"**
|
||||
4. Enter a name (e.g., "LEANN Integration")
|
||||
5. Add the `connections:write` scope
|
||||
6. Click **"Generate"**
|
||||
7. Copy the token (starts with `xapp-`)
|
||||
|
||||
### 1.3 Install App to Workspace
|
||||
|
||||
1. Go to **"OAuth & Permissions"** in the left sidebar
|
||||
2. Click **"Install to Workspace"**
|
||||
3. Review the permissions and click **"Allow"**
|
||||
4. Copy the **"Bot User OAuth Token"** (starts with `xoxb-`)
|
||||
5. Copy the **"User OAuth Token"** (starts with `xoxp-`)
|
||||
|
||||
## Step 2: Install Slack MCP Server
|
||||
|
||||
### Option A: Using npm (Recommended)
|
||||
|
||||
```bash
|
||||
# Install globally
|
||||
npm install -g slack-mcp-server
|
||||
|
||||
# Or install locally
|
||||
npm install slack-mcp-server
|
||||
```
|
||||
|
||||
### Option B: Using npx (No installation required)
|
||||
|
||||
```bash
|
||||
# Use directly without installation
|
||||
npx slack-mcp-server
|
||||
```
|
||||
|
||||
## Step 3: Install and Configure Ollama (for Real LLM Responses)
|
||||
|
||||
### 3.1 Install Ollama
|
||||
|
||||
```bash
|
||||
# Install Ollama using Homebrew (macOS)
|
||||
brew install ollama
|
||||
|
||||
# Or download from https://ollama.ai/
|
||||
```
|
||||
|
||||
### 3.2 Start Ollama Service
|
||||
|
||||
```bash
|
||||
# Start Ollama as a service
|
||||
brew services start ollama
|
||||
|
||||
# Or start manually
|
||||
ollama serve
|
||||
```
|
||||
|
||||
### 3.3 Pull a Model
|
||||
|
||||
```bash
|
||||
# Pull a lightweight model for testing
|
||||
ollama pull llama3.2:1b
|
||||
|
||||
# Verify the model is available
|
||||
ollama list
|
||||
```
|
||||
|
||||
## Step 4: Configure Environment Variables
|
||||
|
||||
Create a `.env` file or set environment variables:
|
||||
|
||||
```bash
|
||||
# Required: User OAuth Token
|
||||
SLACK_OAUTH_TOKEN=xoxp-your-user-oauth-token-here
|
||||
|
||||
# Optional: App-Level Token (if your MCP server requires it)
|
||||
SLACK_APP_TOKEN=xapp-your-app-token-here
|
||||
|
||||
# Optional: Workspace-specific settings
|
||||
SLACK_WORKSPACE_ID=T1234567890 # Your workspace ID (optional)
|
||||
```
|
||||
|
||||
## Step 5: Test the Setup
|
||||
|
||||
### 5.1 Test MCP Server Connection
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--test-connection \
|
||||
--workspace-name "Your Workspace Name"
|
||||
```
|
||||
|
||||
This will test the connection and list available tools without indexing any data.
|
||||
|
||||
### 5.2 Index a Specific Channel
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Your Workspace Name" \
|
||||
--channels general \
|
||||
--query "What did we discuss about the project?"
|
||||
```
|
||||
|
||||
### 5.3 Real RAG Query Examples
|
||||
|
||||
This section demonstrates successful Slack RAG integration queries against the Sky Lab Computing workspace's "random" channel. The system successfully retrieves actual conversation messages and performs semantic search with high relevance scores, including finding specific research paper announcements and technical discussions.
|
||||
|
||||
### Example 1: Advisor Models Query
|
||||
|
||||
**Query:** "train black-box models to adopt to your personal data"
|
||||
|
||||
This query demonstrates the system's ability to find specific research announcements about training black-box models for personal data adaptation.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Example 2: Barbarians at the Gate Query
|
||||
|
||||
**Query:** "AI-driven research systems ADRS"
|
||||
|
||||
This query demonstrates the system's ability to find specific research announcements about AI-driven research systems and algorithm discovery.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Bot is installed in the Sky Lab Computing workspace and invited to the target channel (run `/invite @YourBotName` in the channel if needed)
|
||||
- Bot token available and exported in the same terminal session
|
||||
|
||||
### Commands
|
||||
|
||||
1) Set the workspace token for this shell
|
||||
|
||||
```bash
|
||||
export SLACK_MCP_XOXP_TOKEN="xoxp-***-redacted-***"
|
||||
```
|
||||
|
||||
2) Run queries against the "random" channel by channel ID (C0GN5BX0F)
|
||||
|
||||
**Advisor Models Query:**
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Sky Lab Computing" \
|
||||
--channels C0GN5BX0F \
|
||||
--max-messages-per-channel 100000 \
|
||||
--query "train black-box models to adopt to your personal data" \
|
||||
--llm ollama \
|
||||
--llm-model "llama3.2:1b" \
|
||||
--llm-host "http://localhost:11434" \
|
||||
--no-concatenate-conversations
|
||||
```
|
||||
|
||||
**Barbarians at the Gate Query:**
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Sky Lab Computing" \
|
||||
--channels C0GN5BX0F \
|
||||
--max-messages-per-channel 100000 \
|
||||
--query "AI-driven research systems ADRS" \
|
||||
--llm ollama \
|
||||
--llm-model "llama3.2:1b" \
|
||||
--llm-host "http://localhost:11434" \
|
||||
--no-concatenate-conversations
|
||||
```
|
||||
|
||||
These examples demonstrate the system's ability to find and retrieve specific research announcements and technical discussions from the conversation history, showcasing the power of semantic search in Slack data.
|
||||
|
||||
3) Optional: Ask a broader question
|
||||
|
||||
```bash
|
||||
python test_channel_by_id_or_name.py \
|
||||
--channel-id C0GN5BX0F \
|
||||
--workspace-name "Sky Lab Computing" \
|
||||
--query "What is LEANN about?"
|
||||
```
|
||||
|
||||
Notes:
|
||||
- If you see `not_in_channel`, invite the bot to the channel and re-run.
|
||||
- If you see `channel_not_found`, confirm the channel ID and workspace.
|
||||
- Deep search via server-side “search” tools may require additional Slack scopes; the example above performs client-side filtering over retrieved history.
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue 1: "users cache is not ready yet" Error
|
||||
|
||||
**Problem**: You see this warning:
|
||||
```
|
||||
WARNING - Failed to fetch messages from channel random: Failed to fetch messages: {'code': -32603, 'message': 'users cache is not ready yet, sync process is still running... please wait'}
|
||||
```
|
||||
|
||||
**Solution**: This is a common timing issue. The LEANN integration now includes automatic retry logic:
|
||||
|
||||
1. **Wait and Retry**: The system will automatically retry with exponential backoff (2s, 4s, 8s, etc.)
|
||||
2. **Increase Retry Parameters**: If needed, you can customize retry behavior:
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--max-retries 10 \
|
||||
--retry-delay 3.0 \
|
||||
--channels general \
|
||||
--query "Your query here"
|
||||
```
|
||||
3. **Keep MCP Server Running**: Start the MCP server separately and keep it running:
|
||||
```bash
|
||||
# Terminal 1: Start MCP server
|
||||
slack-mcp-server
|
||||
|
||||
# Terminal 2: Run LEANN (it will connect to the running server)
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --channels general --query "test"
|
||||
```
|
||||
|
||||
### Issue 2: "No message fetching tool found"
|
||||
|
||||
**Problem**: The MCP server doesn't have the expected tools.
|
||||
|
||||
**Solution**:
|
||||
1. Check if your MCP server is properly installed and configured
|
||||
2. Verify your Slack tokens are correct
|
||||
3. Try a different MCP server implementation
|
||||
4. Check the MCP server documentation for required configuration
|
||||
|
||||
### Issue 3: Permission Denied Errors
|
||||
|
||||
**Problem**: You get permission errors when trying to access channels.
|
||||
|
||||
**Solutions**:
|
||||
1. **Check Bot Permissions**: Ensure your bot has been added to the channels you want to access
|
||||
2. **Verify Token Scopes**: Make sure you have all required scopes configured
|
||||
3. **Channel Access**: For private channels, the bot needs to be explicitly invited
|
||||
4. **Workspace Permissions**: Ensure your Slack app has the necessary workspace permissions
|
||||
|
||||
### Issue 4: Empty Results
|
||||
|
||||
**Problem**: No messages are returned even though the channel has messages.
|
||||
|
||||
**Solutions**:
|
||||
1. **Check Channel Names**: Ensure channel names are correct (without the # symbol)
|
||||
2. **Verify Bot Access**: Make sure the bot can access the channels
|
||||
3. **Check Date Ranges**: Some MCP servers have limitations on message history
|
||||
4. **Increase Message Limits**: Try increasing the message limit:
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--channels general \
|
||||
--max-messages-per-channel 1000 \
|
||||
--query "test"
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Custom MCP Server Commands
|
||||
|
||||
If you need to pass additional parameters to your MCP server:
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server --token-file /path/to/tokens.json" \
|
||||
--workspace-name "Your Workspace" \
|
||||
--channels general \
|
||||
--query "Your query"
|
||||
```
|
||||
|
||||
### Multiple Workspaces
|
||||
|
||||
To work with multiple Slack workspaces, you can:
|
||||
|
||||
1. Create separate apps for each workspace
|
||||
2. Use different environment variables
|
||||
3. Run separate instances with different configurations
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
For better performance with large workspaces:
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Your Workspace" \
|
||||
--max-messages-per-channel 500 \
|
||||
--no-concatenate-conversations \
|
||||
--query "Your query"
|
||||
```
|
||||
---
|
||||
|
||||
## Troubleshooting Checklist
|
||||
|
||||
- [ ] Slack app created with proper permissions
|
||||
- [ ] Bot token (xoxb-) copied correctly
|
||||
- [ ] App-level token (xapp-) created if needed
|
||||
- [ ] MCP server installed and accessible
|
||||
- [ ] Ollama installed and running (`brew services start ollama`)
|
||||
- [ ] Ollama model pulled (`ollama pull llama3.2:1b`)
|
||||
- [ ] Environment variables set correctly
|
||||
- [ ] Bot invited to relevant channels
|
||||
- [ ] Channel names specified without # symbol
|
||||
- [ ] Sufficient retry attempts configured
|
||||
- [ ] Network connectivity to Slack APIs
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you continue to have issues:
|
||||
|
||||
1. **Check Logs**: Look for detailed error messages in the console output
|
||||
2. **Test MCP Server**: Use `--test-connection` to verify the MCP server is working
|
||||
3. **Verify Tokens**: Double-check that your Slack tokens are valid and have the right scopes
|
||||
4. **Check Ollama**: Ensure Ollama is running (`ollama serve`) and the model is available (`ollama list`)
|
||||
5. **Community Support**: Reach out to the LEANN community for help
|
||||
|
||||
## Example Commands
|
||||
|
||||
### Basic Usage
|
||||
```bash
|
||||
# Test connection
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
|
||||
|
||||
# Index specific channels
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "My Company" \
|
||||
--channels general random \
|
||||
--query "What did we decide about the project timeline?"
|
||||
```
|
||||
|
||||
### Advanced Usage
|
||||
```bash
|
||||
# With custom retry settings
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "My Company" \
|
||||
--channels general \
|
||||
--max-retries 10 \
|
||||
--retry-delay 5.0 \
|
||||
--max-messages-per-channel 2000 \
|
||||
--query "Show me all decisions made in the last month"
|
||||
```
|
||||
BIN
docs/videos/slack_integration_1.1.png
Normal file
|
After Width: | Height: | Size: 445 KiB |
BIN
docs/videos/slack_integration_1.2.png
Normal file
|
After Width: | Height: | Size: 508 KiB |
BIN
docs/videos/slack_integration_1.3.png
Normal file
|
After Width: | Height: | Size: 437 KiB |
BIN
docs/videos/slack_integration_2.1.png
Normal file
|
After Width: | Height: | Size: 474 KiB |
BIN
docs/videos/slack_integration_2.2.png
Normal file
|
After Width: | Height: | Size: 501 KiB |
BIN
docs/videos/slack_integration_2.3.png
Normal file
|
After Width: | Height: | Size: 454 KiB |
0
examples/__init__.py
Normal file
429
examples/dynamic_update_no_recompute.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""Dynamic HNSW update demo without compact storage.
|
||||
|
||||
This script reproduces the minimal scenario we used while debugging on-the-fly
|
||||
recompute:
|
||||
|
||||
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
|
||||
2. Print the top results with `recompute_embeddings=True`.
|
||||
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
|
||||
4. Run the same query again to show the newly inserted passages.
|
||||
|
||||
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
|
||||
ZMQ activity)::
|
||||
|
||||
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
|
||||
uv run -m examples.dynamic_update_no_recompute \
|
||||
--index-path .leann/examples/leann-demo.leann
|
||||
|
||||
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
|
||||
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
|
||||
It issues the query "What's LEANN?" before and after the update to show how the
|
||||
new passages become immediately searchable. The script uses the
|
||||
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
|
||||
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
|
||||
freshly added passages are embedded locally just like the initial build.
|
||||
|
||||
To make storage comparisons easy, the script can also build a matching
|
||||
``is_recompute=False`` baseline (enabled by default) and report the index size
|
||||
delta after the update. Disable the baseline run with
|
||||
``--skip-compare-no-recompute`` if you only need the recompute flow.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
from apps.chunking import create_text_chunks
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
DEFAULT_QUERY = "What's LEANN?"
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
REPO_ROOT / "data" / "PrideandPrejudice.txt",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path]) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
return [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
|
||||
|
||||
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
|
||||
searcher = LeannSearcher(str(index_path))
|
||||
try:
|
||||
return searcher.search(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
batch_size=16,
|
||||
)
|
||||
finally:
|
||||
searcher.cleanup()
|
||||
|
||||
|
||||
def print_results(title: str, results: Iterable) -> None:
|
||||
print(f"\n=== {title} ===")
|
||||
res_list = list(results)
|
||||
print(f"results count: {len(res_list)}")
|
||||
print("passages:")
|
||||
if not res_list:
|
||||
print(" (no passages returned)")
|
||||
for res in res_list:
|
||||
snippet = res.text.replace("\n", " ")[:120]
|
||||
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
is_recompute: bool,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def update_index(
|
||||
index_path: Path,
|
||||
start_id: int,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
is_recompute: bool,
|
||||
) -> None:
|
||||
updater = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
for offset, passage in enumerate(paragraphs, start=start_id):
|
||||
updater.add_text(passage, metadata={"id": str(offset)})
|
||||
updater.update_index(str(index_path))
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
"""Remove leftover index artifacts for a clean rebuild."""
|
||||
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def index_file_size(index_path: Path) -> int:
|
||||
"""Return the size of the primary .index file for the given index path."""
|
||||
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
return index_file.stat().st_size if index_file.exists() else 0
|
||||
|
||||
|
||||
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
|
||||
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(meta_path.read_text())
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def run_workflow(
|
||||
*,
|
||||
label: str,
|
||||
index_path: Path,
|
||||
initial_paragraphs: list[str],
|
||||
update_paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
is_recompute: bool,
|
||||
query: str,
|
||||
top_k: int,
|
||||
skip_search: bool,
|
||||
) -> dict[str, Any]:
|
||||
prefix = f"[{label}] " if label else ""
|
||||
|
||||
ensure_index_dir(index_path)
|
||||
cleanup_index_files(index_path)
|
||||
|
||||
print(f"{prefix}Building initial index...")
|
||||
build_initial_index(
|
||||
index_path,
|
||||
initial_paragraphs,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
|
||||
initial_size = index_file_size(index_path)
|
||||
if not skip_search:
|
||||
before_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
else:
|
||||
before_results = None
|
||||
|
||||
print(f"\n{prefix}Updating index with additional passages...")
|
||||
update_index(
|
||||
index_path,
|
||||
start_id=len(initial_paragraphs),
|
||||
paragraphs=update_paragraphs,
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
|
||||
if not skip_search:
|
||||
after_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
else:
|
||||
after_results = None
|
||||
updated_size = index_file_size(index_path)
|
||||
|
||||
return {
|
||||
"initial_size": initial_size,
|
||||
"updated_size": updated_size,
|
||||
"delta": updated_size - initial_size,
|
||||
"before_results": before_results if not skip_search else None,
|
||||
"after_results": after_results if not skip_search else None,
|
||||
"metadata": load_metadata_snapshot(index_path),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
type=Path,
|
||||
nargs="+",
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
help="Initial document files (PDF/TXT) used to build the base index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/examples/leann-demo.leann"),
|
||||
help="Destination index path (default: .leann/examples/leann-demo.leann)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-count",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of chunks to use from the initial documents (default: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
type=Path,
|
||||
nargs="*",
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
help="Additional documents to add during update (PDF/TXT)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-count",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of chunks to append from update documents (default: 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-text",
|
||||
type=str,
|
||||
default=(
|
||||
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
|
||||
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
|
||||
"on demand to keep disk usage minimal."
|
||||
),
|
||||
help="Fallback text to append if --update-files is omitted",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of results to show for each search (default: 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=DEFAULT_QUERY,
|
||||
help="Query to run before/after the update",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare-no-recompute",
|
||||
dest="compare_no_recompute",
|
||||
action="store_true",
|
||||
help="Also run a baseline with is_recompute=False and report its index growth.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-compare-no-recompute",
|
||||
dest="compare_no_recompute",
|
||||
action="store_false",
|
||||
help="Skip building the no-recompute baseline.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-search",
|
||||
dest="skip_search",
|
||||
action="store_true",
|
||||
help="Skip the search step.",
|
||||
)
|
||||
parser.set_defaults(compare_no_recompute=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
ensure_index_dir(args.index_path)
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
initial_chunks = load_chunks_from_files(list(args.initial_files))
|
||||
if not initial_chunks:
|
||||
raise ValueError("No text chunks extracted from the initial files.")
|
||||
|
||||
initial = initial_chunks[: args.initial_count]
|
||||
if not initial:
|
||||
raise ValueError("Initial chunk set is empty after applying --initial-count.")
|
||||
|
||||
if args.update_files:
|
||||
update_chunks = load_chunks_from_files(list(args.update_files))
|
||||
if not update_chunks:
|
||||
raise ValueError("No text chunks extracted from the update files.")
|
||||
to_add = update_chunks[: args.update_count]
|
||||
else:
|
||||
if not args.update_text:
|
||||
raise ValueError("Provide --update-files or --update-text for the update step.")
|
||||
to_add = [args.update_text]
|
||||
if not to_add:
|
||||
raise ValueError("Update chunk set is empty after applying --update-count.")
|
||||
|
||||
recompute_stats = run_workflow(
|
||||
label="recompute",
|
||||
index_path=args.index_path,
|
||||
initial_paragraphs=initial,
|
||||
update_paragraphs=to_add,
|
||||
model_name=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
is_recompute=True,
|
||||
query=args.query,
|
||||
top_k=args.top_k,
|
||||
skip_search=args.skip_search,
|
||||
)
|
||||
|
||||
if not args.skip_search:
|
||||
print_results("initial search", recompute_stats["before_results"])
|
||||
if not args.skip_search:
|
||||
print_results("after update", recompute_stats["after_results"])
|
||||
print(
|
||||
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||
f" (Δ {recompute_stats['delta']})"
|
||||
)
|
||||
|
||||
if recompute_stats["metadata"]:
|
||||
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||
print("[recompute] metadata snapshot:")
|
||||
print(json.dumps(meta_view, indent=2))
|
||||
|
||||
if args.compare_no_recompute:
|
||||
baseline_path = (
|
||||
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
|
||||
)
|
||||
baseline_stats = run_workflow(
|
||||
label="no-recompute",
|
||||
index_path=baseline_path,
|
||||
initial_paragraphs=initial,
|
||||
update_paragraphs=to_add,
|
||||
model_name=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
is_recompute=False,
|
||||
query=args.query,
|
||||
top_k=args.top_k,
|
||||
skip_search=args.skip_search,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
|
||||
f" (Δ {baseline_stats['delta']})"
|
||||
)
|
||||
|
||||
after_texts = (
|
||||
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
|
||||
)
|
||||
baseline_after_texts = (
|
||||
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
|
||||
)
|
||||
if after_texts == baseline_after_texts:
|
||||
print(
|
||||
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||
)
|
||||
else:
|
||||
print("[no-recompute] WARNING: search results differ from recompute baseline.")
|
||||
|
||||
if baseline_stats["metadata"]:
|
||||
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||
print("[no-recompute] metadata snapshot:")
|
||||
print(json.dumps(meta_view, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
178
examples/mcp_integration_demo.py
Normal file
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MCP Integration Examples for LEANN
|
||||
|
||||
This script demonstrates how to use LEANN with different MCP servers for
|
||||
RAG on various platforms like Slack and Twitter.
|
||||
|
||||
Examples:
|
||||
1. Slack message RAG via MCP
|
||||
2. Twitter bookmark RAG via MCP
|
||||
3. Testing MCP server connections
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the parent directory to the path so we can import from apps
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
async def demo_slack_mcp():
|
||||
"""Demonstrate Slack MCP integration."""
|
||||
print("=" * 60)
|
||||
print("🔥 Slack MCP RAG Demo")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. Testing Slack MCP server connection...")
|
||||
|
||||
# This would typically use a real MCP server command
|
||||
# For demo purposes, we show what the command would look like
|
||||
# slack_app = SlackMCPRAG() # Would be used for actual testing
|
||||
|
||||
# Simulate command line arguments for testing
|
||||
class MockArgs:
|
||||
mcp_server = "slack-mcp-server" # This would be the actual MCP server command
|
||||
workspace_name = "my-workspace"
|
||||
channels = ["general", "random", "dev-team"]
|
||||
no_concatenate_conversations = False
|
||||
max_messages_per_channel = 50
|
||||
test_connection = True
|
||||
|
||||
print(f"MCP Server Command: {MockArgs.mcp_server}")
|
||||
print(f"Workspace: {MockArgs.workspace_name}")
|
||||
print(f"Channels: {', '.join(MockArgs.channels)}")
|
||||
|
||||
# In a real scenario, you would run:
|
||||
# success = await slack_app.test_mcp_connection(MockArgs)
|
||||
|
||||
print("\n📝 Example usage:")
|
||||
print("python -m apps.slack_rag \\")
|
||||
print(" --mcp-server 'slack-mcp-server' \\")
|
||||
print(" --workspace-name 'my-team' \\")
|
||||
print(" --channels general dev-team \\")
|
||||
print(" --test-connection")
|
||||
|
||||
print("\n🔍 After indexing, you could query:")
|
||||
print("- 'What did the team discuss about the project deadline?'")
|
||||
print("- 'Find messages about the new feature launch'")
|
||||
print("- 'Show me conversations about budget planning'")
|
||||
|
||||
|
||||
async def demo_twitter_mcp():
|
||||
"""Demonstrate Twitter MCP integration."""
|
||||
print("\n" + "=" * 60)
|
||||
print("🐦 Twitter MCP RAG Demo")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. Testing Twitter MCP server connection...")
|
||||
|
||||
# twitter_app = TwitterMCPRAG() # Would be used for actual testing
|
||||
|
||||
class MockArgs:
|
||||
mcp_server = "twitter-mcp-server"
|
||||
username = None # Fetch all bookmarks
|
||||
max_bookmarks = 500
|
||||
no_tweet_content = False
|
||||
no_metadata = False
|
||||
test_connection = True
|
||||
|
||||
print(f"MCP Server Command: {MockArgs.mcp_server}")
|
||||
print(f"Max Bookmarks: {MockArgs.max_bookmarks}")
|
||||
print(f"Include Content: {not MockArgs.no_tweet_content}")
|
||||
print(f"Include Metadata: {not MockArgs.no_metadata}")
|
||||
|
||||
print("\n📝 Example usage:")
|
||||
print("python -m apps.twitter_rag \\")
|
||||
print(" --mcp-server 'twitter-mcp-server' \\")
|
||||
print(" --max-bookmarks 1000 \\")
|
||||
print(" --test-connection")
|
||||
|
||||
print("\n🔍 After indexing, you could query:")
|
||||
print("- 'What AI articles did I bookmark last month?'")
|
||||
print("- 'Find tweets about machine learning techniques'")
|
||||
print("- 'Show me bookmarked threads about startup advice'")
|
||||
|
||||
|
||||
async def show_mcp_server_setup():
|
||||
"""Show how to set up MCP servers."""
|
||||
print("\n" + "=" * 60)
|
||||
print("⚙️ MCP Server Setup Guide")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n🔧 Setting up Slack MCP Server:")
|
||||
print("1. Install a Slack MCP server (example commands):")
|
||||
print(" npm install -g slack-mcp-server")
|
||||
print(" # OR")
|
||||
print(" pip install slack-mcp-server")
|
||||
|
||||
print("\n2. Configure Slack credentials:")
|
||||
print(" export SLACK_BOT_TOKEN='xoxb-your-bot-token'")
|
||||
print(" export SLACK_APP_TOKEN='xapp-your-app-token'")
|
||||
|
||||
print("\n3. Test the server:")
|
||||
print(" slack-mcp-server --help")
|
||||
|
||||
print("\n🔧 Setting up Twitter MCP Server:")
|
||||
print("1. Install a Twitter MCP server:")
|
||||
print(" npm install -g twitter-mcp-server")
|
||||
print(" # OR")
|
||||
print(" pip install twitter-mcp-server")
|
||||
|
||||
print("\n2. Configure Twitter API credentials:")
|
||||
print(" export TWITTER_API_KEY='your-api-key'")
|
||||
print(" export TWITTER_API_SECRET='your-api-secret'")
|
||||
print(" export TWITTER_ACCESS_TOKEN='your-access-token'")
|
||||
print(" export TWITTER_ACCESS_TOKEN_SECRET='your-access-token-secret'")
|
||||
|
||||
print("\n3. Test the server:")
|
||||
print(" twitter-mcp-server --help")
|
||||
|
||||
|
||||
async def show_integration_benefits():
|
||||
"""Show the benefits of MCP integration."""
|
||||
print("\n" + "=" * 60)
|
||||
print("🌟 Benefits of MCP Integration")
|
||||
print("=" * 60)
|
||||
|
||||
benefits = [
|
||||
("🔄 Live Data Access", "Fetch real-time data from platforms without manual exports"),
|
||||
("🔌 Standardized Protocol", "Use any MCP-compatible server with minimal code changes"),
|
||||
("🚀 Easy Extension", "Add new platforms by implementing MCP readers"),
|
||||
("🔒 Secure Access", "MCP servers handle authentication and API management"),
|
||||
("📊 Rich Metadata", "Access full platform metadata (timestamps, engagement, etc.)"),
|
||||
("⚡ Efficient Processing", "Stream data directly into LEANN without intermediate files"),
|
||||
]
|
||||
|
||||
for title, description in benefits:
|
||||
print(f"\n{title}")
|
||||
print(f" {description}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main demo function."""
|
||||
print("🎯 LEANN MCP Integration Examples")
|
||||
print("This demo shows how to integrate LEANN with MCP servers for various platforms.")
|
||||
|
||||
await demo_slack_mcp()
|
||||
await demo_twitter_mcp()
|
||||
await show_mcp_server_setup()
|
||||
await show_integration_benefits()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✨ Next Steps")
|
||||
print("=" * 60)
|
||||
print("1. Install and configure MCP servers for your platforms")
|
||||
print("2. Test connections using --test-connection flag")
|
||||
print("3. Run indexing to build your RAG knowledge base")
|
||||
print("4. Start querying your personal data!")
|
||||
|
||||
print("\n📚 For more information:")
|
||||
print("- Check the README for detailed setup instructions")
|
||||
print("- Look at the apps/slack_rag.py and apps/twitter_rag.py for implementation details")
|
||||
print("- Explore other MCP servers for additional platforms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -343,7 +343,8 @@ class DiskannSearcher(BaseSearcher):
|
||||
"full_index_prefix": full_index_prefix,
|
||||
"num_threads": self.num_threads,
|
||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||
"cache_mechanism": 1,
|
||||
# 1 -> initialize cache using sample_data; 2 -> ready cache without init; others disable cache
|
||||
"cache_mechanism": kwargs.get("cache_mechanism", 1),
|
||||
"pq_prefix": "",
|
||||
"partition_prefix": partition_prefix,
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
@@ -32,6 +32,16 @@ if not logger.handlers:
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||
try:
|
||||
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||
PROVIDER_OPTIONS = {}
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
@@ -181,7 +191,12 @@ def create_diskann_embedding_server(
|
||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||
|
||||
# Process embeddings using unified computation
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
@@ -296,7 +311,12 @@ def create_diskann_embedding_server(
|
||||
continue
|
||||
|
||||
# Process the request
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||
|
||||
# Validation
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
[build-system]
|
||||
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
|
||||
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
|
||||
build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.3.4"
|
||||
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
||||
version = "0.3.5"
|
||||
dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# Key: simplified CMake path
|
||||
|
||||
@@ -29,12 +29,25 @@ if(APPLE)
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||
endif()
|
||||
|
||||
# Use system ZeroMQ instead of building from source
|
||||
# Find ZMQ using pkg-config with IMPORTED_TARGET for automatic target creation
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||
|
||||
# On ARM64 macOS, ensure pkg-config finds ARM64 Homebrew packages first
|
||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||
set(ENV{PKG_CONFIG_PATH} "/opt/homebrew/lib/pkgconfig:/opt/homebrew/share/pkgconfig:$ENV{PKG_CONFIG_PATH}")
|
||||
endif()
|
||||
|
||||
pkg_check_modules(ZMQ REQUIRED IMPORTED_TARGET libzmq)
|
||||
|
||||
# This creates PkgConfig::ZMQ target automatically with correct properties
|
||||
if(TARGET PkgConfig::ZMQ)
|
||||
message(STATUS "Found and configured ZMQ target: PkgConfig::ZMQ")
|
||||
else()
|
||||
message(FATAL_ERROR "pkg_check_modules did not create IMPORTED target for ZMQ.")
|
||||
endif()
|
||||
|
||||
# Add cppzmq headers
|
||||
include_directories(third_party/cppzmq)
|
||||
include_directories(SYSTEM third_party/cppzmq)
|
||||
|
||||
# Configure msgpack-c - disable boost dependency
|
||||
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||
|
||||
@@ -5,6 +5,8 @@ import os
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -237,6 +239,288 @@ def write_compact_format(
|
||||
f_out.write(storage_data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HNSWComponents:
|
||||
original_hnsw_data: dict[str, Any]
|
||||
assign_probas_np: np.ndarray
|
||||
cum_nneighbor_per_level_np: np.ndarray
|
||||
levels_np: np.ndarray
|
||||
is_compact: bool
|
||||
compact_level_ptr: Optional[np.ndarray] = None
|
||||
compact_node_offsets_np: Optional[np.ndarray] = None
|
||||
compact_neighbors_data: Optional[list[int]] = None
|
||||
offsets_np: Optional[np.ndarray] = None
|
||||
neighbors_np: Optional[np.ndarray] = None
|
||||
storage_fourcc: int = NULL_INDEX_FOURCC
|
||||
storage_data: bytes = b""
|
||||
|
||||
|
||||
def _read_hnsw_structure(f) -> HNSWComponents:
|
||||
original_hnsw_data: dict[str, Any] = {}
|
||||
|
||||
hnsw_index_fourcc = read_struct(f, "<I")
|
||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
raise ValueError(
|
||||
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
|
||||
)
|
||||
|
||||
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||
original_hnsw_data["d"] = read_struct(f, "<i")
|
||||
original_hnsw_data["ntotal"] = read_struct(f, "<q")
|
||||
original_hnsw_data["dummy1"] = read_struct(f, "<q")
|
||||
original_hnsw_data["dummy2"] = read_struct(f, "<q")
|
||||
original_hnsw_data["is_trained"] = read_struct(f, "?")
|
||||
original_hnsw_data["metric_type"] = read_struct(f, "<i")
|
||||
original_hnsw_data["metric_arg"] = 0.0
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
|
||||
|
||||
assign_probas_np = read_numpy_vector(f, np.float64, "d")
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
|
||||
levels_np = read_numpy_vector(f, np.int32, "i")
|
||||
|
||||
ntotal = len(levels_np)
|
||||
if ntotal != original_hnsw_data["ntotal"]:
|
||||
original_hnsw_data["ntotal"] = ntotal
|
||||
|
||||
pos_before_compact = f.tell()
|
||||
is_compact_flag = None
|
||||
try:
|
||||
is_compact_flag = read_struct(f, "<?")
|
||||
except EOFError:
|
||||
is_compact_flag = None
|
||||
|
||||
if is_compact_flag:
|
||||
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
|
||||
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||
|
||||
storage_fourcc = read_struct(f, "<I")
|
||||
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
storage_data = f.read()
|
||||
|
||||
return HNSWComponents(
|
||||
original_hnsw_data=original_hnsw_data,
|
||||
assign_probas_np=assign_probas_np,
|
||||
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||
levels_np=levels_np,
|
||||
is_compact=True,
|
||||
compact_level_ptr=compact_level_ptr,
|
||||
compact_node_offsets_np=compact_node_offsets_np,
|
||||
compact_neighbors_data=compact_neighbors_data,
|
||||
storage_fourcc=storage_fourcc,
|
||||
storage_data=storage_data,
|
||||
)
|
||||
|
||||
# Non-compact case
|
||||
f.seek(pos_before_compact)
|
||||
|
||||
pos_before_probe = f.tell()
|
||||
try:
|
||||
suspected_flag = read_struct(f, "<B")
|
||||
if suspected_flag != 0x00:
|
||||
f.seek(pos_before_probe)
|
||||
except EOFError:
|
||||
f.seek(pos_before_probe)
|
||||
|
||||
offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||
neighbors_np = read_numpy_vector(f, np.int32, "i")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||
|
||||
storage_fourcc = NULL_INDEX_FOURCC
|
||||
storage_data = b""
|
||||
try:
|
||||
storage_fourcc = read_struct(f, "<I")
|
||||
storage_data = f.read()
|
||||
except EOFError:
|
||||
storage_fourcc = NULL_INDEX_FOURCC
|
||||
|
||||
return HNSWComponents(
|
||||
original_hnsw_data=original_hnsw_data,
|
||||
assign_probas_np=assign_probas_np,
|
||||
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||
levels_np=levels_np,
|
||||
is_compact=False,
|
||||
offsets_np=offsets_np,
|
||||
neighbors_np=neighbors_np,
|
||||
storage_fourcc=storage_fourcc,
|
||||
storage_data=storage_data,
|
||||
)
|
||||
|
||||
|
||||
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
|
||||
with open(path, "rb") as f:
|
||||
return _read_hnsw_structure(f)
|
||||
|
||||
|
||||
def write_original_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
offsets_np,
|
||||
neighbors_np,
|
||||
storage_fourcc,
|
||||
storage_data,
|
||||
):
|
||||
"""Write non-compact HNSW data in original FAISS order."""
|
||||
|
||||
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
||||
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
||||
|
||||
write_numpy_vector(f_out, assign_probas_np, "d")
|
||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
||||
write_numpy_vector(f_out, levels_np, "i")
|
||||
|
||||
write_numpy_vector(f_out, offsets_np, "Q")
|
||||
write_numpy_vector(f_out, neighbors_np, "i")
|
||||
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
||||
|
||||
f_out.write(struct.pack("<I", storage_fourcc))
|
||||
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||
f_out.write(storage_data)
|
||||
|
||||
|
||||
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
|
||||
"""Rewrite an HNSW index while dropping the embedded storage section."""
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
||||
original_hnsw_data: dict[str, Any] = {}
|
||||
|
||||
hnsw_index_fourcc = read_struct(f_in, "<I")
|
||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
print(
|
||||
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
|
||||
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
||||
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["metric_arg"] = 0.0
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
||||
|
||||
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
|
||||
ntotal = len(levels_np)
|
||||
if ntotal != original_hnsw_data["ntotal"]:
|
||||
original_hnsw_data["ntotal"] = ntotal
|
||||
|
||||
pos_before_compact = f_in.tell()
|
||||
is_compact_flag = None
|
||||
try:
|
||||
is_compact_flag = read_struct(f_in, "<?")
|
||||
except EOFError:
|
||||
is_compact_flag = None
|
||||
|
||||
if is_compact_flag:
|
||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
|
||||
_storage_fourcc = read_struct(f_in, "<I")
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
_storage_data = f_in.read()
|
||||
|
||||
write_compact_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
compact_level_ptr,
|
||||
compact_node_offsets_np,
|
||||
compact_neighbors_data,
|
||||
NULL_INDEX_FOURCC,
|
||||
b"",
|
||||
)
|
||||
else:
|
||||
f_in.seek(pos_before_compact)
|
||||
|
||||
pos_before_probe = f_in.tell()
|
||||
try:
|
||||
suspected_flag = read_struct(f_in, "<B")
|
||||
if suspected_flag != 0x00:
|
||||
f_in.seek(pos_before_probe)
|
||||
except EOFError:
|
||||
f_in.seek(pos_before_probe)
|
||||
|
||||
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
|
||||
_storage_fourcc = None
|
||||
_storage_data = b""
|
||||
try:
|
||||
_storage_fourcc = read_struct(f_in, "<I")
|
||||
_storage_data = f_in.read()
|
||||
except EOFError:
|
||||
_storage_fourcc = NULL_INDEX_FOURCC
|
||||
|
||||
write_original_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
offsets_np,
|
||||
neighbors_np,
|
||||
NULL_INDEX_FOURCC,
|
||||
b"",
|
||||
)
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
|
||||
return True
|
||||
except Exception as exc:
|
||||
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
# --- Main Conversion Logic ---
|
||||
|
||||
|
||||
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
pass
|
||||
|
||||
|
||||
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
|
||||
"""Convenience wrapper to prune embeddings in-place."""
|
||||
|
||||
temp_path = f"{index_filename}.prune.tmp"
|
||||
success = prune_hnsw_embeddings(index_filename, temp_path)
|
||||
if success:
|
||||
try:
|
||||
os.replace(temp_path, index_filename)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.error(f"Failed to replace original index with pruned version: {exc}")
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return success
|
||||
|
||||
|
||||
# --- Script Execution ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
@@ -14,7 +14,7 @@ from leann.interface import (
|
||||
from leann.registry import register_backend
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,8 +90,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
|
||||
try:
|
||||
idmap_file = index_dir / f"{index_prefix}.ids.txt"
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for id_str in ids:
|
||||
f.write(str(id_str) + "\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write ID map: {e}")
|
||||
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
elif self.is_recompute:
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
@@ -133,10 +144,10 @@ class HNSWSearcher(BaseSearcher):
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
|
||||
self.is_compact, self.is_pruned = (
|
||||
self.meta.get("is_compact", True),
|
||||
self.meta.get("is_pruned", True),
|
||||
)
|
||||
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
|
||||
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
|
||||
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
|
||||
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
|
||||
|
||||
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||
if not index_file.exists():
|
||||
@@ -150,6 +161,16 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
# Load ID map if available
|
||||
self._id_map: list[str] = []
|
||||
try:
|
||||
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
|
||||
if idmap_file.exists():
|
||||
with open(idmap_file, encoding="utf-8") as f:
|
||||
self._id_map = [line.rstrip("\n") for line in f]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ID map: {e}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
@@ -194,6 +215,8 @@ class HNSWSearcher(BaseSearcher):
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
if hasattr(self._index, "set_zmq_port"):
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
@@ -248,6 +271,19 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
search_time = time.time() - search_time
|
||||
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
if self._id_map:
|
||||
|
||||
def map_label(x: int) -> str:
|
||||
if 0 <= x < len(self._id_map):
|
||||
return self._id_map[x]
|
||||
return str(x)
|
||||
|
||||
string_labels = [
|
||||
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
else:
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
@@ -24,13 +24,35 @@ logger = logging.getLogger(__name__)
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Ensure we have a handler if none exists
|
||||
# Ensure we have handlers if none exist
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
stream_handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
|
||||
if log_path:
|
||||
try:
|
||||
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
except Exception as exc: # pragma: no cover - best effort logging
|
||||
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
|
||||
|
||||
logger.propagate = False
|
||||
|
||||
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||
try:
|
||||
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||
PROVIDER_OPTIONS = {}
|
||||
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
@@ -92,6 +114,35 @@ def create_hnsw_embedding_server(
|
||||
embedding_dim = 0
|
||||
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||
|
||||
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
|
||||
id_map: list[str] = []
|
||||
try:
|
||||
meta_path = Path(passages_file)
|
||||
base = meta_path.name
|
||||
if base.endswith(".meta.json"):
|
||||
base = base[: -len(".meta.json")] # e.g., laion_index.leann
|
||||
if base.endswith(".leann"):
|
||||
base = base[: -len(".leann")] # e.g., laion_index
|
||||
idmap_file = meta_path.parent / f"{base}.ids.txt"
|
||||
if idmap_file.exists():
|
||||
with open(idmap_file, encoding="utf-8") as f:
|
||||
id_map = [line.rstrip("\n") for line in f]
|
||||
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
|
||||
else:
|
||||
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ID map: {e}")
|
||||
|
||||
def _map_node_id(nid) -> str:
|
||||
try:
|
||||
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
|
||||
idx = int(nid)
|
||||
if 0 <= idx < len(id_map):
|
||||
return id_map[idx]
|
||||
except Exception:
|
||||
pass
|
||||
return str(nid)
|
||||
|
||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||
|
||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||
@@ -138,7 +189,12 @@ def create_hnsw_embedding_server(
|
||||
):
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request)
|
||||
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
||||
embeddings = compute_embeddings(
|
||||
request,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
@@ -168,13 +224,14 @@ def create_hnsw_embedding_server(
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
except Exception as e:
|
||||
@@ -187,7 +244,10 @@ def create_hnsw_embedding_server(
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts, model_name, mode=embedding_mode
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
@@ -238,13 +298,14 @@ def create_hnsw_embedding_server(
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
@@ -252,7 +313,12 @@ def create_hnsw_embedding_server(
|
||||
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.3.4"
|
||||
version = "0.3.5"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = [
|
||||
"leann-core==0.3.4",
|
||||
"leann-core==0.3.5",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.3.4"
|
||||
version = "0.3.5"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
@@ -18,14 +18,16 @@ dependencies = [
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
"torch>=2.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"sentence-transformers>=3.0.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"llama-index-readers-file>=0.4.0", # Essential for document reading
|
||||
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
|
||||
"python-dotenv>=1.0.0",
|
||||
"openai>=1.0.0",
|
||||
"huggingface-hub>=0.20.0",
|
||||
"transformers>=4.30.0",
|
||||
# Keep transformers below 4.46: 4.46.0 adds Python 3.10-only return type syntax and
|
||||
# breaks Python 3.9 environments.
|
||||
"transformers>=4.30.0,<4.46",
|
||||
"requests>=2.25.0",
|
||||
"accelerate>=0.20.0",
|
||||
"PyPDF2>=3.0.0",
|
||||
@@ -40,7 +42,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
colab = [
|
||||
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
|
||||
"transformers>=4.30.0,<5.0.0", # Limit transformers version
|
||||
"transformers>=4.30.0,<4.46", # 4.46.0 switches to PEP 604 typing (int | None), breaks Py3.9
|
||||
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
|
||||
]
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ with the correct, original embedding logic from the user's reference code.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import subprocess
|
||||
@@ -15,10 +16,13 @@ from pathlib import Path
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||
|
||||
from leann.interactive_utils import create_api_session
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
|
||||
from .chat import get_llm
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .metadata_filter import MetadataFilterEngine
|
||||
from .registry import BACKEND_REGISTRY
|
||||
@@ -38,6 +42,7 @@ def compute_embeddings(
|
||||
use_server: bool = True,
|
||||
port: Optional[int] = None,
|
||||
is_build=False,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -71,6 +76,7 @@ def compute_embeddings(
|
||||
model_name,
|
||||
mode=mode,
|
||||
is_build=is_build,
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -277,6 +283,7 @@ class LeannBuilder:
|
||||
embedding_model: str = "facebook/contriever",
|
||||
dimensions: Optional[int] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
embedding_options: Optional[dict[str, Any]] = None,
|
||||
**backend_kwargs,
|
||||
):
|
||||
self.backend_name = backend_name
|
||||
@@ -299,6 +306,7 @@ class LeannBuilder:
|
||||
self.embedding_model = embedding_model
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
self.embedding_options = embedding_options or {}
|
||||
|
||||
# Check if we need to use cosine distance for normalized embeddings
|
||||
normalized_embeddings_models = {
|
||||
@@ -406,6 +414,7 @@ class LeannBuilder:
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
provider_options=self.embedding_options,
|
||||
)[0]
|
||||
)
|
||||
path = Path(index_path)
|
||||
@@ -445,8 +454,20 @@ class LeannBuilder:
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
is_build=True,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
|
||||
try:
|
||||
idmap_file = (
|
||||
index_dir
|
||||
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
|
||||
)
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for sid in string_ids:
|
||||
f.write(str(sid) + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
||||
@@ -471,14 +492,15 @@ class LeannBuilder:
|
||||
],
|
||||
}
|
||||
|
||||
if self.embedding_options:
|
||||
meta_data["embedding_options"] = self.embedding_options
|
||||
|
||||
# Add storage status flags for HNSW backend
|
||||
if self.backend_name == "hnsw":
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||
meta_data["is_compact"] = is_compact
|
||||
meta_data["is_pruned"] = (
|
||||
is_compact and is_recompute
|
||||
) # Pruned only if compact and recompute
|
||||
meta_data["is_pruned"] = bool(is_recompute)
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
@@ -565,6 +587,17 @@ class LeannBuilder:
|
||||
|
||||
# Build the vector index using precomputed embeddings
|
||||
string_ids = [str(id_val) for id_val in ids]
|
||||
# Persist ID map (order == embeddings order)
|
||||
try:
|
||||
idmap_file = (
|
||||
index_dir
|
||||
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
|
||||
)
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for sid in string_ids:
|
||||
f.write(str(sid) + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
builder_instance.build(embeddings, string_ids, index_path)
|
||||
@@ -593,18 +626,242 @@ class LeannBuilder:
|
||||
"embeddings_source": str(embeddings_file),
|
||||
}
|
||||
|
||||
if self.embedding_options:
|
||||
meta_data["embedding_options"] = self.embedding_options
|
||||
|
||||
# Add storage status flags for HNSW backend
|
||||
if self.backend_name == "hnsw":
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||
meta_data["is_compact"] = is_compact
|
||||
meta_data["is_pruned"] = is_compact and is_recompute
|
||||
meta_data["is_pruned"] = bool(is_recompute)
|
||||
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||
|
||||
def update_index(self, index_path: str):
|
||||
"""Append new passages and vectors to an existing HNSW index."""
|
||||
if not self.chunks:
|
||||
raise ValueError("No new chunks provided for update.")
|
||||
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_name = path.name
|
||||
index_prefix = path.stem
|
||||
|
||||
meta_path = index_dir / f"{index_name}.meta.json"
|
||||
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
|
||||
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
|
||||
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
|
||||
if not index_file.exists():
|
||||
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
|
||||
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
backend_name = meta.get("backend_name")
|
||||
if backend_name != self.backend_name:
|
||||
raise ValueError(
|
||||
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
|
||||
)
|
||||
|
||||
meta_backend_kwargs = meta.get("backend_kwargs", {})
|
||||
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
|
||||
if index_is_compact:
|
||||
raise ValueError(
|
||||
"Compact HNSW indices do not support in-place updates. Rebuild required."
|
||||
)
|
||||
|
||||
distance_metric = meta_backend_kwargs.get(
|
||||
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
|
||||
).lower()
|
||||
needs_recompute = bool(
|
||||
meta.get("is_pruned")
|
||||
or meta_backend_kwargs.get("is_recompute")
|
||||
or self.backend_kwargs.get("is_recompute")
|
||||
)
|
||||
|
||||
with open(offset_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
existing_ids = set(offset_map.keys())
|
||||
|
||||
valid_chunks: list[dict[str, Any]] = []
|
||||
for chunk in self.chunks:
|
||||
text = chunk.get("text", "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
continue
|
||||
metadata = chunk.setdefault("metadata", {})
|
||||
passage_id = chunk.get("id") or metadata.get("id")
|
||||
if passage_id and passage_id in existing_ids:
|
||||
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||
valid_chunks.append(chunk)
|
||||
|
||||
if not valid_chunks:
|
||||
raise ValueError("No valid chunks to append.")
|
||||
|
||||
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed,
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
is_build=True,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
|
||||
embedding_dim = embeddings.shape[1]
|
||||
expected_dim = meta.get("dimensions")
|
||||
if expected_dim is not None and expected_dim != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
|
||||
)
|
||||
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
|
||||
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
if distance_metric == "cosine":
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
|
||||
index = faiss.read_index(str(index_file))
|
||||
if hasattr(index, "is_recompute"):
|
||||
index.is_recompute = needs_recompute
|
||||
print(f"index.is_recompute: {index.is_recompute}")
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
# Faiss expects storage.ntotal to reflect the existing graph's
|
||||
# population (even if the vectors themselves were pruned from disk
|
||||
# for recompute mode). When we attach a fresh IndexFlat here its
|
||||
# ntotal starts at zero, which later causes IndexHNSW::add to
|
||||
# believe new "preset" levels were provided and trips the
|
||||
# `n0 + n == levels.size()` assertion. Seed the temporary storage
|
||||
# with the current ntotal so Faiss maintains the proper offset for
|
||||
# incoming vectors.
|
||||
try:
|
||||
storage_index.ntotal = index.ntotal
|
||||
except AttributeError:
|
||||
# Older Faiss builds may not expose ntotal as a writable
|
||||
# attribute; in that case we fall back to the default behaviour.
|
||||
pass
|
||||
if index.d != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
||||
)
|
||||
|
||||
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
|
||||
passage_provider_options = meta.get("embedding_options", self.embedding_options)
|
||||
|
||||
base_id = index.ntotal
|
||||
for offset, chunk in enumerate(valid_chunks):
|
||||
new_id = str(base_id + offset)
|
||||
chunk.setdefault("metadata", {})["id"] = new_id
|
||||
chunk["id"] = new_id
|
||||
|
||||
# Append passages/offsets before we attempt index.add so the ZMQ server
|
||||
# can resolve newly assigned IDs during recompute. Keep rollback hooks
|
||||
# so we can restore files if the update fails mid-way.
|
||||
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||
offset_map_backup = offset_map.copy()
|
||||
|
||||
try:
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for chunk in valid_chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
server_manager: Optional[EmbeddingServerManager] = None
|
||||
server_started = False
|
||||
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
|
||||
|
||||
try:
|
||||
if needs_recompute:
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
server_started, actual_port = server_manager.start_server(
|
||||
port=requested_zmq_port,
|
||||
model_name=self.embedding_model,
|
||||
embedding_mode=passage_meta_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=distance_metric,
|
||||
provider_options=passage_provider_options,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(
|
||||
"Failed to start HNSW embedding server for recompute update."
|
||||
)
|
||||
if actual_port != requested_zmq_port:
|
||||
logger.warning(
|
||||
"Embedding server started on port %s instead of requested %s. "
|
||||
"Using reassigned port.",
|
||||
actual_port,
|
||||
requested_zmq_port,
|
||||
)
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(actual_port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(actual_port)
|
||||
|
||||
if needs_recompute:
|
||||
for i in range(embeddings.shape[0]):
|
||||
print(f"add {i} embeddings")
|
||||
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||
else:
|
||||
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||
faiss.write_index(index, str(index_file))
|
||||
finally:
|
||||
if server_started and server_manager is not None:
|
||||
server_manager.stop_server()
|
||||
|
||||
except Exception:
|
||||
# Roll back appended passages/offset map to keep files consistent.
|
||||
if passages_file.exists():
|
||||
with open(passages_file, "rb+") as f:
|
||||
f.truncate(rollback_passages_size)
|
||||
offset_map = offset_map_backup
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
raise
|
||||
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
logger.info(
|
||||
"Appended %d passages to index '%s'. New total: %d",
|
||||
len(valid_chunks),
|
||||
index_path,
|
||||
len(offset_map),
|
||||
)
|
||||
|
||||
self.chunks.clear()
|
||||
|
||||
if needs_recompute:
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
@@ -628,6 +885,7 @@ class LeannSearcher:
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
self.embedding_options = self.meta_data.get("embedding_options", {})
|
||||
# Delegate portability handling to PassageManager
|
||||
self.passage_manager = PassageManager(
|
||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||
@@ -639,6 +897,8 @@ class LeannSearcher:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
if self.embedding_options:
|
||||
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
@@ -656,6 +916,7 @@ class LeannSearcher:
|
||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||
batch_size: int = 0,
|
||||
use_grep: bool = False,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
@@ -719,10 +980,24 @@ class LeannSearcher:
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Extract query template from stored embedding_options with fallback chain:
|
||||
# 1. Check provider_options override (highest priority)
|
||||
# 2. Check query_prompt_template (new format)
|
||||
# 3. Check prompt_template (old format for backward compat)
|
||||
# 4. None (no template)
|
||||
query_template = None
|
||||
if provider_options and "prompt_template" in provider_options:
|
||||
query_template = provider_options["prompt_template"]
|
||||
elif "query_prompt_template" in self.embedding_options:
|
||||
query_template = self.embedding_options["query_prompt_template"]
|
||||
elif "prompt_template" in self.embedding_options:
|
||||
query_template = self.embedding_options["prompt_template"]
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
zmq_port=zmq_port,
|
||||
query_template=query_template,
|
||||
)
|
||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
@@ -976,6 +1251,17 @@ class LeannChat:
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
logger.info("The context provided to the LLM is:")
|
||||
logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
||||
logger.info("-" * 150)
|
||||
for r in results:
|
||||
chunk_relevance = f"{r.score:.3f}"
|
||||
chunk_id = r.id
|
||||
chunk_content = r.text[:60]
|
||||
chunk_source = r.metadata.get("source", "")[:80]
|
||||
logger.info(
|
||||
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
|
||||
)
|
||||
ask_time = time.time()
|
||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||
ask_time = time.time() - ask_time
|
||||
@@ -983,19 +1269,14 @@ class LeannChat:
|
||||
return ans
|
||||
|
||||
def start_interactive(self):
|
||||
print("\nLeann Chat started (type 'quit' to exit)")
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
if user_input.lower() in ["quit", "exit"]:
|
||||
break
|
||||
if not user_input:
|
||||
continue
|
||||
response = self.ask(user_input)
|
||||
print(f"Leann: {response}")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
"""Start interactive chat session."""
|
||||
session = create_api_session()
|
||||
|
||||
def handle_query(user_input: str):
|
||||
response = self.ask(user_input)
|
||||
print(f"Leann: {response}")
|
||||
|
||||
session.run_interactive_loop(handle_query)
|
||||
|
||||
def cleanup(self):
|
||||
"""Explicitly cleanup embedding server resources.
|
||||
|
||||
@@ -12,6 +12,14 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .settings import (
|
||||
resolve_anthropic_api_key,
|
||||
resolve_anthropic_base_url,
|
||||
resolve_ollama_host,
|
||||
resolve_openai_api_key,
|
||||
resolve_openai_base_url,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -310,11 +318,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||
|
||||
|
||||
def validate_model_and_suggest(
|
||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||
model_name: str, llm_type: str, host: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Validate model name and provide suggestions if invalid"""
|
||||
if llm_type == "ollama":
|
||||
available_models = check_ollama_models(host)
|
||||
resolved_host = resolve_ollama_host(host)
|
||||
available_models = check_ollama_models(resolved_host)
|
||||
if available_models and model_name not in available_models:
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
|
||||
@@ -457,19 +466,19 @@ class LLMInterface(ABC):
|
||||
class OllamaChat(LLMInterface):
|
||||
"""LLM interface for Ollama models."""
|
||||
|
||||
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
||||
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
|
||||
self.model = model
|
||||
self.host = host
|
||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
||||
self.host = resolve_ollama_host(host)
|
||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
|
||||
try:
|
||||
import requests
|
||||
|
||||
# Check if the Ollama server is responsive
|
||||
if host:
|
||||
requests.get(host)
|
||||
if self.host:
|
||||
requests.get(self.host)
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
||||
model_error = validate_model_and_suggest(model, "ollama", self.host)
|
||||
if model_error:
|
||||
raise ValueError(model_error)
|
||||
|
||||
@@ -478,9 +487,11 @@ class OllamaChat(LLMInterface):
|
||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
||||
logger.error(
|
||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||
)
|
||||
raise ConnectionError(
|
||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||
)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
@@ -541,11 +552,30 @@ class OllamaChat(LLMInterface):
|
||||
|
||||
|
||||
class HFChat(LLMInterface):
|
||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates.
|
||||
|
||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||
Args:
|
||||
model_name (str): Name of the Hugging Face model to load.
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models as this can pose
|
||||
a security risk if the model repository is compromised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat", trust_remote_code: bool = False
|
||||
):
|
||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||
|
||||
# Security warning when trust_remote_code is enabled
|
||||
if trust_remote_code:
|
||||
logger.warning(
|
||||
"SECURITY WARNING: trust_remote_code=True allows execution of arbitrary code from the model repository. "
|
||||
"Only enable this for models from trusted sources. This creates a potential security risk if the model "
|
||||
"repository is compromised."
|
||||
)
|
||||
|
||||
self.trust_remote_code = trust_remote_code
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model_name, "hf")
|
||||
if model_error:
|
||||
@@ -583,14 +613,16 @@ class HFChat(LLMInterface):
|
||||
|
||||
try:
|
||||
logger.info(f"Loading tokenizer for {model_name}...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
|
||||
logger.info(f"Loading model {model_name}...")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
logger.info(f"Successfully loaded {model_name}")
|
||||
finally:
|
||||
@@ -737,21 +769,31 @@ class GeminiChat(LLMInterface):
|
||||
class OpenAIChat(LLMInterface):
|
||||
"""LLM interface for OpenAI models."""
|
||||
|
||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = resolve_openai_base_url(base_url)
|
||||
self.api_key = resolve_openai_api_key(api_key)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||
logger.info(
|
||||
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
|
||||
model,
|
||||
self.base_url,
|
||||
)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
self.client = openai.OpenAI(api_key=self.api_key)
|
||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||
@@ -798,12 +840,92 @@ class OpenAIChat(LLMInterface):
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(**params)
|
||||
print(
|
||||
f"Total tokens = {response.usage.total_tokens}, prompt tokens = {response.usage.prompt_tokens}, completion tokens = {response.usage.completion_tokens}"
|
||||
)
|
||||
if response.choices[0].finish_reason == "length":
|
||||
print("The query is exceeding the maximum allowed number of tokens")
|
||||
return response.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error communicating with OpenAI: {e}")
|
||||
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||
|
||||
|
||||
class AnthropicChat(LLMInterface):
|
||||
"""LLM interface for Anthropic Claude models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-haiku-4-5",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = resolve_anthropic_base_url(base_url)
|
||||
self.api_key = resolve_anthropic_api_key(api_key)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
|
||||
model,
|
||||
self.base_url,
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
# Allow custom Anthropic-compatible endpoints via base_url
|
||||
self.client = anthropic.Anthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
|
||||
)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
logger.info(f"Sending request to Anthropic with model {self.model}")
|
||||
|
||||
try:
|
||||
# Anthropic API parameters
|
||||
params = {
|
||||
"model": self.model,
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if "temperature" in kwargs:
|
||||
params["temperature"] = kwargs["temperature"]
|
||||
if "top_p" in kwargs:
|
||||
params["top_p"] = kwargs["top_p"]
|
||||
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
# Extract text from response
|
||||
response_text = response.content[0].text
|
||||
|
||||
# Log token usage
|
||||
print(
|
||||
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
|
||||
f"input tokens = {response.usage.input_tokens}, "
|
||||
f"output tokens = {response.usage.output_tokens}"
|
||||
)
|
||||
|
||||
if response.stop_reason == "max_tokens":
|
||||
print("The query is exceeding the maximum allowed number of tokens")
|
||||
|
||||
return response_text.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error communicating with Anthropic: {e}")
|
||||
return f"Error: Could not get a response from Anthropic. Details: {e}"
|
||||
|
||||
|
||||
class SimulatedChat(LLMInterface):
|
||||
"""A simple simulated chat for testing and development."""
|
||||
|
||||
@@ -841,14 +963,27 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
if llm_type == "ollama":
|
||||
return OllamaChat(
|
||||
model=model or "llama3:8b",
|
||||
host=llm_config.get("host", "http://localhost:11434"),
|
||||
host=llm_config.get("host"),
|
||||
)
|
||||
elif llm_type == "hf":
|
||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||
return HFChat(
|
||||
model_name=model or "deepseek-ai/deepseek-llm-7b-chat",
|
||||
trust_remote_code=llm_config.get("trust_remote_code", False),
|
||||
)
|
||||
elif llm_type == "openai":
|
||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||
return OpenAIChat(
|
||||
model=model or "gpt-4o",
|
||||
api_key=llm_config.get("api_key"),
|
||||
base_url=llm_config.get("base_url"),
|
||||
)
|
||||
elif llm_type == "gemini":
|
||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "anthropic":
|
||||
return AnthropicChat(
|
||||
model=model or "claude-3-5-sonnet-20241022",
|
||||
api_key=llm_config.get("api_key"),
|
||||
base_url=llm_config.get("base_url"),
|
||||
)
|
||||
elif llm_type == "simulated":
|
||||
return SimulatedChat()
|
||||
else:
|
||||
|
||||
@@ -5,12 +5,128 @@ Packaged within leann-core so installed wheels can import it reliably.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Flag to ensure AST token warning only shown once per session
|
||||
_ast_token_warning_shown = False
|
||||
|
||||
|
||||
def estimate_token_count(text: str) -> int:
|
||||
"""
|
||||
Estimate token count for a text string.
|
||||
Uses conservative estimation: ~4 characters per token for natural text,
|
||||
~1.2 tokens per character for code (worse tokenization).
|
||||
|
||||
Args:
|
||||
text: Input text to estimate tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoder.encode(text))
|
||||
except ImportError:
|
||||
# Fallback: Conservative character-based estimation
|
||||
# Assume worst case for code: 1.2 tokens per character
|
||||
return int(len(text) * 1.2)
|
||||
|
||||
|
||||
def calculate_safe_chunk_size(
|
||||
model_token_limit: int,
|
||||
overlap_tokens: int,
|
||||
chunking_mode: str = "traditional",
|
||||
safety_factor: float = 0.9,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate safe chunk size accounting for overlap and safety margin.
|
||||
|
||||
Args:
|
||||
model_token_limit: Maximum tokens supported by embedding model
|
||||
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
|
||||
chunking_mode: "traditional" (tokens) or "ast" (characters)
|
||||
safety_factor: Safety margin (0.9 = 10% safety margin)
|
||||
|
||||
Returns:
|
||||
Safe chunk size: tokens for traditional, characters for AST
|
||||
"""
|
||||
safe_limit = int(model_token_limit * safety_factor)
|
||||
|
||||
if chunking_mode == "traditional":
|
||||
# Traditional chunking uses tokens
|
||||
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
|
||||
return max(1, safe_limit - overlap_tokens)
|
||||
else: # AST chunking
|
||||
# AST uses characters, need to convert
|
||||
# Conservative estimate: 1.2 tokens per char for code
|
||||
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
|
||||
safe_chars = int(safe_limit / 1.2)
|
||||
return max(1, safe_chars - overlap_chars)
|
||||
|
||||
|
||||
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
|
||||
"""
|
||||
Validate that chunks don't exceed token limits and truncate if necessary.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to validate
|
||||
max_tokens: Maximum tokens allowed per chunk
|
||||
|
||||
Returns:
|
||||
Tuple of (validated_chunks, num_truncated)
|
||||
"""
|
||||
validated_chunks = []
|
||||
num_truncated = 0
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
estimated_tokens = estimate_token_count(chunk)
|
||||
|
||||
if estimated_tokens > max_tokens:
|
||||
# Truncate chunk to fit token limit
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = encoder.encode(chunk)
|
||||
if len(tokens) > max_tokens:
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
truncated_chunk = encoder.decode(truncated_tokens)
|
||||
validated_chunks.append(truncated_chunk)
|
||||
num_truncated += 1
|
||||
logger.warning(
|
||||
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
|
||||
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
|
||||
)
|
||||
else:
|
||||
validated_chunks.append(chunk)
|
||||
except ImportError:
|
||||
# Fallback: Conservative character truncation
|
||||
char_limit = int(max_tokens / 1.2) # Conservative for code
|
||||
if len(chunk) > char_limit:
|
||||
truncated_chunk = chunk[:char_limit]
|
||||
validated_chunks.append(truncated_chunk)
|
||||
num_truncated += 1
|
||||
logger.warning(
|
||||
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
|
||||
f"(conservative estimate for {max_tokens} tokens)"
|
||||
)
|
||||
else:
|
||||
validated_chunks.append(chunk)
|
||||
else:
|
||||
validated_chunks.append(chunk)
|
||||
|
||||
if num_truncated > 0:
|
||||
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
|
||||
|
||||
return validated_chunks, num_truncated
|
||||
|
||||
|
||||
# Code file extensions supported by astchunk
|
||||
CODE_EXTENSIONS = {
|
||||
".py": "python",
|
||||
@@ -61,27 +177,45 @@ def create_ast_chunks(
|
||||
max_chunk_size: int = 512,
|
||||
chunk_overlap: int = 64,
|
||||
metadata_template: str = "default",
|
||||
) -> list[str]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Create AST-aware chunks from code documents using astchunk.
|
||||
|
||||
Falls back to traditional chunking if astchunk is unavailable.
|
||||
|
||||
Returns:
|
||||
List of dicts with {"text": str, "metadata": dict}
|
||||
"""
|
||||
try:
|
||||
from astchunk import ASTChunkBuilder # optional dependency
|
||||
except ImportError as e:
|
||||
logger.error(f"astchunk not available: {e}")
|
||||
logger.info("Falling back to traditional chunking for code files")
|
||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
|
||||
|
||||
all_chunks = []
|
||||
for doc in documents:
|
||||
language = doc.metadata.get("language")
|
||||
if not language:
|
||||
logger.warning("No language detected; falling back to traditional chunking")
|
||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||
continue
|
||||
|
||||
try:
|
||||
# Warn once if AST chunk size + overlap might exceed common token limits
|
||||
# Note: Actual truncation happens at embedding time with dynamic model limits
|
||||
global _ast_token_warning_shown
|
||||
estimated_max_tokens = int(
|
||||
(max_chunk_size + chunk_overlap) * 1.2
|
||||
) # Conservative estimate
|
||||
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
|
||||
logger.warning(
|
||||
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
||||
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
||||
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
|
||||
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
|
||||
)
|
||||
_ast_token_warning_shown = True
|
||||
|
||||
configs = {
|
||||
"max_chunk_size": max_chunk_size,
|
||||
"language": language,
|
||||
@@ -105,17 +239,40 @@ def create_ast_chunks(
|
||||
|
||||
chunks = chunk_builder.chunkify(code_content)
|
||||
for chunk in chunks:
|
||||
chunk_text = None
|
||||
astchunk_metadata = {}
|
||||
|
||||
if hasattr(chunk, "text"):
|
||||
chunk_text = chunk.text
|
||||
elif isinstance(chunk, dict) and "text" in chunk:
|
||||
chunk_text = chunk["text"]
|
||||
elif isinstance(chunk, str):
|
||||
chunk_text = chunk
|
||||
elif isinstance(chunk, dict):
|
||||
# Handle astchunk format: {"content": "...", "metadata": {...}}
|
||||
if "content" in chunk:
|
||||
chunk_text = chunk["content"]
|
||||
astchunk_metadata = chunk.get("metadata", {})
|
||||
elif "text" in chunk:
|
||||
chunk_text = chunk["text"]
|
||||
else:
|
||||
chunk_text = str(chunk) # Last resort
|
||||
else:
|
||||
chunk_text = str(chunk)
|
||||
|
||||
if chunk_text and chunk_text.strip():
|
||||
all_chunks.append(chunk_text.strip())
|
||||
# Extract document-level metadata
|
||||
doc_metadata = {
|
||||
"file_path": doc.metadata.get("file_path", ""),
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
}
|
||||
if "creation_date" in doc.metadata:
|
||||
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||
if "last_modified_date" in doc.metadata:
|
||||
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||
|
||||
# Merge document metadata + astchunk metadata
|
||||
combined_metadata = {**doc_metadata, **astchunk_metadata}
|
||||
|
||||
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
|
||||
|
||||
logger.info(
|
||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||
@@ -123,15 +280,19 @@ def create_ast_chunks(
|
||||
except Exception as e:
|
||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||
logger.info("Falling back to traditional chunking")
|
||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||
|
||||
return all_chunks
|
||||
|
||||
|
||||
def create_traditional_chunks(
|
||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||
) -> list[str]:
|
||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
|
||||
|
||||
Returns:
|
||||
List of dicts with {"text": str, "metadata": dict}
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||
chunk_size = 256
|
||||
@@ -147,19 +308,40 @@ def create_traditional_chunks(
|
||||
paragraph_separator="\n\n",
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
result = []
|
||||
for doc in documents:
|
||||
# Extract document-level metadata
|
||||
doc_metadata = {
|
||||
"file_path": doc.metadata.get("file_path", ""),
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
}
|
||||
if "creation_date" in doc.metadata:
|
||||
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||
if "last_modified_date" in doc.metadata:
|
||||
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||
|
||||
try:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
for node in nodes:
|
||||
result.append({"text": node.get_content(), "metadata": doc_metadata})
|
||||
except Exception as e:
|
||||
logger.error(f"Traditional chunking failed for document: {e}")
|
||||
content = doc.get_content()
|
||||
if content and content.strip():
|
||||
all_texts.append(content.strip())
|
||||
result.append({"text": content.strip(), "metadata": doc_metadata})
|
||||
|
||||
return all_texts
|
||||
return result
|
||||
|
||||
|
||||
def _traditional_chunks_as_dicts(
|
||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Helper: Traditional chunking that returns dict format for consistency.
|
||||
|
||||
This is now just an alias for create_traditional_chunks for backwards compatibility.
|
||||
"""
|
||||
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||
|
||||
|
||||
def create_text_chunks(
|
||||
@@ -171,8 +353,12 @@ def create_text_chunks(
|
||||
ast_chunk_overlap: int = 64,
|
||||
code_file_extensions: Optional[list[str]] = None,
|
||||
ast_fallback_traditional: bool = True,
|
||||
) -> list[str]:
|
||||
"""Create text chunks from documents with optional AST support for code files."""
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Create text chunks from documents with optional AST support for code files.
|
||||
|
||||
Returns:
|
||||
List of dicts with {"text": str, "metadata": dict}
|
||||
"""
|
||||
if not documents:
|
||||
logger.warning("No documents provided for chunking")
|
||||
return []
|
||||
@@ -207,14 +393,17 @@ def create_text_chunks(
|
||||
logger.error(f"AST chunking failed: {e}")
|
||||
if ast_fallback_traditional:
|
||||
all_chunks.extend(
|
||||
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
|
||||
)
|
||||
else:
|
||||
raise
|
||||
if text_docs:
|
||||
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
|
||||
else:
|
||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
|
||||
|
||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||
|
||||
# Note: Token truncation is now handled at embedding time with dynamic model limits
|
||||
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
|
||||
return all_chunks
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -8,7 +9,14 @@ from llama_index.core.node_parser import SentenceSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .interactive_utils import create_cli_session
|
||||
from .registry import register_project_directory
|
||||
from .settings import (
|
||||
resolve_anthropic_base_url,
|
||||
resolve_ollama_host,
|
||||
resolve_openai_api_key,
|
||||
resolve_openai_base_url,
|
||||
)
|
||||
|
||||
|
||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||
@@ -104,7 +112,7 @@ Examples:
|
||||
help="Documents directories and/or files (default: current directory)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--backend",
|
||||
"--backend-name",
|
||||
type=str,
|
||||
default="hnsw",
|
||||
choices=["hnsw", "diskann"],
|
||||
@@ -123,6 +131,36 @@ Examples:
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible embedding host",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible embedding services",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--query-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template for queries (different from build template for task-specific models)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||
)
|
||||
@@ -160,25 +198,25 @@ Examples:
|
||||
"--doc-chunk-size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Document chunk size in tokens/characters (default: 256)",
|
||||
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--doc-chunk-overlap",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Document chunk overlap (default: 128)",
|
||||
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--code-chunk-size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Code chunk size in tokens/lines (default: 512)",
|
||||
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--code-chunk-overlap",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Code chunk overlap (default: 50)",
|
||||
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--use-ast-chunking",
|
||||
@@ -188,14 +226,14 @@ Examples:
|
||||
build_parser.add_argument(
|
||||
"--ast-chunk-size",
|
||||
type=int,
|
||||
default=768,
|
||||
help="AST chunk size in characters (default: 768)",
|
||||
default=300,
|
||||
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--ast-chunk-overlap",
|
||||
type=int,
|
||||
default=96,
|
||||
help="AST chunk overlap in characters (default: 96)",
|
||||
default=64,
|
||||
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--ast-fallback-traditional",
|
||||
@@ -234,21 +272,42 @@ Examples:
|
||||
action="store_true",
|
||||
help="Non-interactive mode: automatically select index without prompting",
|
||||
)
|
||||
search_parser.add_argument(
|
||||
"--show-metadata",
|
||||
action="store_true",
|
||||
help="Display file paths and metadata in search results",
|
||||
)
|
||||
search_parser.add_argument(
|
||||
"--embedding-prompt-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)",
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
ask_parser.add_argument("index_name", help="Index name")
|
||||
ask_parser.add_argument(
|
||||
"query",
|
||||
nargs="?",
|
||||
help="Question to ask (omit for prompt or when using --interactive)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="ollama",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
choices=["simulated", "ollama", "hf", "openai", "anthropic"],
|
||||
help="LLM provider (default: ollama)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
||||
)
|
||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||
ask_parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
||||
)
|
||||
@@ -277,6 +336,18 @@ Examples:
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for cloud LLM providers (OpenAI, Anthropic)",
|
||||
)
|
||||
|
||||
# List command
|
||||
subparsers.add_parser("list", help="List all indexes")
|
||||
@@ -1114,6 +1185,11 @@ Examples:
|
||||
print(f"Warning: Could not process {file_path}: {e}")
|
||||
|
||||
# Load other file types with default reader
|
||||
# Exclude PDFs from code_extensions if they were already processed separately
|
||||
other_file_extensions = code_extensions
|
||||
if should_process_pdfs and ".pdf" in code_extensions:
|
||||
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
|
||||
|
||||
try:
|
||||
# Create a custom file filter function using our PathSpec
|
||||
def file_filter(
|
||||
@@ -1129,21 +1205,26 @@ Examples:
|
||||
except (ValueError, OSError):
|
||||
return True # Include files that can't be processed
|
||||
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=code_extensions,
|
||||
file_extractor={}, # Use default extractors
|
||||
exclude_hidden=not include_hidden,
|
||||
filename_as_id=True,
|
||||
).load_data(show_progress=True)
|
||||
# Only load other file types if there are extensions to process
|
||||
if other_file_extensions:
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=other_file_extensions,
|
||||
file_extractor={}, # Use default extractors
|
||||
exclude_hidden=not include_hidden,
|
||||
filename_as_id=True,
|
||||
).load_data(show_progress=True)
|
||||
else:
|
||||
other_docs = []
|
||||
|
||||
# Filter documents after loading based on gitignore rules
|
||||
filtered_docs = []
|
||||
for doc in other_docs:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_filter(file_path):
|
||||
doc.metadata["source"] = file_path
|
||||
filtered_docs.append(doc)
|
||||
|
||||
documents.extend(filtered_docs)
|
||||
@@ -1219,7 +1300,7 @@ Examples:
|
||||
from .chunking_utils import create_text_chunks
|
||||
|
||||
# Use enhanced chunking with AST support
|
||||
all_texts = create_text_chunks(
|
||||
chunk_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=self.node_parser.chunk_size,
|
||||
chunk_overlap=self.node_parser.chunk_overlap,
|
||||
@@ -1230,6 +1311,9 @@ Examples:
|
||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||
)
|
||||
|
||||
# create_text_chunks now returns list[dict] with metadata preserved
|
||||
all_texts.extend(chunk_texts)
|
||||
|
||||
except ImportError as e:
|
||||
print(
|
||||
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
||||
@@ -1241,14 +1325,27 @@ Examples:
|
||||
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
||||
# Check if this is a code file based on source path
|
||||
source_path = doc.metadata.get("source", "")
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
||||
|
||||
# Extract metadata to preserve with chunks
|
||||
chunk_metadata = {
|
||||
"file_path": file_path or source_path,
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
}
|
||||
|
||||
# Add optional metadata if available
|
||||
if "creation_date" in doc.metadata:
|
||||
chunk_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||
if "last_modified_date" in doc.metadata:
|
||||
chunk_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||
|
||||
# Use appropriate parser based on file type
|
||||
parser = self.code_parser if is_code_file else self.node_parser
|
||||
nodes = parser.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
all_texts.append({"text": node.get_content(), "metadata": chunk_metadata})
|
||||
|
||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||
return all_texts
|
||||
@@ -1323,12 +1420,30 @@ Examples:
|
||||
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||
print(f"Building index '{index_name}' with {args.backend_name} backend...")
|
||||
|
||||
embedding_options: dict[str, Any] = {}
|
||||
if args.embedding_mode == "ollama":
|
||||
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||
elif args.embedding_mode == "openai":
|
||||
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
if args.query_prompt_template:
|
||||
# New format: separate templates
|
||||
if args.embedding_prompt_template:
|
||||
embedding_options["build_prompt_template"] = args.embedding_prompt_template
|
||||
embedding_options["query_prompt_template"] = args.query_prompt_template
|
||||
elif args.embedding_prompt_template:
|
||||
# Old format: single template (backward compat)
|
||||
embedding_options["prompt_template"] = args.embedding_prompt_template
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
embedding_options=embedding_options or None,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.complexity,
|
||||
is_compact=args.compact,
|
||||
@@ -1336,8 +1451,8 @@ Examples:
|
||||
num_threads=args.num_threads,
|
||||
)
|
||||
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
for chunk in all_texts:
|
||||
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"Index built at {index_path}")
|
||||
@@ -1444,6 +1559,11 @@ Examples:
|
||||
print("Invalid input. Aborting search.")
|
||||
return
|
||||
|
||||
# Build provider_options for runtime override
|
||||
provider_options = {}
|
||||
if args.embedding_prompt_template:
|
||||
provider_options["prompt_template"] = args.embedding_prompt_template
|
||||
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
results = searcher.search(
|
||||
query,
|
||||
@@ -1453,12 +1573,31 @@ Examples:
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
provider_options=provider_options if provider_options else None,
|
||||
)
|
||||
|
||||
print(f"Search results for '{query}' (top {len(results)}):")
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"{i}. Score: {result.score:.3f}")
|
||||
|
||||
# Display metadata if flag is set
|
||||
if args.show_metadata and result.metadata:
|
||||
file_path = result.metadata.get("file_path", "")
|
||||
if file_path:
|
||||
print(f" 📄 File: {file_path}")
|
||||
|
||||
file_name = result.metadata.get("file_name", "")
|
||||
if file_name and file_name != file_path:
|
||||
print(f" 📝 Name: {file_name}")
|
||||
|
||||
# Show timestamps if available
|
||||
if "creation_date" in result.metadata:
|
||||
print(f" 🕐 Created: {result.metadata['creation_date']}")
|
||||
if "last_modified_date" in result.metadata:
|
||||
print(f" 🕑 Modified: {result.metadata['last_modified_date']}")
|
||||
|
||||
print(f" {result.text[:200]}...")
|
||||
print(f" Source: {result.metadata.get('source', '')}")
|
||||
print()
|
||||
|
||||
async def ask_questions(self, args):
|
||||
@@ -1476,58 +1615,58 @@ Examples:
|
||||
|
||||
llm_config = {"type": args.llm, "model": args.model}
|
||||
if args.llm == "ollama":
|
||||
llm_config["host"] = args.host
|
||||
llm_config["host"] = resolve_ollama_host(args.host)
|
||||
elif args.llm == "openai":
|
||||
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
|
||||
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||
if resolved_api_key:
|
||||
llm_config["api_key"] = resolved_api_key
|
||||
elif args.llm == "anthropic":
|
||||
# For Anthropic, pass base_url and API key if provided
|
||||
if args.api_base:
|
||||
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
|
||||
if args.api_key:
|
||||
llm_config["api_key"] = args.api_key
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
llm_kwargs: dict[str, Any] = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
def _ask_once(prompt: str) -> None:
|
||||
query_start_time = time.time()
|
||||
response = chat.ask(
|
||||
prompt,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
query_completion_time = time.time() - query_start_time
|
||||
print(f"LEANN: {response}")
|
||||
print(f"The query took {query_completion_time:.3f} seconds to finish")
|
||||
|
||||
initial_query = (args.query or "").strip()
|
||||
|
||||
if args.interactive:
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
# Create interactive session
|
||||
session = create_cli_session(index_name)
|
||||
|
||||
while True:
|
||||
user_input = input("\nYou: ").strip()
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
if initial_query:
|
||||
_ask_once(initial_query)
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
session.run_interactive_loop(_ask_once)
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
query = initial_query or input("Enter your question: ").strip()
|
||||
if not query:
|
||||
print("No question provided. Exiting.")
|
||||
return
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
_ask_once(query)
|
||||
|
||||
async def run(self, args=None):
|
||||
parser = self.create_parser()
|
||||
|
||||
@@ -4,20 +4,307 @@ Consolidates all embedding computation logic using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure performance
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
# Set up logger with proper level
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Token limit registry for embedding models
|
||||
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
|
||||
# Ollama models use dynamic discovery via /api/show
|
||||
EMBEDDING_MODEL_LIMITS = {
|
||||
# Nomic models (common across servers)
|
||||
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
|
||||
"nomic-embed-text-v1.5": 2048,
|
||||
"nomic-embed-text-v2": 512,
|
||||
# Other embedding models
|
||||
"mxbai-embed-large": 512,
|
||||
"all-minilm": 512,
|
||||
"bge-m3": 8192,
|
||||
"snowflake-arctic-embed": 512,
|
||||
# OpenAI models
|
||||
"text-embedding-3-small": 8192,
|
||||
"text-embedding-3-large": 8192,
|
||||
"text-embedding-ada-002": 8192,
|
||||
}
|
||||
|
||||
# Runtime cache for dynamically discovered token limits
|
||||
# Key: (model_name, base_url), Value: token_limit
|
||||
# Prevents repeated SDK/API calls for the same model
|
||||
_token_limit_cache: dict[tuple[str, str], int] = {}
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
default: int = 2048,
|
||||
) -> int:
|
||||
"""
|
||||
Get token limit for a given embedding model.
|
||||
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
||||
Caches discovered limits to prevent repeated API/SDK calls.
|
||||
|
||||
Args:
|
||||
model_name: Name of the embedding model
|
||||
base_url: Base URL of the embedding server (for dynamic discovery)
|
||||
default: Default token limit if model not found
|
||||
|
||||
Returns:
|
||||
Token limit for the model in tokens
|
||||
"""
|
||||
# Check cache first to avoid repeated SDK/API calls
|
||||
cache_key = (model_name, base_url or "")
|
||||
if cache_key in _token_limit_cache:
|
||||
cached_limit = _token_limit_cache[cache_key]
|
||||
logger.debug(f"Using cached token limit for {model_name}: {cached_limit}")
|
||||
return cached_limit
|
||||
|
||||
# Try Ollama dynamic discovery if base_url provided
|
||||
if base_url:
|
||||
# Detect Ollama servers by port or "ollama" in URL
|
||||
if "11434" in base_url or "ollama" in base_url.lower():
|
||||
limit = _query_ollama_context_limit(model_name, base_url)
|
||||
if limit:
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Try LM Studio SDK discovery
|
||||
if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower():
|
||||
# Convert HTTP to WebSocket URL
|
||||
ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://")
|
||||
# Remove /v1 suffix if present
|
||||
if ws_url.endswith("/v1"):
|
||||
ws_url = ws_url[:-3]
|
||||
|
||||
limit = _query_lmstudio_context_limit(model_name, ws_url)
|
||||
if limit:
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Fallback to known model registry with version handling (from PR #154)
|
||||
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
||||
base_model_name = model_name.split(":")[0]
|
||||
|
||||
# Check exact match first
|
||||
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||
limit = EMBEDDING_MODEL_LIMITS[model_name]
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Check base name match
|
||||
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||
limit = EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||
_token_limit_cache[cache_key] = limit
|
||||
return limit
|
||||
|
||||
# Check partial matches for common patterns
|
||||
for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items():
|
||||
if known_model in base_model_name or base_model_name in known_model:
|
||||
_token_limit_cache[cache_key] = registry_limit
|
||||
return registry_limit
|
||||
|
||||
# Default fallback
|
||||
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
||||
_token_limit_cache[cache_key] = default
|
||||
return default
|
||||
|
||||
|
||||
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
|
||||
"""
|
||||
Truncate texts to fit within token limit using tiktoken.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to truncate
|
||||
token_limit: Maximum number of tokens allowed
|
||||
|
||||
Returns:
|
||||
List of truncated texts (same length as input)
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Use tiktoken with cl100k_base encoding
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
truncated_texts = []
|
||||
truncation_count = 0
|
||||
total_tokens_removed = 0
|
||||
max_original_length = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
tokens = enc.encode(text)
|
||||
original_length = len(tokens)
|
||||
|
||||
if original_length <= token_limit:
|
||||
# Text is within limit, keep as is
|
||||
truncated_texts.append(text)
|
||||
else:
|
||||
# Truncate to token_limit
|
||||
truncated_tokens = tokens[:token_limit]
|
||||
truncated_text = enc.decode(truncated_tokens)
|
||||
truncated_texts.append(truncated_text)
|
||||
|
||||
# Track truncation statistics
|
||||
truncation_count += 1
|
||||
tokens_removed = original_length - token_limit
|
||||
total_tokens_removed += tokens_removed
|
||||
max_original_length = max(max_original_length, original_length)
|
||||
|
||||
# Log individual truncation at WARNING level (first few only)
|
||||
if truncation_count <= 3:
|
||||
logger.warning(
|
||||
f"Text {i + 1} truncated: {original_length} → {token_limit} tokens "
|
||||
f"({tokens_removed} tokens removed)"
|
||||
)
|
||||
elif truncation_count == 4:
|
||||
logger.warning("Further truncation warnings suppressed...")
|
||||
|
||||
# Log summary at INFO level
|
||||
if truncation_count > 0:
|
||||
logger.warning(
|
||||
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
|
||||
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
|
||||
)
|
||||
|
||||
return truncated_texts
|
||||
|
||||
|
||||
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||
"""
|
||||
Query Ollama /api/show for model context limit.
|
||||
|
||||
Args:
|
||||
model_name: Name of the Ollama model
|
||||
base_url: Base URL of the Ollama server
|
||||
|
||||
Returns:
|
||||
Context limit in tokens if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/api/show",
|
||||
json={"name": model_name},
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "model_info" in data:
|
||||
# Look for *.context_length in model_info
|
||||
for key, value in data["model_info"].items():
|
||||
if "context_length" in key and isinstance(value, int):
|
||||
logger.info(f"Detected {model_name} context limit: {value} tokens")
|
||||
return value
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to query Ollama context limit: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||
"""
|
||||
Query LM Studio SDK for model context length via Node.js subprocess.
|
||||
|
||||
Args:
|
||||
model_name: Name of the LM Studio model
|
||||
base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234")
|
||||
|
||||
Returns:
|
||||
Context limit in tokens if found, None otherwise
|
||||
"""
|
||||
# Inline JavaScript using @lmstudio/sdk
|
||||
# Note: Load model temporarily for metadata, then unload to respect JIT auto-evict
|
||||
js_code = f"""
|
||||
const {{ LMStudioClient }} = require('@lmstudio/sdk');
|
||||
(async () => {{
|
||||
try {{
|
||||
const client = new LMStudioClient({{ baseUrl: '{base_url}' }});
|
||||
const model = await client.embedding.load('{model_name}', {{ verbose: false }});
|
||||
const contextLength = await model.getContextLength();
|
||||
await model.unload(); // Unload immediately to respect JIT auto-evict settings
|
||||
console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }}));
|
||||
}} catch (error) {{
|
||||
console.error(JSON.stringify({{ error: error.message }}));
|
||||
process.exit(1);
|
||||
}}
|
||||
}})();
|
||||
"""
|
||||
|
||||
try:
|
||||
# Set NODE_PATH to include global modules for @lmstudio/sdk resolution
|
||||
env = os.environ.copy()
|
||||
|
||||
# Try to get npm global root (works with nvm, brew node, etc.)
|
||||
try:
|
||||
npm_root = subprocess.run(
|
||||
["npm", "root", "-g"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if npm_root.returncode == 0:
|
||||
global_modules = npm_root.stdout.strip()
|
||||
# Append to existing NODE_PATH if present
|
||||
existing_node_path = env.get("NODE_PATH", "")
|
||||
env["NODE_PATH"] = (
|
||||
f"{global_modules}:{existing_node_path}"
|
||||
if existing_node_path
|
||||
else global_modules
|
||||
)
|
||||
except Exception:
|
||||
# If npm not available, continue with existing NODE_PATH
|
||||
pass
|
||||
|
||||
result = subprocess.run(
|
||||
["node", "-e", js_code],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.debug(f"LM Studio SDK error: {result.stderr}")
|
||||
return None
|
||||
|
||||
data = json.loads(result.stdout)
|
||||
context_length = data.get("contextLength")
|
||||
|
||||
if context_length and context_length > 0:
|
||||
logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}")
|
||||
return context_length
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("Node.js not found - install Node.js for LM Studio SDK features")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.debug("LM Studio SDK query timeout")
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("LM Studio SDK returned invalid JSON")
|
||||
except Exception as e:
|
||||
logger.debug(f"LM Studio SDK query failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: dict[str, Any] = {}
|
||||
|
||||
@@ -31,6 +318,7 @@ def compute_embeddings(
|
||||
adaptive_optimization: bool = True,
|
||||
manual_tokenize: bool = False,
|
||||
max_length: int = 512,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Unified embedding computation entry point
|
||||
@@ -46,6 +334,8 @@ def compute_embeddings(
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
provider_options = provider_options or {}
|
||||
|
||||
if mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
texts,
|
||||
@@ -57,11 +347,23 @@ def compute_embeddings(
|
||||
max_length=max_length,
|
||||
)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
return compute_embeddings_openai(
|
||||
texts,
|
||||
model_name,
|
||||
base_url=provider_options.get("base_url"),
|
||||
api_key=provider_options.get("api_key"),
|
||||
provider_options=provider_options,
|
||||
)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
elif mode == "ollama":
|
||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
||||
return compute_embeddings_ollama(
|
||||
texts,
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
host=provider_options.get("host"),
|
||||
provider_options=provider_options,
|
||||
)
|
||||
elif mode == "gemini":
|
||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||
else:
|
||||
@@ -168,32 +470,73 @@ def compute_embeddings_sentence_transformers(
|
||||
}
|
||||
|
||||
try:
|
||||
# Try local loading first
|
||||
model_kwargs["local_files_only"] = True
|
||||
tokenizer_kwargs["local_files_only"] = True
|
||||
# Try loading with advanced parameters first (newer versions)
|
||||
local_model_kwargs = model_kwargs.copy()
|
||||
local_tokenizer_kwargs = tokenizer_kwargs.copy()
|
||||
local_model_kwargs["local_files_only"] = True
|
||||
local_tokenizer_kwargs["local_files_only"] = True
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
model_kwargs=local_model_kwargs,
|
||||
tokenizer_kwargs=local_tokenizer_kwargs,
|
||||
local_files_only=True,
|
||||
)
|
||||
logger.info("Model loaded successfully! (local + optimized)")
|
||||
except TypeError as e:
|
||||
if "model_kwargs" in str(e) or "tokenizer_kwargs" in str(e):
|
||||
logger.warning(
|
||||
f"Advanced parameters not supported ({e}), using basic initialization..."
|
||||
)
|
||||
# Fallback to basic initialization for older versions
|
||||
try:
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
local_files_only=True,
|
||||
)
|
||||
logger.info("Model loaded successfully! (local + basic)")
|
||||
except Exception as e2:
|
||||
logger.warning(f"Local loading failed ({e2}), trying network download...")
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
local_files_only=False,
|
||||
)
|
||||
logger.info("Model loaded successfully! (network + basic)")
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Local loading failed ({e}), trying network download...")
|
||||
# Fallback to network loading
|
||||
model_kwargs["local_files_only"] = False
|
||||
tokenizer_kwargs["local_files_only"] = False
|
||||
# Fallback to network loading with advanced parameters
|
||||
try:
|
||||
network_model_kwargs = model_kwargs.copy()
|
||||
network_tokenizer_kwargs = tokenizer_kwargs.copy()
|
||||
network_model_kwargs["local_files_only"] = False
|
||||
network_tokenizer_kwargs["local_files_only"] = False
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
logger.info("Model loaded successfully! (network + optimized)")
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=network_model_kwargs,
|
||||
tokenizer_kwargs=network_tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
logger.info("Model loaded successfully! (network + optimized)")
|
||||
except TypeError as e2:
|
||||
if "model_kwargs" in str(e2) or "tokenizer_kwargs" in str(e2):
|
||||
logger.warning(
|
||||
f"Advanced parameters not supported ({e2}), using basic network loading..."
|
||||
)
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
local_files_only=False,
|
||||
)
|
||||
logger.info("Model loaded successfully! (network + basic)")
|
||||
else:
|
||||
raise
|
||||
|
||||
# Apply additional optimizations based on mode
|
||||
if use_fp16 and device in ["cuda", "mps"]:
|
||||
@@ -353,12 +696,16 @@ def compute_embeddings_sentence_transformers(
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||
def compute_embeddings_openai(
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
try:
|
||||
import os
|
||||
|
||||
import openai
|
||||
except ImportError as e:
|
||||
raise ImportError(f"OpenAI package not installed: {e}")
|
||||
@@ -373,24 +720,40 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||
)
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
# Extract base_url and api_key from provider_options if not provided directly
|
||||
provider_options = provider_options or {}
|
||||
effective_base_url = base_url or provider_options.get("base_url")
|
||||
effective_api_key = api_key or provider_options.get("api_key")
|
||||
|
||||
resolved_base_url = resolve_openai_base_url(effective_base_url)
|
||||
resolved_api_key = resolve_openai_api_key(effective_api_key)
|
||||
|
||||
if not resolved_api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
# Cache OpenAI client
|
||||
cache_key = "openai_client"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
_model_cache[cache_key] = client
|
||||
logger.info("OpenAI client cached")
|
||||
# Create OpenAI client
|
||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
)
|
||||
print(f"len of texts: {len(texts)}")
|
||||
|
||||
# Apply prompt template if provided
|
||||
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||
"prompt_template"
|
||||
)
|
||||
|
||||
if prompt_template:
|
||||
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||
texts = [f"{prompt_template}{text}" for text in texts]
|
||||
|
||||
# Query token limit and apply truncation
|
||||
token_limit = get_model_token_limit(model_name, base_url=effective_base_url)
|
||||
logger.info(f"Using token limit: {token_limit} for model '{model_name}'")
|
||||
texts = truncate_to_token_limit(texts, token_limit)
|
||||
|
||||
# OpenAI has limits on batch size and input length
|
||||
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
||||
all_embeddings = []
|
||||
@@ -421,7 +784,15 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Verify we got the expected number of embeddings
|
||||
if len(batch_embeddings) != len(batch_texts):
|
||||
logger.warning(
|
||||
f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}"
|
||||
)
|
||||
|
||||
# Only take the number of embeddings that match the batch size
|
||||
all_embeddings.extend(batch_embeddings[: len(batch_texts)])
|
||||
except Exception as e:
|
||||
logger.error(f"Batch {i} failed: {e}")
|
||||
raise
|
||||
@@ -507,18 +878,24 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
||||
|
||||
|
||||
def compute_embeddings_ollama(
|
||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
is_build: bool = False,
|
||||
host: Optional[str] = None,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API with simplified batch processing.
|
||||
Compute embeddings using Ollama API with true batch processing.
|
||||
|
||||
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
||||
Uses the /api/embed endpoint which supports batch inputs.
|
||||
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
host: Ollama host URL (default: http://localhost:11434)
|
||||
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||
provider_options: Optional provider-specific options (e.g., prompt_template)
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
@@ -533,17 +910,19 @@ def compute_embeddings_ollama(
|
||||
if not texts:
|
||||
raise ValueError("Cannot compute embeddings for empty text list")
|
||||
|
||||
resolved_host = resolve_ollama_host(host)
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
|
||||
)
|
||||
|
||||
# Check if Ollama is running
|
||||
try:
|
||||
response = requests.get(f"{host}/api/version", timeout=5)
|
||||
response = requests.get(f"{resolved_host}/api/version", timeout=5)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.ConnectionError:
|
||||
error_msg = (
|
||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
||||
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
|
||||
"Please ensure Ollama is running:\n"
|
||||
" • macOS/Linux: ollama serve\n"
|
||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||
@@ -555,7 +934,7 @@ def compute_embeddings_ollama(
|
||||
|
||||
# Check if model exists and provide helpful suggestions
|
||||
try:
|
||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
||||
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
|
||||
response.raise_for_status()
|
||||
models = response.json()
|
||||
model_names = [model["name"] for model in models.get("models", [])]
|
||||
@@ -615,10 +994,12 @@ def compute_embeddings_ollama(
|
||||
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
||||
model_name = resolved_model_name
|
||||
|
||||
# Verify the model supports embeddings by testing it
|
||||
# Verify the model supports embeddings by testing it with /api/embed
|
||||
try:
|
||||
test_response = requests.post(
|
||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
||||
f"{resolved_host}/api/embed",
|
||||
json={"model": model_name, "input": "test"},
|
||||
timeout=10,
|
||||
)
|
||||
if test_response.status_code != 200:
|
||||
error_msg = (
|
||||
@@ -649,56 +1030,82 @@ def compute_embeddings_ollama(
|
||||
# If torch is not available, use conservative batch size
|
||||
batch_size = 32
|
||||
|
||||
logger.info(f"Using batch size: {batch_size}")
|
||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||
|
||||
# Apply prompt template if provided
|
||||
provider_options = provider_options or {}
|
||||
# Priority: build_prompt_template (new format) > prompt_template (old format)
|
||||
prompt_template = provider_options.get("build_prompt_template") or provider_options.get(
|
||||
"prompt_template"
|
||||
)
|
||||
|
||||
if prompt_template:
|
||||
logger.warning(f"Applying prompt template: '{prompt_template}'")
|
||||
texts = [f"{prompt_template}{text}" for text in texts]
|
||||
|
||||
# Get model token limit and apply truncation before batching
|
||||
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||
|
||||
# Apply truncation to all texts before batch processing
|
||||
# Function logs truncation details internally
|
||||
texts = truncate_to_token_limit(texts, token_limit)
|
||||
|
||||
def get_batch_embeddings(batch_texts):
|
||||
"""Get embeddings for a batch of texts."""
|
||||
all_embeddings = []
|
||||
failed_indices = []
|
||||
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
for i, text in enumerate(batch_texts):
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
# Texts are already truncated to token limit by the outer function
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# Use /api/embed endpoint with "input" parameter for batch processing
|
||||
response = requests.post(
|
||||
f"{resolved_host}/api/embed",
|
||||
json={"model": model_name, "input": batch_texts},
|
||||
timeout=60, # Increased timeout for batch processing
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Truncate very long texts to avoid API issues
|
||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": truncated_text},
|
||||
timeout=30,
|
||||
result = response.json()
|
||||
batch_embeddings = result.get("embeddings")
|
||||
|
||||
if batch_embeddings is None:
|
||||
raise ValueError("No embeddings returned from API")
|
||||
|
||||
if not isinstance(batch_embeddings, list):
|
||||
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
|
||||
|
||||
if len(batch_embeddings) != len(batch_texts):
|
||||
raise ValueError(
|
||||
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embedding = result.get("embedding")
|
||||
return batch_embeddings, []
|
||||
|
||||
if embedding is None:
|
||||
raise ValueError(f"No embedding returned for text {i}")
|
||||
except requests.exceptions.Timeout:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.warning(f"Timeout for batch after {max_retries} retries")
|
||||
return None, list(range(len(batch_texts)))
|
||||
|
||||
if not isinstance(embedding, list) or len(embedding) == 0:
|
||||
raise ValueError(f"Invalid embedding format for text {i}")
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
# Enhanced error detection for token limit violations
|
||||
error_msg = str(e).lower()
|
||||
if "token" in error_msg and (
|
||||
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
|
||||
):
|
||||
logger.error(
|
||||
f"Token limit exceeded for batch. Error: {e}. "
|
||||
f"Consider reducing chunk sizes or check token truncation."
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to get embeddings for batch: {e}")
|
||||
return None, list(range(len(batch_texts)))
|
||||
|
||||
all_embeddings.append(embedding)
|
||||
break
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
||||
failed_indices.append(i)
|
||||
all_embeddings.append(None)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.error(f"Failed to get embedding for text {i}: {e}")
|
||||
failed_indices.append(i)
|
||||
all_embeddings.append(None)
|
||||
break
|
||||
return all_embeddings, failed_indices
|
||||
return None, list(range(len(batch_texts)))
|
||||
|
||||
# Process texts in batches
|
||||
all_embeddings = []
|
||||
@@ -716,7 +1123,7 @@ def compute_embeddings_ollama(
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
|
||||
if show_progress:
|
||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
||||
else:
|
||||
batch_iterator = range(num_batches)
|
||||
|
||||
@@ -727,10 +1134,14 @@ def compute_embeddings_ollama(
|
||||
|
||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||
|
||||
# Adjust failed indices to global indices
|
||||
global_failed = [start_idx + idx for idx in batch_failed]
|
||||
all_failed_indices.extend(global_failed)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
if batch_embeddings is not None:
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
else:
|
||||
# Entire batch failed, add None placeholders
|
||||
all_embeddings.extend([None] * len(batch_texts))
|
||||
# Adjust failed indices to global indices
|
||||
global_failed = [start_idx + idx for idx in batch_failed]
|
||||
all_failed_indices.extend(global_failed)
|
||||
|
||||
# Handle failed embeddings
|
||||
if all_failed_indices:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
@@ -8,6 +9,8 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .settings import encode_provider_options
|
||||
|
||||
# Lightweight, self-contained server manager with no cross-process inspection
|
||||
|
||||
# Set up logging based on environment variable
|
||||
@@ -46,6 +49,85 @@ def _check_port(port: int) -> bool:
|
||||
# Note: All cross-process scanning helpers removed for simplicity
|
||||
|
||||
|
||||
def _safe_resolve(path: Path) -> str:
|
||||
"""Resolve paths safely even if the target does not yet exist."""
|
||||
try:
|
||||
return str(path.resolve(strict=False))
|
||||
except Exception:
|
||||
return str(path)
|
||||
|
||||
|
||||
def _safe_stat_signature(path: Path) -> dict:
|
||||
"""Return a lightweight signature describing the current state of a path."""
|
||||
signature: dict[str, object] = {"path": _safe_resolve(path)}
|
||||
try:
|
||||
stat = path.stat()
|
||||
except FileNotFoundError:
|
||||
signature["missing"] = True
|
||||
except Exception as exc: # pragma: no cover - unexpected filesystem errors
|
||||
signature["error"] = str(exc)
|
||||
else:
|
||||
signature["mtime_ns"] = stat.st_mtime_ns
|
||||
signature["size"] = stat.st_size
|
||||
return signature
|
||||
|
||||
|
||||
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
|
||||
"""Collect modification signatures for metadata and referenced passage files."""
|
||||
if not passages_file:
|
||||
return None
|
||||
|
||||
meta_path = Path(passages_file)
|
||||
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
|
||||
|
||||
try:
|
||||
with meta_path.open(encoding="utf-8") as fh:
|
||||
meta = json.load(fh)
|
||||
except FileNotFoundError:
|
||||
signature["meta_missing"] = True
|
||||
signature["sources"] = []
|
||||
return signature
|
||||
except json.JSONDecodeError as exc:
|
||||
signature["meta_error"] = f"json_error:{exc}"
|
||||
signature["sources"] = []
|
||||
return signature
|
||||
except Exception as exc: # pragma: no cover - unexpected errors
|
||||
signature["meta_error"] = str(exc)
|
||||
signature["sources"] = []
|
||||
return signature
|
||||
|
||||
base_dir = meta_path.parent
|
||||
seen_paths: set[str] = set()
|
||||
source_signatures: list[dict[str, object]] = []
|
||||
|
||||
for source in meta.get("passage_sources", []):
|
||||
for key, kind in (
|
||||
("path", "passages"),
|
||||
("path_relative", "passages"),
|
||||
("index_path", "index"),
|
||||
("index_path_relative", "index"),
|
||||
):
|
||||
raw_path = source.get(key)
|
||||
if not raw_path:
|
||||
continue
|
||||
candidate = Path(raw_path)
|
||||
if not candidate.is_absolute():
|
||||
candidate = base_dir / candidate
|
||||
resolved = _safe_resolve(candidate)
|
||||
if resolved in seen_paths:
|
||||
continue
|
||||
seen_paths.add(resolved)
|
||||
sig = _safe_stat_signature(candidate)
|
||||
sig["kind"] = kind
|
||||
source_signatures.append(sig)
|
||||
|
||||
signature["sources"] = source_signatures
|
||||
return signature
|
||||
|
||||
|
||||
# Note: All cross-process scanning helpers removed for simplicity
|
||||
|
||||
|
||||
class EmbeddingServerManager:
|
||||
"""
|
||||
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||
@@ -82,16 +164,42 @@ class EmbeddingServerManager:
|
||||
) -> tuple[bool, int]:
|
||||
"""Start the embedding server."""
|
||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||
provider_options = kwargs.pop("provider_options", None)
|
||||
passages_file = kwargs.get("passages_file", "")
|
||||
|
||||
config_signature = self._build_config_signature(
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
provider_options=provider_options,
|
||||
passages_file=passages_file,
|
||||
)
|
||||
|
||||
# If this manager already has a live server, just reuse it
|
||||
if self.server_process and self.server_process.poll() is None and self.server_port:
|
||||
if (
|
||||
self.server_process
|
||||
and self.server_process.poll() is None
|
||||
and self.server_port
|
||||
and self._server_config == config_signature
|
||||
):
|
||||
logger.info("Reusing in-process server")
|
||||
return True, self.server_port
|
||||
|
||||
# Configuration changed, stop existing server before starting a new one
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
logger.info("Existing server configuration differs; restarting embedding server")
|
||||
self.stop_server()
|
||||
|
||||
# For Colab environment, use a different strategy
|
||||
if _is_colab_environment():
|
||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
||||
return self._start_server_colab(
|
||||
port,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
config_signature=config_signature,
|
||||
provider_options=provider_options,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Always pick a fresh available port
|
||||
try:
|
||||
@@ -101,13 +209,40 @@ class EmbeddingServerManager:
|
||||
return False, port
|
||||
|
||||
# Start a new server
|
||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||
return self._start_new_server(
|
||||
actual_port,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
provider_options=provider_options,
|
||||
config_signature=config_signature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _build_config_signature(
|
||||
self,
|
||||
*,
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
provider_options: Optional[dict],
|
||||
passages_file: Optional[str],
|
||||
) -> dict:
|
||||
"""Create a signature describing the current server configuration."""
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"passages_file": passages_file or "",
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
"passages_signature": _build_passages_signature(passages_file),
|
||||
}
|
||||
|
||||
def _start_server_colab(
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
*,
|
||||
config_signature: Optional[dict] = None,
|
||||
provider_options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
"""Start server with Colab-specific configuration."""
|
||||
@@ -125,8 +260,21 @@ class EmbeddingServerManager:
|
||||
|
||||
try:
|
||||
# In Colab, we'll use a more direct approach
|
||||
self._launch_server_process_colab(command, actual_port)
|
||||
return self._wait_for_server_ready_colab(actual_port)
|
||||
self._launch_server_process_colab(
|
||||
command,
|
||||
actual_port,
|
||||
provider_options=provider_options,
|
||||
config_signature=config_signature,
|
||||
)
|
||||
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||
if started:
|
||||
self._server_config = config_signature or {
|
||||
"model_name": model_name,
|
||||
"passages_file": kwargs.get("passages_file", ""),
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
return started, ready_port
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||
return False, actual_port
|
||||
@@ -134,7 +282,13 @@ class EmbeddingServerManager:
|
||||
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
||||
|
||||
def _start_new_server(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
provider_options: Optional[dict] = None,
|
||||
config_signature: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
"""Start a new embedding server on the given port."""
|
||||
logger.info(f"Starting embedding server on port {port}...")
|
||||
@@ -142,8 +296,21 @@ class EmbeddingServerManager:
|
||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
try:
|
||||
self._launch_server_process(command, port)
|
||||
return self._wait_for_server_ready(port)
|
||||
self._launch_server_process(
|
||||
command,
|
||||
port,
|
||||
provider_options=provider_options,
|
||||
config_signature=config_signature,
|
||||
)
|
||||
started, ready_port = self._wait_for_server_ready(port)
|
||||
if started:
|
||||
self._server_config = config_signature or {
|
||||
"model_name": model_name,
|
||||
"passages_file": kwargs.get("passages_file", ""),
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
return started, ready_port
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start embedding server: {e}")
|
||||
return False, port
|
||||
@@ -173,7 +340,14 @@ class EmbeddingServerManager:
|
||||
|
||||
return command
|
||||
|
||||
def _launch_server_process(self, command: list, port: int) -> None:
|
||||
def _launch_server_process(
|
||||
self,
|
||||
command: list,
|
||||
port: int,
|
||||
*,
|
||||
provider_options: Optional[dict] = None,
|
||||
config_signature: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Launch the server process."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
logger.info(f"Command: {' '.join(command)}")
|
||||
@@ -193,32 +367,43 @@ class EmbeddingServerManager:
|
||||
|
||||
# Start embedding server subprocess
|
||||
logger.info(f"Starting server process with command: {' '.join(command)}")
|
||||
env = os.environ.copy()
|
||||
encoded_options = encode_provider_options(provider_options)
|
||||
if encoded_options:
|
||||
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=stdout_target,
|
||||
stderr=stderr_target,
|
||||
env=env,
|
||||
)
|
||||
self.server_port = port
|
||||
# Record config for in-process reuse
|
||||
try:
|
||||
self._server_config = {
|
||||
"model_name": command[command.index("--model-name") + 1]
|
||||
if "--model-name" in command
|
||||
else "",
|
||||
"passages_file": command[command.index("--passages-file") + 1]
|
||||
if "--passages-file" in command
|
||||
else "",
|
||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||
if "--embedding-mode" in command
|
||||
else "sentence-transformers",
|
||||
}
|
||||
except Exception:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
}
|
||||
# Record config for in-process reuse (best effort; refined later when ready)
|
||||
if config_signature is not None:
|
||||
self._server_config = config_signature
|
||||
else: # Fallback for unexpected code paths
|
||||
try:
|
||||
self._server_config = {
|
||||
"model_name": command[command.index("--model-name") + 1]
|
||||
if "--model-name" in command
|
||||
else "",
|
||||
"passages_file": command[command.index("--passages-file") + 1]
|
||||
if "--passages-file" in command
|
||||
else "",
|
||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||
if "--embedding-mode" in command
|
||||
else "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
except Exception:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
# Register atexit callback only when we actually start a process
|
||||
@@ -322,16 +507,29 @@ class EmbeddingServerManager:
|
||||
# Removed: cross-process adoption no longer supported
|
||||
return
|
||||
|
||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||
def _launch_server_process_colab(
|
||||
self,
|
||||
command: list,
|
||||
port: int,
|
||||
*,
|
||||
provider_options: Optional[dict] = None,
|
||||
config_signature: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Launch the server process with Colab-specific settings."""
|
||||
logger.info(f"Colab Command: {' '.join(command)}")
|
||||
|
||||
# In Colab, we need to be more careful about process management
|
||||
env = os.environ.copy()
|
||||
encoded_options = encode_provider_options(provider_options)
|
||||
if encoded_options:
|
||||
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
self.server_port = port
|
||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||
@@ -341,11 +539,15 @@ class EmbeddingServerManager:
|
||||
atexit.register(self._finalize_process)
|
||||
self._atexit_registered = True
|
||||
# Record config for in-process reuse is best-effort in Colab mode
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
}
|
||||
if config_signature is not None:
|
||||
self._server_config = config_signature
|
||||
else:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
|
||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||
|
||||
189
packages/leann-core/src/leann/interactive_utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Interactive session utilities for LEANN applications.
|
||||
|
||||
Provides shared readline functionality and command handling across
|
||||
CLI, API, and RAG example interactive modes.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
# Try to import readline with fallback for Windows
|
||||
try:
|
||||
import readline
|
||||
|
||||
HAS_READLINE = True
|
||||
except ImportError:
|
||||
# Windows doesn't have readline by default
|
||||
HAS_READLINE = False
|
||||
readline = None
|
||||
|
||||
|
||||
class InteractiveSession:
|
||||
"""Manages interactive session with optional readline support and common commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_name: str,
|
||||
prompt: str = "You: ",
|
||||
welcome_message: str = "",
|
||||
):
|
||||
"""
|
||||
Initialize interactive session with optional readline support.
|
||||
|
||||
Args:
|
||||
history_name: Name for history file (e.g., "cli", "api_chat")
|
||||
(ignored if readline not available)
|
||||
prompt: Input prompt to display
|
||||
welcome_message: Message to show when starting session
|
||||
|
||||
Note:
|
||||
On systems without readline (e.g., Windows), falls back to basic input()
|
||||
with limited functionality (no history, no line editing).
|
||||
"""
|
||||
self.history_name = history_name
|
||||
self.prompt = prompt
|
||||
self.welcome_message = welcome_message
|
||||
self._setup_complete = False
|
||||
|
||||
def setup_readline(self):
|
||||
"""Setup readline with history support (if available)."""
|
||||
if self._setup_complete:
|
||||
return
|
||||
|
||||
if not HAS_READLINE:
|
||||
# Readline not available (likely Windows), skip setup
|
||||
self._setup_complete = True
|
||||
return
|
||||
|
||||
# History file setup
|
||||
history_dir = Path.home() / ".leann" / "history"
|
||||
history_dir.mkdir(parents=True, exist_ok=True)
|
||||
history_file = history_dir / f"{self.history_name}.history"
|
||||
|
||||
# Load history if exists
|
||||
try:
|
||||
readline.read_history_file(str(history_file))
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# Save history on exit
|
||||
atexit.register(readline.write_history_file, str(history_file))
|
||||
|
||||
# Optional: Enable vi editing mode (commented out by default)
|
||||
# readline.parse_and_bind("set editing-mode vi")
|
||||
|
||||
self._setup_complete = True
|
||||
|
||||
def _show_help(self):
|
||||
"""Show available commands."""
|
||||
print("Commands:")
|
||||
print(" quit/exit/q - Exit the chat")
|
||||
print(" help - Show this help message")
|
||||
print(" clear - Clear screen")
|
||||
print(" history - Show command history")
|
||||
|
||||
def _show_history(self):
|
||||
"""Show command history."""
|
||||
if not HAS_READLINE:
|
||||
print(" History not available (readline not supported on this system)")
|
||||
return
|
||||
|
||||
history_length = readline.get_current_history_length()
|
||||
if history_length == 0:
|
||||
print(" No history available")
|
||||
return
|
||||
|
||||
for i in range(history_length):
|
||||
item = readline.get_history_item(i + 1)
|
||||
if item:
|
||||
print(f" {i + 1}: {item}")
|
||||
|
||||
def get_user_input(self) -> Optional[str]:
|
||||
"""
|
||||
Get user input with readline support.
|
||||
|
||||
Returns:
|
||||
User input string, or None if EOF (Ctrl+D)
|
||||
"""
|
||||
try:
|
||||
return input(self.prompt).strip()
|
||||
except KeyboardInterrupt:
|
||||
print("\n(Use 'quit' to exit)")
|
||||
return "" # Return empty string to continue
|
||||
except EOFError:
|
||||
print("\nGoodbye!")
|
||||
return None
|
||||
|
||||
def run_interactive_loop(self, handler_func: Callable[[str], None]):
|
||||
"""
|
||||
Run the interactive loop with a custom handler function.
|
||||
|
||||
Args:
|
||||
handler_func: Function to handle user input that's not a built-in command
|
||||
Should accept a string and handle the user's query
|
||||
"""
|
||||
self.setup_readline()
|
||||
|
||||
if self.welcome_message:
|
||||
print(self.welcome_message)
|
||||
|
||||
while True:
|
||||
user_input = self.get_user_input()
|
||||
|
||||
if user_input is None: # EOF (Ctrl+D)
|
||||
break
|
||||
|
||||
if not user_input: # Empty input or KeyboardInterrupt
|
||||
continue
|
||||
|
||||
# Handle built-in commands
|
||||
command = user_input.lower()
|
||||
if command in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
elif command == "help":
|
||||
self._show_help()
|
||||
elif command == "clear":
|
||||
os.system("clear" if os.name != "nt" else "cls")
|
||||
elif command == "history":
|
||||
self._show_history()
|
||||
else:
|
||||
# Regular user input - pass to handler
|
||||
try:
|
||||
handler_func(user_input)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def create_cli_session(index_name: str) -> InteractiveSession:
|
||||
"""Create an interactive session for CLI usage."""
|
||||
return InteractiveSession(
|
||||
history_name=index_name,
|
||||
prompt="\nYou: ",
|
||||
welcome_message="LEANN Assistant ready! Type 'quit' to exit, 'help' for commands\n"
|
||||
+ "=" * 40,
|
||||
)
|
||||
|
||||
|
||||
def create_api_session() -> InteractiveSession:
|
||||
"""Create an interactive session for API chat."""
|
||||
return InteractiveSession(
|
||||
history_name="api_chat",
|
||||
prompt="You: ",
|
||||
welcome_message="Leann Chat started (type 'quit' to exit, 'help' for commands)\n"
|
||||
+ "=" * 40,
|
||||
)
|
||||
|
||||
|
||||
def create_rag_session(app_name: str, data_description: str) -> InteractiveSession:
|
||||
"""Create an interactive session for RAG examples."""
|
||||
return InteractiveSession(
|
||||
history_name=f"{app_name}_rag",
|
||||
prompt="You: ",
|
||||
welcome_message=f"[Interactive Mode] Chat with your {data_description} data!\nType 'quit' or 'exit' to stop, 'help' for commands.\n"
|
||||
+ "=" * 40,
|
||||
)
|
||||
@@ -77,6 +77,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: Optional[int] = None,
|
||||
query_template: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
@@ -84,6 +85,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
query_template: Optional prompt template to prepend to query
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array with shape (1, D)
|
||||
|
||||
@@ -60,6 +60,11 @@ def handle_request(request):
|
||||
"maximum": 128,
|
||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
||||
},
|
||||
"show_metadata": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "Include file paths and metadata in search results. Useful for understanding which files contain the results.",
|
||||
},
|
||||
},
|
||||
"required": ["index_name", "query"],
|
||||
},
|
||||
@@ -104,6 +109,8 @@ def handle_request(request):
|
||||
f"--complexity={args.get('complexity', 32)}",
|
||||
"--non-interactive",
|
||||
]
|
||||
if args.get("show_metadata", False):
|
||||
cmd.append("--show-metadata")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
elif tool_name == "leann_list":
|
||||
|
||||
@@ -33,6 +33,8 @@ def autodiscover_backends():
|
||||
discovered_backends = []
|
||||
for dist in importlib.metadata.distributions():
|
||||
dist_name = dist.metadata["name"]
|
||||
if dist_name is None:
|
||||
continue
|
||||
if dist_name.startswith("leann-backend-"):
|
||||
backend_module_name = dist_name.replace("-", "_")
|
||||
discovered_backends.append(backend_module_name)
|
||||
|
||||
@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
||||
|
||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
self.embedding_options = self.meta.get("embedding_options", {})
|
||||
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name=backend_module_name,
|
||||
@@ -70,6 +71,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
or "mips"
|
||||
)
|
||||
|
||||
# Filter out ALL prompt templates from provider_options during search
|
||||
# Templates are applied in compute_query_embedding (line 109-110) BEFORE server call
|
||||
# The server should never apply templates during search to avoid double-templating
|
||||
search_provider_options = {
|
||||
k: v
|
||||
for k, v in self.embedding_options.items()
|
||||
if k not in ("build_prompt_template", "query_prompt_template", "prompt_template")
|
||||
}
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
@@ -77,6 +87,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=distance_metric,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
provider_options=search_provider_options,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||
@@ -88,6 +99,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int = 5557,
|
||||
query_template: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
@@ -96,10 +108,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
query_template: Optional prompt template to prepend to query
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Apply query template BEFORE any computation path
|
||||
# This ensures template is applied consistently for both server and fallback paths
|
||||
if query_template:
|
||||
query = f"{query_template}{query}"
|
||||
|
||||
# Try to use embedding server if available and requested
|
||||
if use_server_if_available:
|
||||
try:
|
||||
@@ -125,7 +143,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
from .embedding_compute import compute_embeddings
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||
return compute_embeddings(
|
||||
[query],
|
||||
self.embedding_model,
|
||||
embedding_mode,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
|
||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||
"""Compute embeddings using the ZMQ embedding server."""
|
||||
|
||||
101
packages/leann-core/src/leann/settings.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Runtime configuration helpers for LEANN."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
|
||||
|
||||
|
||||
def _clean_url(value: str) -> str:
|
||||
"""Normalize URL strings by stripping trailing slashes."""
|
||||
|
||||
return value.rstrip("/") if value else value
|
||||
|
||||
|
||||
def resolve_ollama_host(explicit: str | None = None) -> str:
|
||||
"""Resolve the Ollama-compatible endpoint to use."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_LOCAL_LLM_HOST"),
|
||||
os.getenv("LEANN_OLLAMA_HOST"),
|
||||
os.getenv("OLLAMA_HOST"),
|
||||
os.getenv("LOCAL_LLM_ENDPOINT"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_OLLAMA_HOST)
|
||||
|
||||
|
||||
def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||
"""Resolve the base URL for OpenAI-compatible services."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_OPENAI_BASE_URL"),
|
||||
os.getenv("OPENAI_BASE_URL"),
|
||||
os.getenv("LOCAL_OPENAI_BASE_URL"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||
|
||||
|
||||
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
|
||||
"""Resolve the base URL for Anthropic-compatible services."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
|
||||
os.getenv("ANTHROPIC_BASE_URL"),
|
||||
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
|
||||
|
||||
|
||||
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for OpenAI-compatible services."""
|
||||
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for Anthropic services."""
|
||||
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
return os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
|
||||
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||
"""Serialize provider options for child processes."""
|
||||
|
||||
if not options:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.dumps(options)
|
||||
except (TypeError, ValueError):
|
||||
# Fall back to empty payload if serialization fails
|
||||
return None
|
||||
@@ -53,6 +53,11 @@ leann build my-project --docs $(git ls-files)
|
||||
# Start Claude Code
|
||||
claude
|
||||
```
|
||||
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
|
||||
|
||||
```bash
|
||||
leann build my-project --docs $(git ls-files) --no-recompute
|
||||
```
|
||||
|
||||
## 🚀 Advanced Usage Examples to build the index
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.3.4"
|
||||
version = "0.3.5"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
|
||||