Compare commits
4 Commits
security/e
...
feature/cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f52bce23c3 | ||
|
|
f1355b70d8 | ||
|
|
2dd4147de2 | ||
|
|
be17980114 |
111
.github/workflows/build-reusable.yml
vendored
111
.github/workflows/build-reusable.yml
vendored
@@ -17,17 +17,26 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Run pre-commit with only lint group (no project deps)
|
||||
run: |
|
||||
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install ruff
|
||||
run: |
|
||||
uv tool install ruff
|
||||
|
||||
- name: Run ruff check
|
||||
run: |
|
||||
ruff check .
|
||||
|
||||
- name: Run ruff format check
|
||||
run: |
|
||||
ruff format --check .
|
||||
|
||||
build:
|
||||
needs: lint
|
||||
@@ -94,11 +103,14 @@ jobs:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
@@ -156,24 +168,11 @@ jobs:
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv python install ${{ matrix.python }}
|
||||
uv venv --python ${{ matrix.python }} .uv-build
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
BUILD_PY=".uv-build\\Scripts\\python.exe"
|
||||
else
|
||||
BUILD_PY=".uv-build/bin/python"
|
||||
fi
|
||||
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --python "$BUILD_PY" auditwheel
|
||||
uv pip install --system auditwheel
|
||||
else
|
||||
uv pip install --python "$BUILD_PY" delocate
|
||||
fi
|
||||
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
|
||||
else
|
||||
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
|
||||
uv pip install --system delocate
|
||||
fi
|
||||
|
||||
- name: Set macOS environment variables
|
||||
@@ -309,66 +308,18 @@ jobs:
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create uv-managed virtual environment with the requested interpreter
|
||||
uv python install ${{ matrix.python }}
|
||||
# Create a virtual environment with the correct Python version
|
||||
uv venv --python ${{ matrix.python }}
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
UV_PY=".venv\\Scripts\\python.exe"
|
||||
else
|
||||
UV_PY=".venv/bin/python"
|
||||
fi
|
||||
# Install packages using --find-links to prioritize local builds
|
||||
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
|
||||
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
|
||||
|
||||
# Install test dependency group only (avoids reinstalling project package)
|
||||
uv pip install --python "$UV_PY" --group test
|
||||
|
||||
# Install core wheel built in this job
|
||||
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||
if [[ -n "$CORE_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$CORE_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
|
||||
fi
|
||||
|
||||
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||
|
||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
fi
|
||||
fi
|
||||
|
||||
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||
if [[ -z "$HNSW_WHL" ]]; then
|
||||
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||
fi
|
||||
if [[ -n "$HNSW_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$HNSW_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
|
||||
fi
|
||||
|
||||
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||
if [[ -z "$DISKANN_WHL" ]]; then
|
||||
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||
fi
|
||||
if [[ -n "$DISKANN_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$DISKANN_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
|
||||
fi
|
||||
|
||||
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||
if [[ -n "$LEANN_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$LEANN_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
|
||||
fi
|
||||
# Install test dependencies using extras
|
||||
uv pip install -e ".[test]"
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
|
||||
16
.gitignore
vendored
16
.gitignore
vendored
@@ -18,7 +18,6 @@ demo/experiment_results/**/*.json
|
||||
*.eml
|
||||
*.emlx
|
||||
*.json
|
||||
*.png
|
||||
!.vscode/*.json
|
||||
*.sh
|
||||
*.txt
|
||||
@@ -95,13 +94,10 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
paru-bin/
|
||||
|
||||
CLAUDE.md
|
||||
CLAUDE.local.md
|
||||
.claude/*.local.*
|
||||
.claude/local/*
|
||||
benchmarks/data/
|
||||
|
||||
## multi vector
|
||||
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||
|
||||
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
|
||||
# If you need to commit a specific demo PDF, remove this negation locally.
|
||||
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||
|
||||
237
README.md
237
README.md
@@ -20,7 +20,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
@@ -176,16 +176,13 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, ChatGPT conversations, Claude conversations, iMessage conversations, and more.
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, Claude conversations, and more.
|
||||
|
||||
|
||||
|
||||
### Generation Model Setup
|
||||
|
||||
#### LLM Backend
|
||||
|
||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
|
||||
|
||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
|
||||
<details>
|
||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||
@@ -196,68 +193,6 @@ Set your OpenAI API key as an environment variable:
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
Make sure to use `--llm openai` flag when using the CLI.
|
||||
You can also specify the model name with `--llm-model <model-name>` flag.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🛠️ Supported LLM & Embedding Providers (via OpenAI Compatibility)</strong></summary>
|
||||
|
||||
Thanks to the widespread adoption of the OpenAI API format, LEANN is compatible out-of-the-box with a vast array of LLM and embedding providers. Simply set the `OPENAI_BASE_URL` and `OPENAI_API_KEY` environment variables to connect to your preferred service.
|
||||
|
||||
```sh
|
||||
export OPENAI_API_KEY="xxx"
|
||||
export OPENAI_BASE_URL="http://localhost:1234/v1" # base url of the provider
|
||||
```
|
||||
|
||||
To use OpenAI compatible endpoint with the CLI interface:
|
||||
|
||||
If you are using it for text generation, make sure to use `--llm openai` flag and specify the model name with `--llm-model <model-name>` flag.
|
||||
|
||||
If you are using it for embedding, set the `--embedding-mode openai` flag and specify the model name with `--embedding-model <MODEL>`.
|
||||
|
||||
-----
|
||||
|
||||
|
||||
Below is a list of base URLs for common providers to get you started.
|
||||
|
||||
|
||||
### 🖥️ Local Inference Engines (Recommended for full privacy)
|
||||
|
||||
| Provider | Sample Base URL |
|
||||
| ---------------- | --------------------------- |
|
||||
| **Ollama** | `http://localhost:11434/v1` |
|
||||
| **LM Studio** | `http://localhost:1234/v1` |
|
||||
| **vLLM** | `http://localhost:8000/v1` |
|
||||
| **llama.cpp** | `http://localhost:8080/v1` |
|
||||
| **SGLang** | `http://localhost:30000/v1` |
|
||||
| **LiteLLM** | `http://localhost:4000` |
|
||||
|
||||
-----
|
||||
|
||||
### ☁️ Cloud Providers
|
||||
|
||||
> **🚨 A Note on Privacy:** Before choosing a cloud provider, carefully review their privacy and data retention policies. Depending on their terms, your data may be used for their own purposes, including but not limited to human reviews and model training, which can lead to serious consequences if not handled properly.
|
||||
|
||||
|
||||
| Provider | Base URL |
|
||||
| ---------------- | ---------------------------------------------------------- |
|
||||
| **OpenAI** | `https://api.openai.com/v1` |
|
||||
| **OpenRouter** | `https://openrouter.ai/api/v1` |
|
||||
| **Gemini** | `https://generativelanguage.googleapis.com/v1beta/openai/` |
|
||||
| **x.AI (Grok)** | `https://api.x.ai/v1` |
|
||||
| **Groq AI** | `https://api.groq.com/openai/v1` |
|
||||
| **DeepSeek** | `https://api.deepseek.com/v1` |
|
||||
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||
|
||||
|
||||
|
||||
|
||||
If your provider isn't on this list, don't worry! Check their documentation for an OpenAI-compatible endpoint—chances are, it's OpenAI Compatible too!
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -542,80 +477,6 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### 🤖 ChatGPT Chat History: Your Personal AI Conversation Archive!
|
||||
|
||||
Transform your ChatGPT conversations into a searchable knowledge base! Search through all your ChatGPT discussions about coding, research, brainstorming, and more.
|
||||
|
||||
```bash
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_export.html --query "How do I create a list in Python?"
|
||||
```
|
||||
|
||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your ChatGPT discussions again.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Export ChatGPT Data</strong></summary>
|
||||
|
||||
**Step-by-step export process:**
|
||||
|
||||
1. **Sign in to ChatGPT**
|
||||
2. **Click your profile icon** in the top right corner
|
||||
3. **Navigate to Settings** → **Data Controls**
|
||||
4. **Click "Export"** under Export Data
|
||||
5. **Confirm the export** request
|
||||
6. **Download the ZIP file** from the email link (expires in 24 hours)
|
||||
7. **Extract or use directly** with LEANN
|
||||
|
||||
**Supported formats:**
|
||||
- `.html` files from ChatGPT exports
|
||||
- `.zip` archives from ChatGPT
|
||||
- Directories with multiple export files
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: ChatGPT-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--export-path PATH # Path to ChatGPT export file (.html/.zip) or directory (default: ./chatgpt_export)
|
||||
--separate-messages # Process each message separately instead of concatenated conversations
|
||||
--chunk-size N # Text chunk size (default: 512)
|
||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage with HTML export
|
||||
python -m apps.chatgpt_rag --export-path conversations.html
|
||||
|
||||
# Process ZIP archive from ChatGPT
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_export.zip
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_data.html --query "Python programming help"
|
||||
|
||||
# Process individual messages for fine-grained search
|
||||
python -m apps.chatgpt_rag --separate-messages --export-path chatgpt_export.html
|
||||
|
||||
# Process directory containing multiple exports
|
||||
python -m apps.chatgpt_rag --export-path ./chatgpt_exports/ --max-items 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your ChatGPT conversations are indexed, you can search with queries like:
|
||||
- "What did I ask ChatGPT about Python programming?"
|
||||
- "Show me conversations about machine learning algorithms"
|
||||
- "Find discussions about web development frameworks"
|
||||
- "What coding advice did ChatGPT give me?"
|
||||
- "Search for conversations about debugging techniques"
|
||||
- "Find ChatGPT's recommendations for learning resources"
|
||||
|
||||
</details>
|
||||
|
||||
### 🤖 Claude Chat History: Your Personal AI Conversation Archive!
|
||||
|
||||
Transform your Claude conversations into a searchable knowledge base! Search through all your Claude discussions about coding, research, brainstorming, and more.
|
||||
@@ -690,90 +551,6 @@ Once your Claude conversations are indexed, you can search with queries like:
|
||||
|
||||
</details>
|
||||
|
||||
### 💬 iMessage History: Your Personal Conversation Archive!
|
||||
|
||||
Transform your iMessage conversations into a searchable knowledge base! Search through all your text messages, group chats, and conversations with friends, family, and colleagues.
|
||||
|
||||
```bash
|
||||
python -m apps.imessage_rag --query "What did we discuss about the weekend plans?"
|
||||
```
|
||||
|
||||
**Unlock your message history.** Never lose track of important conversations, shared links, or memorable moments from your iMessage history.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Access iMessage Data</strong></summary>
|
||||
|
||||
**iMessage data location:**
|
||||
|
||||
iMessage conversations are stored in a SQLite database on your Mac at:
|
||||
```
|
||||
~/Library/Messages/chat.db
|
||||
```
|
||||
|
||||
**Important setup requirements:**
|
||||
|
||||
1. **Grant Full Disk Access** to your terminal or IDE:
|
||||
- Open **System Preferences** → **Security & Privacy** → **Privacy**
|
||||
- Select **Full Disk Access** from the left sidebar
|
||||
- Click the **+** button and add your terminal app (Terminal, iTerm2) or IDE (VS Code, etc.)
|
||||
- Restart your terminal/IDE after granting access
|
||||
|
||||
2. **Alternative: Use a backup database**
|
||||
- If you have Time Machine backups or manual copies of the database
|
||||
- Use `--db-path` to specify a custom location
|
||||
|
||||
**Supported formats:**
|
||||
- Direct access to `~/Library/Messages/chat.db` (default)
|
||||
- Custom database path with `--db-path`
|
||||
- Works with backup copies of the database
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: iMessage-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--db-path PATH # Path to chat.db file (default: ~/Library/Messages/chat.db)
|
||||
--concatenate-conversations # Group messages by conversation (default: True)
|
||||
--no-concatenate-conversations # Process each message individually
|
||||
--chunk-size N # Text chunk size (default: 1000)
|
||||
--chunk-overlap N # Overlap between chunks (default: 200)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage (requires Full Disk Access)
|
||||
python -m apps.imessage_rag
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.imessage_rag --query "family dinner plans"
|
||||
|
||||
# Use custom database path
|
||||
python -m apps.imessage_rag --db-path /path/to/backup/chat.db
|
||||
|
||||
# Process individual messages instead of conversations
|
||||
python -m apps.imessage_rag --no-concatenate-conversations
|
||||
|
||||
# Limit processing for testing
|
||||
python -m apps.imessage_rag --max-items 100 --query "weekend"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your iMessage conversations are indexed, you can search with queries like:
|
||||
- "What did we discuss about vacation plans?"
|
||||
- "Find messages about restaurant recommendations"
|
||||
- "Show me conversations with John about the project"
|
||||
- "Search for shared links about technology"
|
||||
- "Find group chat discussions about weekend events"
|
||||
- "What did mom say about the family gathering?"
|
||||
|
||||
</details>
|
||||
|
||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||
|
||||
<details>
|
||||
@@ -843,9 +620,6 @@ leann search my-docs "machine learning concepts"
|
||||
# Interactive chat with your documents
|
||||
leann ask my-docs --interactive
|
||||
|
||||
# Ask a single question (non-interactive)
|
||||
leann ask my-docs "Where are prompts configured?"
|
||||
|
||||
# List all your indexes
|
||||
leann list
|
||||
|
||||
@@ -1006,8 +780,9 @@ results = searcher.search("banana‑crocodile", use_grep=True, top_k=1)
|
||||
## Reproduce Our Results
|
||||
|
||||
```bash
|
||||
uv run benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
uv run benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
|
||||
@@ -10,9 +10,7 @@ from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from leann.interactive_utils import create_rag_session
|
||||
from leann.registry import register_project_directory
|
||||
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@@ -80,24 +78,6 @@ class BaseRAGExample(ABC):
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible embedding host",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible embedding services",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# LLM parameters
|
||||
llm_group = parser.add_argument_group("LLM Parameters")
|
||||
@@ -117,8 +97,8 @@ class BaseRAGExample(ABC):
|
||||
llm_group.add_argument(
|
||||
"--llm-host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||
default="http://localhost:11434",
|
||||
help="Host for Ollama API (default: http://localhost:11434)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--thinking-budget",
|
||||
@@ -127,18 +107,6 @@ class BaseRAGExample(ABC):
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible APIs",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# AST Chunking parameters
|
||||
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||
@@ -237,13 +205,9 @@ class BaseRAGExample(ABC):
|
||||
|
||||
if args.llm == "openai":
|
||||
config["model"] = args.llm_model or "gpt-4o"
|
||||
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||
if resolved_key:
|
||||
config["api_key"] = resolved_key
|
||||
elif args.llm == "ollama":
|
||||
config["model"] = args.llm_model or "llama3.2:1b"
|
||||
config["host"] = resolve_ollama_host(args.llm_host)
|
||||
config["host"] = args.llm_host
|
||||
elif args.llm == "hf":
|
||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
elif args.llm == "simulated":
|
||||
@@ -259,20 +223,10 @@ class BaseRAGExample(ABC):
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
print(f"Total text chunks: {len(texts)}")
|
||||
|
||||
embedding_options: dict[str, Any] = {}
|
||||
if args.embedding_mode == "ollama":
|
||||
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||
elif args.embedding_mode == "openai":
|
||||
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
embedding_options=embedding_options or None,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
@@ -308,26 +262,37 @@ class BaseRAGExample(ABC):
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
# Create interactive session
|
||||
session = create_rag_session(
|
||||
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
|
||||
)
|
||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||
print("Type 'quit' or 'exit' to stop.\n")
|
||||
|
||||
def handle_query(query: str):
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
while True:
|
||||
try:
|
||||
query = input("You: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
if not query:
|
||||
continue
|
||||
|
||||
session.run_interactive_loop(handle_query)
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""iMessage data processing module."""
|
||||
@@ -1,342 +0,0 @@
|
||||
"""
|
||||
iMessage data reader.
|
||||
|
||||
Reads and processes iMessage conversation data from the macOS Messages database.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class IMessageReader(BaseReader):
|
||||
"""
|
||||
iMessage data reader.
|
||||
|
||||
Reads iMessage conversation data from the macOS Messages database (chat.db).
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _get_default_chat_db_path(self) -> Path:
|
||||
"""
|
||||
Get the default path to the iMessage chat database.
|
||||
|
||||
Returns:
|
||||
Path to the chat.db file
|
||||
"""
|
||||
home = Path.home()
|
||||
return home / "Library" / "Messages" / "chat.db"
|
||||
|
||||
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
|
||||
"""
|
||||
Convert Cocoa timestamp to readable format.
|
||||
|
||||
Args:
|
||||
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
|
||||
|
||||
Returns:
|
||||
Formatted timestamp string
|
||||
"""
|
||||
if cocoa_timestamp == 0:
|
||||
return "Unknown"
|
||||
|
||||
try:
|
||||
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
|
||||
# Convert to seconds and add to Unix epoch
|
||||
cocoa_epoch = datetime(2001, 1, 1)
|
||||
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
|
||||
message_time = cocoa_epoch.timestamp() + unix_timestamp
|
||||
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "Unknown"
|
||||
|
||||
def _get_contact_name(self, handle_id: str) -> str:
|
||||
"""
|
||||
Get a readable contact name from handle ID.
|
||||
|
||||
Args:
|
||||
handle_id: The handle ID (phone number or email)
|
||||
|
||||
Returns:
|
||||
Formatted contact name
|
||||
"""
|
||||
if not handle_id:
|
||||
return "Unknown"
|
||||
|
||||
# Clean up phone numbers and emails for display
|
||||
if "@" in handle_id:
|
||||
return handle_id # Email address
|
||||
elif handle_id.startswith("+"):
|
||||
return handle_id # International phone number
|
||||
else:
|
||||
# Try to format as phone number
|
||||
digits = "".join(filter(str.isdigit, handle_id))
|
||||
if len(digits) == 10:
|
||||
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
|
||||
elif len(digits) == 11 and digits[0] == "1":
|
||||
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
|
||||
else:
|
||||
return handle_id
|
||||
|
||||
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
|
||||
"""
|
||||
Read messages from the iMessage database.
|
||||
|
||||
Args:
|
||||
db_path: Path to the chat.db file
|
||||
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
if not db_path.exists():
|
||||
print(f"iMessage database not found at: {db_path}")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Connect to the database
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query to get messages with chat and handle information
|
||||
query = """
|
||||
SELECT
|
||||
m.ROWID as message_id,
|
||||
m.text,
|
||||
m.date,
|
||||
m.is_from_me,
|
||||
m.service,
|
||||
c.chat_identifier,
|
||||
c.display_name as chat_display_name,
|
||||
h.id as handle_id,
|
||||
c.ROWID as chat_id
|
||||
FROM message m
|
||||
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
|
||||
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
|
||||
LEFT JOIN handle h ON m.handle_id = h.ROWID
|
||||
WHERE m.text IS NOT NULL AND m.text != ''
|
||||
ORDER BY c.ROWID, m.date
|
||||
"""
|
||||
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
messages = []
|
||||
for row in rows:
|
||||
(
|
||||
message_id,
|
||||
text,
|
||||
date,
|
||||
is_from_me,
|
||||
service,
|
||||
chat_identifier,
|
||||
chat_display_name,
|
||||
handle_id,
|
||||
chat_id,
|
||||
) = row
|
||||
|
||||
message = {
|
||||
"message_id": message_id,
|
||||
"text": text,
|
||||
"timestamp": self._convert_cocoa_timestamp(date),
|
||||
"is_from_me": bool(is_from_me),
|
||||
"service": service or "iMessage",
|
||||
"chat_identifier": chat_identifier or "Unknown",
|
||||
"chat_display_name": chat_display_name or "Unknown Chat",
|
||||
"handle_id": handle_id or "Unknown",
|
||||
"contact_name": self._get_contact_name(handle_id or ""),
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
messages.append(message)
|
||||
|
||||
conn.close()
|
||||
print(f"Found {len(messages)} messages in database")
|
||||
return messages
|
||||
|
||||
except sqlite3.Error as e:
|
||||
print(f"Error reading iMessage database: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Unexpected error reading iMessage database: {e}")
|
||||
return []
|
||||
|
||||
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
|
||||
"""
|
||||
Group messages by chat ID.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Dictionary mapping chat_id to list of messages
|
||||
"""
|
||||
chats = {}
|
||||
for message in messages:
|
||||
chat_id = message["chat_id"]
|
||||
if chat_id not in chats:
|
||||
chats[chat_id] = []
|
||||
chats[chat_id].append(message)
|
||||
|
||||
return chats
|
||||
|
||||
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
|
||||
"""
|
||||
Create concatenated content from chat messages.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID
|
||||
messages: List of messages in the chat
|
||||
|
||||
Returns:
|
||||
Concatenated text content
|
||||
"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Get chat info from first message
|
||||
first_msg = messages[0]
|
||||
chat_name = first_msg["chat_display_name"]
|
||||
chat_identifier = first_msg["chat_identifier"]
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
timestamp = message["timestamp"]
|
||||
is_from_me = message["is_from_me"]
|
||||
text = message["text"]
|
||||
contact_name = message["contact_name"]
|
||||
|
||||
if is_from_me:
|
||||
prefix = "[You]"
|
||||
else:
|
||||
prefix = f"[{contact_name}]"
|
||||
|
||||
if timestamp != "Unknown":
|
||||
prefix += f" ({timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {text}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
doc_content = f"""Chat: {chat_name}
|
||||
Identifier: {chat_identifier}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def _create_individual_content(self, message: dict) -> str:
|
||||
"""
|
||||
Create content for individual message.
|
||||
|
||||
Args:
|
||||
message: Message dictionary
|
||||
|
||||
Returns:
|
||||
Formatted message content
|
||||
"""
|
||||
timestamp = message["timestamp"]
|
||||
is_from_me = message["is_from_me"]
|
||||
text = message["text"]
|
||||
contact_name = message["contact_name"]
|
||||
chat_name = message["chat_display_name"]
|
||||
|
||||
sender = "You" if is_from_me else contact_name
|
||||
|
||||
return f"""Message from {sender} in chat "{chat_name}"
|
||||
Time: {timestamp}
|
||||
Content: {text}
|
||||
"""
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load iMessage data and return as documents.
|
||||
|
||||
Args:
|
||||
input_dir: Optional path to directory containing chat.db file.
|
||||
If not provided, uses default macOS location.
|
||||
**load_kwargs: Additional arguments (unused)
|
||||
|
||||
Returns:
|
||||
List of Document objects containing iMessage data
|
||||
"""
|
||||
docs = []
|
||||
|
||||
# Determine database path
|
||||
if input_dir:
|
||||
db_path = Path(input_dir) / "chat.db"
|
||||
else:
|
||||
db_path = self._get_default_chat_db_path()
|
||||
|
||||
print(f"Reading iMessage database from: {db_path}")
|
||||
|
||||
# Read messages from database
|
||||
messages = self._read_messages_from_db(db_path)
|
||||
if not messages:
|
||||
return docs
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Group messages by chat and create concatenated documents
|
||||
chats = self._group_messages_by_chat(messages)
|
||||
|
||||
for chat_id, chat_messages in chats.items():
|
||||
if not chat_messages:
|
||||
continue
|
||||
|
||||
content = self._create_concatenated_content(chat_id, chat_messages)
|
||||
|
||||
# Create metadata
|
||||
first_msg = chat_messages[0]
|
||||
last_msg = chat_messages[-1]
|
||||
|
||||
metadata = {
|
||||
"source": "iMessage",
|
||||
"chat_id": chat_id,
|
||||
"chat_name": first_msg["chat_display_name"],
|
||||
"chat_identifier": first_msg["chat_identifier"],
|
||||
"message_count": len(chat_messages),
|
||||
"first_message_date": first_msg["timestamp"],
|
||||
"last_message_date": last_msg["timestamp"],
|
||||
"participants": list(
|
||||
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
|
||||
),
|
||||
}
|
||||
|
||||
doc = Document(text=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
else:
|
||||
# Create individual documents for each message
|
||||
for message in messages:
|
||||
content = self._create_individual_content(message)
|
||||
|
||||
metadata = {
|
||||
"source": "iMessage",
|
||||
"message_id": message["message_id"],
|
||||
"chat_id": message["chat_id"],
|
||||
"chat_name": message["chat_display_name"],
|
||||
"chat_identifier": message["chat_identifier"],
|
||||
"timestamp": message["timestamp"],
|
||||
"is_from_me": message["is_from_me"],
|
||||
"contact_name": message["contact_name"],
|
||||
"service": message["service"],
|
||||
}
|
||||
|
||||
doc = Document(text=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
print(f"Created {len(docs)} documents from iMessage data")
|
||||
return docs
|
||||
@@ -1,125 +0,0 @@
|
||||
"""
|
||||
iMessage RAG Example.
|
||||
|
||||
This example demonstrates how to build a RAG system on your iMessage conversation history.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from leann.chunking_utils import create_text_chunks
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.imessage_data.imessage_reader import IMessageReader
|
||||
|
||||
|
||||
class IMessageRAG(BaseRAGExample):
|
||||
"""RAG example for iMessage conversation history."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="iMessage",
|
||||
description="RAG on your iMessage conversation history",
|
||||
default_index_name="imessage_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add iMessage-specific arguments."""
|
||||
imessage_group = parser.add_argument_group("iMessage Parameters")
|
||||
imessage_group.add_argument(
|
||||
"--db-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--no-concatenate-conversations",
|
||||
action="store_true",
|
||||
help="Process each message individually instead of concatenating by conversation",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum characters per text chunk (default: 1000)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--chunk-overlap",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Overlap between text chunks (default: 200)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load iMessage history and convert to text chunks."""
|
||||
print("Loading iMessage conversation history...")
|
||||
|
||||
# Determine concatenation setting
|
||||
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
|
||||
|
||||
# Initialize iMessage reader
|
||||
reader = IMessageReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Load documents
|
||||
try:
|
||||
if args.db_path:
|
||||
# Use custom database path
|
||||
db_dir = str(Path(args.db_path).parent)
|
||||
documents = reader.load_data(input_dir=db_dir)
|
||||
else:
|
||||
# Use default macOS location
|
||||
documents = reader.load_data()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading iMessage data: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
|
||||
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
|
||||
print("3. Try specifying a custom path with --db-path if you have a backup")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
print("No iMessage conversations found!")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} iMessage documents")
|
||||
|
||||
# Show some statistics
|
||||
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
|
||||
print(f"Total messages: {total_messages}")
|
||||
|
||||
if concatenate:
|
||||
# Show chat statistics
|
||||
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
|
||||
unique_chats = len(set(chat_names))
|
||||
print(f"Unique conversations: {unique_chats}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0:
|
||||
all_texts = all_texts[: args.max_items]
|
||||
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
app = IMessageRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,113 +0,0 @@
|
||||
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
|
||||
|
||||
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
|
||||
|
||||
### What you’ll run
|
||||
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
|
||||
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
|
||||
|
||||
## Prerequisites (macOS)
|
||||
|
||||
### 1) Homebrew poppler (for pdf2image)
|
||||
```bash
|
||||
brew install poppler
|
||||
which pdfinfo && pdfinfo -v
|
||||
```
|
||||
|
||||
### 2) Python environment
|
||||
Use uv (recommended) or pip. Python 3.9+.
|
||||
|
||||
Using uv:
|
||||
```bash
|
||||
uv pip install \
|
||||
colpali_engine \
|
||||
pdf2image \
|
||||
pillow \
|
||||
matplotlib qwen_vl_utils \
|
||||
einops \
|
||||
seaborn
|
||||
```
|
||||
|
||||
Notes:
|
||||
- On first run, models download from Hugging Face. Login/config if needed.
|
||||
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
|
||||
```bash
|
||||
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
|
||||
```
|
||||
|
||||
## Run the demos
|
||||
|
||||
### A) Local PDF example
|
||||
Converts a local PDF into page images, embeds them, builds an index, and searches.
|
||||
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
# If you don't have the sample PDF locally, download it (ignored by Git)
|
||||
mkdir -p pdfs
|
||||
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
|
||||
ls pdfs/2004.12832v2.pdf
|
||||
# Ensure output dir exists
|
||||
mkdir -p pages
|
||||
python multi-vector-leann-paper-example.py
|
||||
```
|
||||
Expected:
|
||||
- Page images in `pages/`.
|
||||
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
|
||||
|
||||
To use your own PDF: edit `pdf_path` near the top of the script.
|
||||
|
||||
### B) Similarity map + answer demo
|
||||
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
|
||||
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
Artifacts (when enabled):
|
||||
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
|
||||
- Similarity maps: `./figures/similarity_map_rank{K}.png`
|
||||
|
||||
Key knobs in the script (top of file):
|
||||
- `QUERY`: your question
|
||||
- `MODEL`: `"colqwen2"` or `"colpali"`
|
||||
- `USE_HF_DATASET`: set `False` to use local pages
|
||||
- `PDF`, `PAGES_DIR`: for local mode
|
||||
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
|
||||
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
|
||||
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
|
||||
|
||||
## Troubleshooting
|
||||
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
|
||||
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
|
||||
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
|
||||
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
|
||||
|
||||
## Notes
|
||||
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
|
||||
- For local PDFs, page images go to `./pages/`.
|
||||
|
||||
|
||||
### Retrieval and Visualization Example
|
||||
|
||||
Example settings in `multi-vector-leann-similarity-map.py`:
|
||||
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
|
||||
- `SIMILARITY_MAP = True` (to generate heatmaps)
|
||||
- `TOPK = 1` (save the top retrieved page and its similarity map)
|
||||
|
||||
Run:
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
|
||||
Outputs (by default):
|
||||
- Retrieved page: `./figures/retrieved_page_rank1.png`
|
||||
- Similarity map: `./figures/similarity_map_rank1.png`
|
||||
|
||||
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
|
||||
"):
|
||||

|
||||
|
||||
Notes:
|
||||
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
|
||||
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 166 KiB |
@@ -1,182 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
_repo_root = Path(current_file).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
||||
|
||||
|
||||
class LeannMultiVector:
|
||||
def __init__(
|
||||
self,
|
||||
index_path: str,
|
||||
dim: int = 128,
|
||||
distance_metric: str = "mips",
|
||||
m: int = 16,
|
||||
ef_construction: int = 500,
|
||||
is_compact: bool = False,
|
||||
is_recompute: bool = False,
|
||||
embedding_model_name: str = "colvision",
|
||||
) -> None:
|
||||
self.index_path = index_path
|
||||
self.dim = dim
|
||||
self.embedding_model_name = embedding_model_name
|
||||
self._pending_items: list[dict] = []
|
||||
self._backend_kwargs = {
|
||||
"distance_metric": distance_metric,
|
||||
"M": m,
|
||||
"efConstruction": ef_construction,
|
||||
"is_compact": is_compact,
|
||||
"is_recompute": is_recompute,
|
||||
}
|
||||
self._labels_meta: list[dict] = []
|
||||
|
||||
def _meta_dict(self) -> dict:
|
||||
return {
|
||||
"version": "1.0",
|
||||
"backend_name": "hnsw",
|
||||
"embedding_model": self.embedding_model_name,
|
||||
"embedding_mode": "custom",
|
||||
"dimensions": self.dim,
|
||||
"backend_kwargs": self._backend_kwargs,
|
||||
"is_compact": self._backend_kwargs.get("is_compact", True),
|
||||
"is_pruned": self._backend_kwargs.get("is_compact", True)
|
||||
and self._backend_kwargs.get("is_recompute", True),
|
||||
}
|
||||
|
||||
def create_collection(self) -> None:
|
||||
path = Path(self.index_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def insert(self, data: dict) -> None:
|
||||
self._pending_items.append(
|
||||
{
|
||||
"doc_id": int(data["doc_id"]),
|
||||
"filepath": data.get("filepath", ""),
|
||||
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
||||
}
|
||||
)
|
||||
|
||||
def _labels_path(self) -> Path:
|
||||
index_path_obj = Path(self.index_path)
|
||||
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
|
||||
|
||||
def _meta_path(self) -> Path:
|
||||
index_path_obj = Path(self.index_path)
|
||||
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||
|
||||
def create_index(self) -> None:
|
||||
if not self._pending_items:
|
||||
return
|
||||
|
||||
embeddings: list[np.ndarray] = []
|
||||
labels_meta: list[dict] = []
|
||||
|
||||
for item in self._pending_items:
|
||||
doc_id = int(item["doc_id"])
|
||||
filepath = item.get("filepath", "")
|
||||
colbert_vecs = item["colbert_vecs"]
|
||||
for seq_id, vec in enumerate(colbert_vecs):
|
||||
vec_np = np.asarray(vec, dtype=np.float32)
|
||||
embeddings.append(vec_np)
|
||||
labels_meta.append(
|
||||
{
|
||||
"id": f"{doc_id}:{seq_id}",
|
||||
"doc_id": doc_id,
|
||||
"seq_id": int(seq_id),
|
||||
"filepath": filepath,
|
||||
}
|
||||
)
|
||||
|
||||
if not embeddings:
|
||||
return
|
||||
|
||||
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||
# print shape of embeddings_np
|
||||
print(embeddings_np.shape)
|
||||
|
||||
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||
ids = [str(i) for i in range(embeddings_np.shape[0])]
|
||||
builder.build(embeddings_np, ids, self.index_path)
|
||||
|
||||
import json as _json
|
||||
|
||||
with open(self._meta_path(), "w", encoding="utf-8") as f:
|
||||
_json.dump(self._meta_dict(), f, indent=2)
|
||||
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||
_json.dump(labels_meta, f)
|
||||
|
||||
self._labels_meta = labels_meta
|
||||
|
||||
def _load_labels_meta_if_needed(self) -> None:
|
||||
if self._labels_meta:
|
||||
return
|
||||
labels_path = self._labels_path()
|
||||
if labels_path.exists():
|
||||
import json as _json
|
||||
|
||||
with open(labels_path, encoding="utf-8") as f:
|
||||
self._labels_meta = _json.load(f)
|
||||
|
||||
def search(
|
||||
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||
) -> list[tuple[float, int]]:
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(1, -1)
|
||||
if data.dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
self._load_labels_meta_if_needed()
|
||||
|
||||
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
||||
raw = searcher.search(
|
||||
data,
|
||||
first_stage_k,
|
||||
recompute_embeddings=False,
|
||||
complexity=128,
|
||||
beam_width=1,
|
||||
prune_ratio=0.0,
|
||||
batch_size=0,
|
||||
)
|
||||
|
||||
labels = raw.get("labels")
|
||||
distances = raw.get("distances")
|
||||
if labels is None or distances is None:
|
||||
return []
|
||||
|
||||
doc_scores: dict[int, float] = {}
|
||||
B = len(labels)
|
||||
for b in range(B):
|
||||
per_doc_best: dict[int, float] = {}
|
||||
for k, sid in enumerate(labels[b]):
|
||||
try:
|
||||
idx = int(sid)
|
||||
except Exception:
|
||||
continue
|
||||
if 0 <= idx < len(self._labels_meta):
|
||||
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
|
||||
else:
|
||||
continue
|
||||
score = float(distances[b][k])
|
||||
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
|
||||
per_doc_best[doc_id] = score
|
||||
for doc_id, best_score in per_doc_best.items():
|
||||
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
|
||||
|
||||
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
@@ -1,112 +0,0 @@
|
||||
# pip install pdf2image
|
||||
# pip install pymilvus
|
||||
# pip install colpali_engine
|
||||
# pip install tqdm
|
||||
# pip install pillow
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Ensure local leann packages are importable before importing them
|
||||
_repo_root = Path(__file__).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
|
||||
import torch
|
||||
from colpali_engine.models import ColPali
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||
_device_str = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
"mps"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
device = get_torch_device(_device_str)
|
||||
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=_dtype,
|
||||
device_map=device,
|
||||
).eval()
|
||||
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||
|
||||
queries = [
|
||||
"How to end-to-end retrieval with ColBert",
|
||||
"Where is ColBERT performance Table, including text representation results?",
|
||||
]
|
||||
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
qs: list[torch.Tensor] = []
|
||||
for batch_query in dataloader:
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
embeddings_query = model(**batch_query)
|
||||
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
print(qs[0].shape)
|
||||
# %%
|
||||
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](images),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
|
||||
ds: list[torch.Tensor] = []
|
||||
for batch_doc in tqdm(dataloader):
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
embeddings_doc = model(**batch_doc)
|
||||
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||
|
||||
print(ds[0].shape)
|
||||
|
||||
# %%
|
||||
# Build HNSW index via LeannRetriever primitives and run search
|
||||
index_path = "./indexes/colpali.leann"
|
||||
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||
retriever.create_collection()
|
||||
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||
for i in range(len(filepaths)):
|
||||
data = {
|
||||
"colbert_vecs": ds[i].float().numpy(),
|
||||
"doc_id": i,
|
||||
"filepath": filepaths[i],
|
||||
}
|
||||
retriever.insert(data)
|
||||
retriever.create_index()
|
||||
for query in qs:
|
||||
query_np = query.float().numpy()
|
||||
result = retriever.search(query_np, topk=1)
|
||||
print(filepaths[result[0][1]])
|
||||
@@ -1,477 +0,0 @@
|
||||
## Jupyter-style notebook script
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||
_repo_root = Path(current_file).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||
|
||||
# %%
|
||||
# Config
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
|
||||
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||
|
||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||
USE_HF_DATASET: bool = True
|
||||
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||
DATASET_SPLIT: str = "train"
|
||||
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||
|
||||
# Local pages (used when USE_HF_DATASET == False)
|
||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||
PAGES_DIR: str = "./pages"
|
||||
|
||||
# Index + retrieval settings
|
||||
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||
TOPK: int = 1
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = False
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
SIMILARITY_MAP: bool = True
|
||||
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||
ANSWER: bool = True
|
||||
MAX_NEW_TOKENS: int = 128
|
||||
|
||||
|
||||
# %%
|
||||
# Helpers
|
||||
def _natural_sort_key(name: str) -> int:
|
||||
m = re.search(r"\d+", name)
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||
filenames = sorted(filenames, key=_natural_sort_key)
|
||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||
images = [Image.open(p) for p in filepaths]
|
||||
return filepaths, images
|
||||
|
||||
|
||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||
if not pdf_path:
|
||||
return
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
try:
|
||||
from pdf2image import convert_from_path
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||
) from e
|
||||
images = convert_from_path(pdf_path, dpi=dpi)
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||
|
||||
|
||||
def _select_device_and_dtype():
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import get_torch_device
|
||||
|
||||
device_str = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
"mps"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
device = get_torch_device(device_str)
|
||||
# Stable dtype selection to avoid NaNs:
|
||||
# - CUDA: prefer bfloat16 if supported, else float16
|
||||
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||
# - CPU: float32
|
||||
if device_str == "cuda":
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||
except Exception:
|
||||
pass
|
||||
elif device_str == "mps":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = torch.float32
|
||||
return device_str, device, dtype
|
||||
|
||||
|
||||
def _load_colvision(model_choice: str):
|
||||
import torch
|
||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
device_str, device, dtype = _select_device_and_dtype()
|
||||
|
||||
if model_choice == "colqwen2":
|
||||
model_name = "vidore/colqwen2-v1.0"
|
||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||
attn_implementation = (
|
||||
"flash_attention_2"
|
||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||
else "eager"
|
||||
)
|
||||
model = ColQwen2.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
).eval()
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
else:
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
).eval()
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
|
||||
return model_name, model, processor, device_str, device, dtype
|
||||
|
||||
|
||||
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Ensure deterministic eval and autocast for stability
|
||||
model.eval()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[Image.Image](images),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
|
||||
doc_vecs: list[Any] = []
|
||||
for batch_doc in tqdm(dataloader, desc="Embedding images"):
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
embeddings_doc = model(**batch_doc)
|
||||
else:
|
||||
embeddings_doc = model(**batch_doc)
|
||||
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||
return doc_vecs
|
||||
|
||||
|
||||
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
model.eval()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
q_vecs: list[Any] = []
|
||||
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
embeddings_query = model(**batch_query)
|
||||
else:
|
||||
embeddings_query = model(**batch_query)
|
||||
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
return q_vecs
|
||||
|
||||
|
||||
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
|
||||
dim = int(doc_vecs[0].shape[-1])
|
||||
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||
retriever.create_collection()
|
||||
for i, vec in enumerate(doc_vecs):
|
||||
data = {
|
||||
"colbert_vecs": vec.float().numpy(),
|
||||
"doc_id": i,
|
||||
"filepath": filepaths[i],
|
||||
}
|
||||
retriever.insert(data)
|
||||
retriever.create_index()
|
||||
return retriever
|
||||
|
||||
|
||||
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
||||
index_base = Path(index_path)
|
||||
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||
if index_base.exists() and meta.exists() and labels.exists():
|
||||
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_similarity_map(
|
||||
model,
|
||||
processor,
|
||||
image: Image.Image,
|
||||
query: str,
|
||||
token_idx: Optional[int] = None,
|
||||
output_path: Optional[str] = None,
|
||||
) -> tuple[int, float]:
|
||||
import torch
|
||||
from colpali_engine.interpretability import (
|
||||
get_similarity_maps_from_embeddings,
|
||||
plot_similarity_map,
|
||||
)
|
||||
|
||||
batch_images = processor.process_images([image]).to(model.device)
|
||||
batch_queries = processor.process_queries([query]).to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_embeddings = model.forward(**batch_images)
|
||||
query_embeddings = model.forward(**batch_queries)
|
||||
|
||||
n_patches = processor.get_n_patches(
|
||||
image_size=image.size,
|
||||
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||
)
|
||||
image_mask = processor.get_image_mask(batch_images)
|
||||
|
||||
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||
image_embeddings=image_embeddings,
|
||||
query_embeddings=query_embeddings,
|
||||
n_patches=n_patches,
|
||||
image_mask=image_mask,
|
||||
)
|
||||
|
||||
similarity_maps = batched_similarity_maps[0]
|
||||
|
||||
# Determine token index if not provided: choose the token with highest max score
|
||||
if token_idx is None:
|
||||
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||
token_idx = int(per_token_max.argmax().item())
|
||||
|
||||
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||
|
||||
if output_path:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plot_similarity_map(
|
||||
image=image,
|
||||
similarity_map=similarity_maps[token_idx],
|
||||
figsize=(14, 14),
|
||||
show_colorbar=False,
|
||||
)
|
||||
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
plt.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
return token_idx, float(max_sim_score)
|
||||
|
||||
|
||||
class QwenVL:
|
||||
def __init__(self, device: str):
|
||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
|
||||
min_pixels = 256 * 28 * 28
|
||||
max_pixels = 1280 * 28 * 28
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||
)
|
||||
|
||||
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
content = []
|
||||
for img in images:
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="jpeg")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||
content.append({"type": "text", "text": query})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
return self.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# Step 1: Prepare data
|
||||
if USE_HF_DATASET:
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||
filepaths: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for i in tqdm(range(N), desc="Loading dataset", total=N ):
|
||||
p = dataset[i]
|
||||
# Compose a descriptive identifier for printing later
|
||||
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||
print(identifier)
|
||||
filepaths.append(identifier)
|
||||
images.append(p["page_image"]) # PIL Image
|
||||
else:
|
||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Step 2: Load model and processor
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
# Step 3: Build or load index
|
||||
retriever: Optional[LeannMultiVector] = None
|
||||
if not REBUILD_INDEX:
|
||||
try:
|
||||
one_vec = _embed_images(model, processor, [images[0]])[0]
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
||||
except Exception:
|
||||
retriever = None
|
||||
|
||||
if retriever is None:
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
|
||||
|
||||
|
||||
# %%
|
||||
# Step 4: Embed query and search
|
||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
top_images: list[Image.Image] = []
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
path = filepaths[doc_id]
|
||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||
top_images.append(images[doc_id])
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
|
||||
base = _Path(SAVE_TOP_IMAGE)
|
||||
base.parent.mkdir(parents=True, exist_ok=True)
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if base.suffix:
|
||||
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||
else:
|
||||
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||
img.save(str(out_path))
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
|
||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||
|
||||
# %%
|
||||
# Step 5: Similarity maps for top-K results
|
||||
if results and SIMILARITY_MAP:
|
||||
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||
from pathlib import Path as _Path
|
||||
|
||||
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if output_base:
|
||||
if output_base.suffix:
|
||||
out_dir = output_base.parent
|
||||
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||
out_path = str(out_dir / out_name)
|
||||
else:
|
||||
out_dir = output_base
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||
else:
|
||||
out_path = None
|
||||
chosen_idx, max_sim = _generate_similarity_map(
|
||||
model=model,
|
||||
processor=processor,
|
||||
image=img,
|
||||
query=QUERY,
|
||||
token_idx=token_idx,
|
||||
output_path=out_path,
|
||||
)
|
||||
if out_path:
|
||||
print(
|
||||
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Step 6: Optional answer generation
|
||||
if results and ANSWER:
|
||||
qwen = QwenVL(device=device_str)
|
||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||
print("\nAnswer:")
|
||||
print(response)
|
||||
@@ -1,183 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
|
||||
class TimeParser:
|
||||
def __init__(self):
|
||||
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
|
||||
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
|
||||
|
||||
# Compile for performance
|
||||
self.regex = re.compile(self.pattern, re.IGNORECASE)
|
||||
|
||||
# Stop words to remove before regex parsing
|
||||
self.stop_words = {
|
||||
"in",
|
||||
"at",
|
||||
"of",
|
||||
"by",
|
||||
"as",
|
||||
"me",
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"any",
|
||||
"find",
|
||||
"search",
|
||||
"list",
|
||||
"ago",
|
||||
"back",
|
||||
"past",
|
||||
"earlier",
|
||||
}
|
||||
|
||||
def clean_text(self, text):
|
||||
"""Remove stop words from text"""
|
||||
words = text.split()
|
||||
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
|
||||
return cleaned
|
||||
|
||||
def parse(self, text):
|
||||
"""Extract all time expressions from text"""
|
||||
# Clean text first
|
||||
cleaned_text = self.clean_text(text)
|
||||
|
||||
matches = []
|
||||
for match in self.regex.finditer(cleaned_text):
|
||||
fuzzy = match.group(1) # "around", "about", etc.
|
||||
number = int(match.group(2))
|
||||
unit = match.group(3).lower()
|
||||
|
||||
matches.append(
|
||||
{
|
||||
"full_match": match.group(0),
|
||||
"fuzzy": bool(fuzzy),
|
||||
"number": number,
|
||||
"unit": unit,
|
||||
"range": self.calculate_range(number, unit, bool(fuzzy)),
|
||||
}
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
def calculate_range(self, number, unit, is_fuzzy):
|
||||
"""Convert to actual datetime range and return ISO format strings"""
|
||||
units = {
|
||||
"hour": timedelta(hours=number),
|
||||
"day": timedelta(days=number),
|
||||
"week": timedelta(weeks=number),
|
||||
"month": timedelta(days=number * 30),
|
||||
"year": timedelta(days=number * 365),
|
||||
}
|
||||
|
||||
delta = units[unit]
|
||||
now = datetime.now()
|
||||
target = now - delta
|
||||
|
||||
if is_fuzzy:
|
||||
buffer = delta * 0.2 # 20% buffer for fuzzy
|
||||
start = (target - buffer).isoformat()
|
||||
end = (target + buffer).isoformat()
|
||||
else:
|
||||
start = target.isoformat()
|
||||
end = now.isoformat()
|
||||
|
||||
return (start, end)
|
||||
|
||||
|
||||
def search_files(query, top_k=15):
|
||||
"""Search the index and return results"""
|
||||
# Parse time expressions
|
||||
parser = TimeParser()
|
||||
time_matches = parser.parse(query)
|
||||
|
||||
# Remove time expressions from query for semantic search
|
||||
clean_query = query
|
||||
if time_matches:
|
||||
for match in time_matches:
|
||||
clean_query = clean_query.replace(match["full_match"], "").strip()
|
||||
|
||||
# Check if clean_query is less than 4 characters
|
||||
if len(clean_query) < 4:
|
||||
print("Error: add more input for accurate results.")
|
||||
return
|
||||
|
||||
# Single query to vector DB
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search(
|
||||
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
|
||||
)
|
||||
|
||||
# Filter by time if time expression found
|
||||
if time_matches:
|
||||
time_range = time_matches[0]["range"] # Use first time expression
|
||||
start_time, end_time = time_range
|
||||
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
# Access metadata attribute directly (not .get())
|
||||
metadata = result.metadata if hasattr(result, "metadata") else {}
|
||||
|
||||
if metadata:
|
||||
# Check modification date first, fall back to creation date
|
||||
date_str = metadata.get("modification_date") or metadata.get("creation_date")
|
||||
|
||||
if date_str:
|
||||
# Convert strings to datetime objects for proper comparison
|
||||
try:
|
||||
file_date = datetime.fromisoformat(date_str)
|
||||
start_dt = datetime.fromisoformat(start_time)
|
||||
end_dt = datetime.fromisoformat(end_time)
|
||||
|
||||
# Compare dates properly
|
||||
if start_dt <= file_date <= end_dt:
|
||||
filtered_results.append(result)
|
||||
except (ValueError, TypeError):
|
||||
# Handle invalid date formats
|
||||
print(f"Warning: Invalid date format in metadata: {date_str}")
|
||||
continue
|
||||
|
||||
results = filtered_results
|
||||
|
||||
# Print results
|
||||
print(f"\nSearch results for: '{query}'")
|
||||
if time_matches:
|
||||
print(
|
||||
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
|
||||
)
|
||||
print(
|
||||
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
|
||||
)
|
||||
print("-" * 80)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n[{i}] Score: {result.score:.4f}")
|
||||
print(f"Content: {result.text}")
|
||||
|
||||
# Show metadata if present
|
||||
metadata = result.metadata if hasattr(result, "metadata") else None
|
||||
if metadata:
|
||||
if "creation_date" in metadata:
|
||||
print(f"Created: {metadata['creation_date']}")
|
||||
if "modification_date" in metadata:
|
||||
print(f"Modified: {metadata['modification_date']}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print('Usage: python search_index.py "<search query>" [top_k]')
|
||||
sys.exit(1)
|
||||
|
||||
query = sys.argv[1]
|
||||
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
|
||||
|
||||
search_files(query, top_k)
|
||||
@@ -1,82 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
|
||||
def process_json_items(json_file_path):
|
||||
"""Load and process JSON file with metadata items"""
|
||||
|
||||
with open(json_file_path, encoding="utf-8") as f:
|
||||
items = json.load(f)
|
||||
|
||||
# Guard against empty JSON
|
||||
if not items:
|
||||
print("⚠️ No items found in the JSON file. Exiting gracefully.")
|
||||
return
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
|
||||
|
||||
total_items = len(items)
|
||||
items_added = 0
|
||||
print(f"Processing {total_items} items...")
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
try:
|
||||
# Create embedding text sentence
|
||||
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
|
||||
|
||||
# Prepare metadata with dates
|
||||
metadata = {}
|
||||
if "CreationDate" in item:
|
||||
metadata["creation_date"] = item["CreationDate"]
|
||||
if "ContentChangeDate" in item:
|
||||
metadata["modification_date"] = item["ContentChangeDate"]
|
||||
|
||||
# Add to builder
|
||||
builder.add_text(embedding_text, metadata=metadata)
|
||||
items_added += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
|
||||
continue
|
||||
|
||||
# Show progress
|
||||
progress = (idx + 1) / total_items * 100
|
||||
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
|
||||
sys.stdout.flush()
|
||||
|
||||
print() # New line after progress
|
||||
|
||||
# Guard against no successfully added items
|
||||
if items_added == 0:
|
||||
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
|
||||
return
|
||||
|
||||
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
|
||||
print("Building index...")
|
||||
|
||||
try:
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"✓ Index saved to {INDEX_PATH}")
|
||||
except ValueError as e:
|
||||
if "No chunks added" in str(e):
|
||||
print("⚠️ No chunks were added to the builder. Index not created.")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python build_index.py <json_file>")
|
||||
sys.exit(1)
|
||||
|
||||
json_file = sys.argv[1]
|
||||
if not Path(json_file).exists():
|
||||
print(f"Error: File {json_file} not found")
|
||||
sys.exit(1)
|
||||
|
||||
process_json_items(json_file)
|
||||
@@ -1,265 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Spotlight Metadata Dumper for Vector DB
|
||||
Extracts only essential metadata for semantic search embeddings
|
||||
Output is optimized for vector database storage with minimal fields
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# Check platform before importing macOS-specific modules
|
||||
if sys.platform != "darwin":
|
||||
print("This script requires macOS (uses Spotlight)")
|
||||
sys.exit(1)
|
||||
|
||||
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
|
||||
|
||||
# EDIT THIS LIST: Add or remove folders to search
|
||||
# Can be either:
|
||||
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
|
||||
# - Absolute paths (e.g., "/Applications", "/System/Library")
|
||||
SEARCH_FOLDERS = [
|
||||
"Desktop",
|
||||
"Downloads",
|
||||
"Documents",
|
||||
"Music",
|
||||
"Pictures",
|
||||
"Movies",
|
||||
# "Library", # Uncomment to include
|
||||
# "/Applications", # Absolute path example
|
||||
# "Code/Projects", # Subfolder example
|
||||
# Add any other folders here
|
||||
]
|
||||
|
||||
|
||||
def convert_to_serializable(obj):
|
||||
"""Convert NS objects to Python serializable types"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle NSDate
|
||||
if hasattr(obj, "timeIntervalSince1970"):
|
||||
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
|
||||
|
||||
# Handle NSArray
|
||||
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
|
||||
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
|
||||
|
||||
# Convert to string
|
||||
try:
|
||||
return str(obj)
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
|
||||
|
||||
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
|
||||
"""
|
||||
Dump Spotlight data using public.item predicate
|
||||
"""
|
||||
# Build full paths from SEARCH_FOLDERS
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
search_paths = []
|
||||
|
||||
print("Search locations:")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Check if it's an absolute path or relative
|
||||
if folder.startswith("/"):
|
||||
full_path = folder
|
||||
else:
|
||||
full_path = os.path.join(home_dir, folder)
|
||||
|
||||
if os.path.exists(full_path):
|
||||
search_paths.append(full_path)
|
||||
print(f" ✓ {full_path}")
|
||||
else:
|
||||
print(f" ✗ {full_path} (not found)")
|
||||
|
||||
if not search_paths:
|
||||
print("No valid search paths found!")
|
||||
return []
|
||||
|
||||
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
|
||||
|
||||
# Create query with public.item predicate
|
||||
query = NSMetadataQuery.alloc().init()
|
||||
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
|
||||
query.setPredicate_(predicate)
|
||||
|
||||
# Set search scopes to our specific folders
|
||||
query.setSearchScopes_(search_paths)
|
||||
|
||||
print("Starting query...")
|
||||
query.startQuery()
|
||||
|
||||
# Wait for gathering to complete
|
||||
run_loop = NSRunLoop.currentRunLoop()
|
||||
print("Gathering results...")
|
||||
|
||||
# Let it gather for a few seconds
|
||||
for i in range(50): # 5 seconds max
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
# Check gathering status periodically
|
||||
if i % 10 == 0:
|
||||
current_count = query.resultCount()
|
||||
if current_count > 0:
|
||||
print(f" Found {current_count} items so far...")
|
||||
|
||||
# Continue while still gathering (up to 2 more seconds)
|
||||
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
|
||||
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
|
||||
query.stopQuery()
|
||||
|
||||
total_results = query.resultCount()
|
||||
print(f"Found {total_results} total items")
|
||||
|
||||
if total_results == 0:
|
||||
print("No results found")
|
||||
return []
|
||||
|
||||
# Process items
|
||||
items_to_process = min(total_results, max_items)
|
||||
results = []
|
||||
|
||||
# ONLY relevant attributes for vector embeddings
|
||||
# These provide essential context for semantic search without bloat
|
||||
attributes = [
|
||||
"kMDItemPath", # Full path for file retrieval
|
||||
"kMDItemFSName", # Filename for display & embedding
|
||||
"kMDItemFSSize", # Size for filtering/ranking
|
||||
"kMDItemContentType", # File type for categorization
|
||||
"kMDItemKind", # Human-readable type for embedding
|
||||
"kMDItemFSCreationDate", # Temporal context
|
||||
"kMDItemFSContentChangeDate", # Recency for ranking
|
||||
]
|
||||
|
||||
print(f"Processing {items_to_process} items...")
|
||||
|
||||
for i in range(items_to_process):
|
||||
try:
|
||||
item = query.resultAtIndex_(i)
|
||||
metadata = {}
|
||||
|
||||
# Extract ONLY the relevant attributes
|
||||
for attr in attributes:
|
||||
try:
|
||||
value = item.valueForAttribute_(attr)
|
||||
if value is not None:
|
||||
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
|
||||
clean_key = attr.replace("kMDItem", "").replace("FS", "")
|
||||
metadata[clean_key] = convert_to_serializable(value)
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Only add if we have at least a path
|
||||
if metadata.get("Path"):
|
||||
results.append(metadata)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing item {i}: {e}")
|
||||
continue
|
||||
|
||||
# Save to JSON
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\n✓ Saved {len(results)} items to {output_file}")
|
||||
|
||||
# Show summary
|
||||
print("\nSample items:")
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
for i, item in enumerate(results[:3]):
|
||||
print(f"\n[Item {i + 1}]")
|
||||
print(f" Path: {item.get('Path', 'N/A')}")
|
||||
print(f" Name: {item.get('Name', 'N/A')}")
|
||||
print(f" Type: {item.get('ContentType', 'N/A')}")
|
||||
print(f" Kind: {item.get('Kind', 'N/A')}")
|
||||
|
||||
# Handle size properly
|
||||
size = item.get("Size")
|
||||
if size:
|
||||
try:
|
||||
size_int = int(size)
|
||||
if size_int > 1024 * 1024:
|
||||
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
|
||||
elif size_int > 1024:
|
||||
print(f" Size: {size_int / 1024:.2f} KB")
|
||||
else:
|
||||
print(f" Size: {size_int} bytes")
|
||||
except (ValueError, TypeError):
|
||||
print(f" Size: {size}")
|
||||
|
||||
# Show dates
|
||||
if "CreationDate" in item:
|
||||
print(f" Created: {item['CreationDate']}")
|
||||
if "ContentChangeDate" in item:
|
||||
print(f" Modified: {item['ContentChangeDate']}")
|
||||
|
||||
# Count by type
|
||||
type_counts = {}
|
||||
for item in results:
|
||||
content_type = item.get("ContentType", "unknown")
|
||||
type_counts[content_type] = type_counts.get(content_type, 0) + 1
|
||||
|
||||
print(f"\nTotal items saved: {len(results)}")
|
||||
|
||||
if type_counts:
|
||||
print("\nTop content types:")
|
||||
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||
print(f" {ct}: {count} items")
|
||||
|
||||
# Count by folder
|
||||
folder_counts = {}
|
||||
for item in results:
|
||||
path = item.get("Path", "")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Build the full folder path
|
||||
if folder.startswith("/"):
|
||||
folder_path = folder
|
||||
else:
|
||||
folder_path = os.path.join(home_dir, folder)
|
||||
|
||||
if path.startswith(folder_path):
|
||||
folder_counts[folder] = folder_counts.get(folder, 0) + 1
|
||||
break
|
||||
|
||||
if folder_counts:
|
||||
print("\nItems by location:")
|
||||
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" {folder}: {count} items")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
max_items = int(sys.argv[1])
|
||||
except ValueError:
|
||||
print("Usage: python spot.py [number_of_items]")
|
||||
print("Default: 10 items")
|
||||
sys.exit(1)
|
||||
else:
|
||||
max_items = 10
|
||||
|
||||
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
|
||||
|
||||
# Run dump
|
||||
dump_spotlight_data(max_items=max_items, output_file=output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,23 +0,0 @@
|
||||
BM25 vs DiskANN Baselines
|
||||
|
||||
```bash
|
||||
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
|
||||
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
|
||||
```
|
||||
|
||||
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
|
||||
- Machine-specific; results measured locally with the current repo.
|
||||
|
||||
DiskANN (NQ queries, search-only)
|
||||
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
|
||||
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
|
||||
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
|
||||
|
||||
BM25
|
||||
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
|
||||
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
|
||||
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
|
||||
|
||||
Notes
|
||||
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
|
||||
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.
|
||||
|
Before Width: | Height: | Size: 1.3 KiB |
@@ -1,183 +0,0 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "pyserini"
|
||||
# ]
|
||||
# ///
|
||||
# sudo pacman -S jdk21-openjdk
|
||||
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
|
||||
# sudo archlinux-java status
|
||||
# sudo archlinux-java set java-21-openjdk
|
||||
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
|
||||
# fish_add_path --global $JAVA_HOME/bin
|
||||
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
|
||||
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from statistics import mean
|
||||
|
||||
|
||||
def load_queries(path: str, limit: int | None) -> list[str]:
|
||||
queries: list[str] = []
|
||||
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
|
||||
_, ext = os.path.splitext(path)
|
||||
if ext.lower() in {".jsonl", ".json"}:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
# Not strict JSONL? treat the whole line as the query
|
||||
queries.append(line)
|
||||
continue
|
||||
q = obj.get("query") or obj.get("text") or obj.get("question")
|
||||
if q:
|
||||
queries.append(str(q))
|
||||
else:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
s = line.strip()
|
||||
if s:
|
||||
queries.append(s)
|
||||
|
||||
if limit is not None and limit > 0:
|
||||
queries = queries[:limit]
|
||||
return queries
|
||||
|
||||
|
||||
def percentile(values: list[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
s = sorted(values)
|
||||
k = (len(s) - 1) * (p / 100.0)
|
||||
f = int(k)
|
||||
c = min(f + 1, len(s) - 1)
|
||||
if f == c:
|
||||
return s[f]
|
||||
return s[f] + (s[c] - s[f]) * (k - f)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
|
||||
ap.add_argument(
|
||||
"--bm25-index",
|
||||
default="benchmarks/data/indices/bm25_index",
|
||||
help="Path to Pyserini Lucene index directory",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--queries",
|
||||
default="benchmarks/data/queries/nq_open.jsonl",
|
||||
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
|
||||
)
|
||||
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
|
||||
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
|
||||
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
|
||||
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
|
||||
ap.add_argument(
|
||||
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
|
||||
)
|
||||
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
|
||||
args = ap.parse_args()
|
||||
|
||||
try:
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
except Exception:
|
||||
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
|
||||
raise
|
||||
|
||||
if not os.path.isdir(args.bm25_index):
|
||||
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
queries = load_queries(args.queries, args.limit)
|
||||
if not queries:
|
||||
print("No queries loaded.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loaded {len(queries)} queries from {args.queries}")
|
||||
print(f"Opening BM25 index: {args.bm25_index}")
|
||||
searcher = LuceneSearcher(args.bm25_index)
|
||||
# Some builds of pyserini require explicit set_bm25; others ignore
|
||||
try:
|
||||
searcher.set_bm25(k1=args.k1, b=args.b)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
latencies: list[float] = []
|
||||
total_searches = 0
|
||||
|
||||
# Warmup
|
||||
for i in range(min(args.warmup, len(queries))):
|
||||
_ = searcher.search(queries[i], k=args.k)
|
||||
|
||||
t0 = time.time()
|
||||
for i, q in enumerate(queries):
|
||||
t1 = time.time()
|
||||
hits = searcher.search(q, k=args.k)
|
||||
t2 = time.time()
|
||||
latencies.append(t2 - t1)
|
||||
total_searches += 1
|
||||
|
||||
if args.fetch_docs:
|
||||
# Optional doc fetch to include I/O time
|
||||
for h in hits:
|
||||
try:
|
||||
_ = searcher.doc(h.docid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (i + 1) % 50 == 0:
|
||||
print(f"Processed {i + 1}/{len(queries)} queries")
|
||||
|
||||
t1 = time.time()
|
||||
total_time = t1 - t0
|
||||
|
||||
if latencies:
|
||||
avg = mean(latencies)
|
||||
p50 = percentile(latencies, 50)
|
||||
p90 = percentile(latencies, 90)
|
||||
p95 = percentile(latencies, 95)
|
||||
p99 = percentile(latencies, 99)
|
||||
qps = total_searches / total_time if total_time > 0 else 0.0
|
||||
else:
|
||||
avg = p50 = p90 = p95 = p99 = qps = 0.0
|
||||
|
||||
print("BM25 Latency Report")
|
||||
print(f" queries: {total_searches}")
|
||||
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
|
||||
print(f" avg per query: {avg:.6f} s")
|
||||
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
|
||||
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
|
||||
|
||||
if args.report:
|
||||
payload = {
|
||||
"queries": total_searches,
|
||||
"k": args.k,
|
||||
"k1": args.k1,
|
||||
"b": args.b,
|
||||
"avg_s": avg,
|
||||
"p50_s": p50,
|
||||
"p90_s": p90,
|
||||
"p95_s": p95,
|
||||
"p99_s": p99,
|
||||
"total_time_s": total_time,
|
||||
"qps": qps,
|
||||
"index_dir": os.path.abspath(args.bm25_index),
|
||||
"fetch_docs": bool(args.fetch_docs),
|
||||
}
|
||||
with open(args.report, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2)
|
||||
print(f"Saved report to {args.report}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,124 +0,0 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "leann-backend-diskann"
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_queries(path: Path, limit: int | None) -> list[str]:
|
||||
out: list[str] = []
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
obj = json.loads(line)
|
||||
out.append(obj["query"])
|
||||
if limit and len(out) >= limit:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(
|
||||
description="DiskANN baseline on real NQ queries (search-only timing)"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--index-dir",
|
||||
default="benchmarks/data/indices/diskann_rpj_wiki",
|
||||
help="Directory containing DiskANN files",
|
||||
)
|
||||
ap.add_argument("--index-prefix", default="ann")
|
||||
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
|
||||
ap.add_argument("--num-queries", type=int, default=200)
|
||||
ap.add_argument("--top-k", type=int, default=10)
|
||||
ap.add_argument("--complexity", type=int, default=62)
|
||||
ap.add_argument("--threads", type=int, default=1)
|
||||
ap.add_argument("--beam-width", type=int, default=1)
|
||||
ap.add_argument("--cache-mechanism", type=int, default=2)
|
||||
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
|
||||
args = ap.parse_args()
|
||||
|
||||
index_dir = Path(args.index_dir).resolve()
|
||||
if not index_dir.is_dir():
|
||||
raise SystemExit(f"Index dir not found: {index_dir}")
|
||||
|
||||
qpath = Path(args.queries_file).resolve()
|
||||
if not qpath.exists():
|
||||
raise SystemExit(f"Queries file not found: {qpath}")
|
||||
|
||||
queries = load_queries(qpath, args.num_queries)
|
||||
print(f"Loaded {len(queries)} queries from {qpath}")
|
||||
|
||||
# Compute embeddings once (exclude from timing)
|
||||
from leann.api import compute_embeddings as _compute
|
||||
|
||||
embs = _compute(
|
||||
queries,
|
||||
model_name="facebook/contriever-msmarco",
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
if embs.ndim != 2:
|
||||
raise SystemExit("Embedding compute failed or returned wrong shape")
|
||||
|
||||
# Build searcher
|
||||
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
|
||||
|
||||
index_prefix_path = str(index_dir / args.index_prefix)
|
||||
searcher = _DiskannSearcher(
|
||||
index_prefix_path,
|
||||
num_threads=int(args.threads),
|
||||
cache_mechanism=int(args.cache_mechanism),
|
||||
num_nodes_to_cache=int(args.num_nodes_to_cache),
|
||||
)
|
||||
|
||||
# Warmup (not timed)
|
||||
_ = searcher.search(
|
||||
embs[0:1],
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
batch_recompute=False,
|
||||
dedup_node_dis=False,
|
||||
)
|
||||
|
||||
# Timed loop
|
||||
times: list[float] = []
|
||||
for i in range(embs.shape[0]):
|
||||
t0 = time.time()
|
||||
_ = searcher.search(
|
||||
embs[i : i + 1],
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
batch_recompute=False,
|
||||
dedup_node_dis=False,
|
||||
)
|
||||
times.append(time.time() - t0)
|
||||
|
||||
times_sorted = sorted(times)
|
||||
avg = float(sum(times) / len(times))
|
||||
p50 = times_sorted[len(times) // 2]
|
||||
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
|
||||
|
||||
print("\nDiskANN (NQ, search-only) Report")
|
||||
print(f" queries: {len(times)}")
|
||||
print(
|
||||
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
|
||||
)
|
||||
print(f" avg per query: {avg:.6f} s")
|
||||
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
|
||||
print(f" QPS: {1.0 / avg:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,141 +0,0 @@
|
||||
# Enron Emails Benchmark
|
||||
|
||||
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
|
||||
|
||||
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
|
||||
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
|
||||
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
|
||||
|
||||
## Layout
|
||||
|
||||
benchmarks/enron_emails/
|
||||
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
|
||||
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
|
||||
- data/: Generated passages, queries, embeddings-related files
|
||||
- baseline/: FAISS Flat baseline files
|
||||
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
|
||||
|
||||
## Quickstart
|
||||
|
||||
1) Prepare the data and index
|
||||
|
||||
cd benchmarks/enron_emails
|
||||
python setup_enron_emails.py --data-dir data
|
||||
|
||||
Notes:
|
||||
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
|
||||
Alternatively, pass a local path to `--emails-csv`.
|
||||
|
||||
Notes:
|
||||
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
|
||||
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
|
||||
|
||||
2) Run recall evaluation (Stage 2)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
|
||||
|
||||
3) Complexity sweep (Stage 3)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
|
||||
|
||||
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
|
||||
|
||||
4) Index comparison (Stage 4)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
|
||||
|
||||
5) Generation evaluation (Stage 5)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
|
||||
|
||||
6) Combined index + generation evaluation (Stages 4+5, recommended)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
|
||||
|
||||
Notes:
|
||||
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
|
||||
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
|
||||
- `--baseline-dir` defaults to `baseline`
|
||||
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
|
||||
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
|
||||
- `--model-name` defaults to `Qwen/Qwen3-8B`
|
||||
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
|
||||
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
|
||||
|
||||
Optional flags:
|
||||
- --queries data/evaluation_queries.jsonl (custom queries file)
|
||||
- --baseline-dir baseline (where FAISS baseline lives)
|
||||
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
|
||||
- --llm-backend hf|vllm (LLM backend for generation)
|
||||
- --model-name Qwen/Qwen3-8B (LLM model for generation)
|
||||
- --max-queries 1000 (limit number of queries for evaluation)
|
||||
|
||||
## Files Produced
|
||||
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
|
||||
- data/enron_index_hnsw.leann.*: LEANN index files
|
||||
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
|
||||
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
|
||||
|
||||
## Notes
|
||||
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
|
||||
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
|
||||
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
|
||||
|
||||
## Stages Summary
|
||||
|
||||
- Stage 2 (Recall@3):
|
||||
- Compares LEANN vs FAISS Flat baseline on Recall@3.
|
||||
- Compact index runs with `recompute_embeddings=True`.
|
||||
|
||||
- Stage 3 (Binary Search for Complexity):
|
||||
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
|
||||
|
||||
- Stage 4 (Index Comparison):
|
||||
- Reports .index-only sizes for compact vs non-compact.
|
||||
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
|
||||
- Stores retrieval results for Stage 5 generation evaluation.
|
||||
- Fails fast if compact recompute cannot run.
|
||||
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
|
||||
- First from the current run (when running `--stage all`), otherwise
|
||||
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
|
||||
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
|
||||
|
||||
- Stage 5 (Generation Evaluation):
|
||||
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
|
||||
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
|
||||
- Measures generation timing separately from search timing.
|
||||
- Requires Stage 4 results (no additional searching performed).
|
||||
|
||||
## Example Results
|
||||
|
||||
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
|
||||
|
||||
- Stage 3 (Binary Search):
|
||||
- Minimal complexity achieving 90% Recall@3: 88
|
||||
- Sampled points:
|
||||
- C=8 → 59.9% Recall@3
|
||||
- C=72 → 89.4% Recall@3
|
||||
- C=88 → 90.2% Recall@3
|
||||
- C=96 → 90.7% Recall@3
|
||||
- C=112 → 91.1% Recall@3
|
||||
- C=136 → 91.3% Recall@3
|
||||
- C=256 → 92.0% Recall@3
|
||||
|
||||
- Stage 4 (Index Sizes, .index only):
|
||||
- Compact: ~2.2 MB
|
||||
- Non-compact: ~82.0 MB
|
||||
- Storage saving by compact: ~97.3%
|
||||
|
||||
- Stage 4 (Search Timing, 988 queries, complexity=88):
|
||||
- Non-compact (no recompute): ~0.0075 s avg per query
|
||||
- Compact (with recompute): ~1.981 s avg per query
|
||||
- Speed ratio (non-compact/compact): ~0.0038x
|
||||
|
||||
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
|
||||
- Average generation time: ~22.302 s per query
|
||||
- Total queries processed: 988
|
||||
- LLM backend: HuggingFace transformers
|
||||
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||
|
||||
Full JSON output is saved by the script (see `--output`), e.g.:
|
||||
`benchmarks/enron_emails/results_enron_stage45.json`.
|
||||
1
benchmarks/enron_emails/data/.gitignore
vendored
1
benchmarks/enron_emails/data/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
downloads/
|
||||
@@ -1,614 +0,0 @@
|
||||
"""
|
||||
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
||||
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
||||
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
||||
On errors, fail fast without fallbacks.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.passage_ids = pickle.load(f)
|
||||
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||
|
||||
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
for i, query in enumerate(queries):
|
||||
# Compute query embedding with the same model/mode as the index
|
||||
q_emb = compute_embeddings(
|
||||
[query],
|
||||
self.searcher.embedding_model,
|
||||
mode=self.searcher.embedding_mode,
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS Flat ground truth
|
||||
n = q_emb.shape[0]
|
||||
k = 3
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(q_emb),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN (may require embedding server depending on index configuration)
|
||||
results = self.searcher.search(
|
||||
query,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {r.id for r in results}
|
||||
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0
|
||||
total_recall += recall
|
||||
|
||||
if i < 3:
|
||||
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS: {list(baseline_ids)}")
|
||||
print(f" LEANN: {list(test_ids)}")
|
||||
print(f" ∩: {list(intersection)}")
|
||||
|
||||
avg = total_recall / max(1, len(queries))
|
||||
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
||||
return avg
|
||||
|
||||
def cleanup(self):
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class EnronEvaluator:
|
||||
def __init__(self, index_path: str):
|
||||
self.index_path = index_path
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
def load_queries(self, queries_file: str) -> list[str]:
|
||||
queries: list[str] = []
|
||||
with open(queries_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
data = json.loads(line)
|
||||
if "query" in data:
|
||||
queries.append(data["query"])
|
||||
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
||||
return queries
|
||||
|
||||
def cleanup(self):
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||
|
||||
print("📏 Analyzing index sizes (.index only)...")
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem
|
||||
|
||||
sizes: dict[str, float] = {}
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json"
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
||||
|
||||
sizes["index_only_mb"] = (
|
||||
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||
)
|
||||
sizes["metadata_mb"] = (
|
||||
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_text_mb"] = (
|
||||
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_index_mb"] = (
|
||||
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||
)
|
||||
|
||||
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
||||
return sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source and embedding model
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
# Load all passages and ids
|
||||
ids: list[str] = []
|
||||
texts: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
ids.append(str(data["id"]))
|
||||
texts.append(data["text"])
|
||||
|
||||
# Compute embeddings using the same method as LEANN
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
meta["embedding_model"],
|
||||
mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Build non-compact index with same passages and embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=False,
|
||||
is_compact=False,
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact"]
|
||||
},
|
||||
)
|
||||
|
||||
# Persist a pickle for build_index_from_embeddings
|
||||
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings), pf)
|
||||
|
||||
print(
|
||||
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||
)
|
||||
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
||||
) -> dict:
|
||||
"""Compare search speed for non-compact vs compact indexes."""
|
||||
import time
|
||||
|
||||
results: dict = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
"retrieval_results": [], # Store retrieval results for Stage 5
|
||||
}
|
||||
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
# Non-compact (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
for q in test_queries:
|
||||
t0 = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
results["non_compact"]["search_times"].append(time.time() - t0)
|
||||
|
||||
# Compact (with recompute). Fail fast if it cannot run.
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
for q in test_queries:
|
||||
t0 = time.time()
|
||||
docs = compact_searcher.search(
|
||||
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
results["compact"]["search_times"].append(time.time() - t0)
|
||||
|
||||
# Store retrieval results for Stage 5
|
||||
results["retrieval_results"].append(
|
||||
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
||||
)
|
||||
compact_searcher.cleanup()
|
||||
|
||||
if results["non_compact"]["search_times"]:
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
if results["compact"]["search_times"]:
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
if results["avg_search_times"].get("compact", 0) > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = 0.0
|
||||
|
||||
non_compact_searcher.cleanup()
|
||||
return results
|
||||
|
||||
def evaluate_complexity(
|
||||
self,
|
||||
recall_eval: "RecallEvaluator",
|
||||
queries: list[str],
|
||||
target: float = 0.90,
|
||||
c_min: int = 8,
|
||||
c_max: int = 256,
|
||||
max_iters: int = 10,
|
||||
recompute: bool = False,
|
||||
) -> dict:
|
||||
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
||||
|
||||
def round_c(x: int) -> int:
|
||||
# snap to multiple of 8 like other benches typically do
|
||||
return max(1, int((x + 7) // 8) * 8)
|
||||
|
||||
metrics: list[dict] = []
|
||||
|
||||
lo = round_c(c_min)
|
||||
hi = round_c(c_max)
|
||||
|
||||
print(
|
||||
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
||||
)
|
||||
|
||||
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
||||
r_lo = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=lo, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
||||
r_hi = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=hi, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||
|
||||
cap = 1024
|
||||
while r_hi < target and hi < cap:
|
||||
lo = hi
|
||||
r_lo = r_hi
|
||||
hi = round_c(hi * 2)
|
||||
r_hi = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=hi, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||
|
||||
if r_hi < target:
|
||||
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
||||
print("📈 Observations:")
|
||||
for m in metrics:
|
||||
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
||||
|
||||
# Binary search within [lo, hi]
|
||||
best = hi
|
||||
iters = 0
|
||||
while lo < hi and iters < max_iters:
|
||||
mid = round_c((lo + hi) // 2)
|
||||
r_mid = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=mid, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
||||
if r_mid >= target:
|
||||
best = mid
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
||||
iters += 1
|
||||
|
||||
print("📈 Binary search results (sampled points):")
|
||||
# Print unique complexity entries ordered by complexity
|
||||
for m in sorted(
|
||||
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
||||
):
|
||||
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
||||
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument(
|
||||
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "5", "all", "45"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument(
|
||||
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||
)
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
||||
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve queries file: if default path not found, fall back to index's directory
|
||||
if not os.path.exists(args.queries):
|
||||
from pathlib import Path
|
||||
|
||||
idx_dir = Path(args.index).parent
|
||||
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
||||
if fallback_q.exists():
|
||||
args.queries = str(fallback_q)
|
||||
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
||||
raise SystemExit(1)
|
||||
|
||||
results_out: dict = {}
|
||||
|
||||
if args.stage in ("2", "all"):
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
queries = queries[:10]
|
||||
print(f"🧪 Using first {len(queries)} queries")
|
||||
|
||||
complexity = args.complexity or 64
|
||||
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
||||
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
||||
evaluator.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
if args.stage in ("3", "all"):
|
||||
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
queries = queries[: args.max_queries]
|
||||
print(f"🧪 Using first {len(queries)} queries")
|
||||
|
||||
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
||||
from pathlib import Path
|
||||
|
||||
index_path = Path(args.index)
|
||||
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact evaluator for binary search with recompute=False
|
||||
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
sweep = enron_eval.evaluate_complexity(
|
||||
evaluator_nc, queries, target=args.target_recall, recompute=False
|
||||
)
|
||||
results_out["stage3"] = sweep
|
||||
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
||||
from pathlib import Path
|
||||
|
||||
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"stage3": sweep}, f, indent=2)
|
||||
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
||||
evaluator_nc.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 3 completed!\n")
|
||||
|
||||
if args.stage in ("4", "all", "45"):
|
||||
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
test_q = queries[: min(args.max_queries, len(queries))]
|
||||
|
||||
current_sizes = enron_eval.analyze_index_sizes()
|
||||
# Build non-compact index for comparison (no fallback)
|
||||
from pathlib import Path
|
||||
|
||||
index_path = Path(args.index)
|
||||
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
||||
nc_eval = EnronEvaluator(non_compact_path)
|
||||
|
||||
if (
|
||||
current_sizes.get("index_only_mb", 0) > 0
|
||||
and non_compact_sizes.get("index_only_mb", 0) > 0
|
||||
):
|
||||
storage_saving_percent = max(
|
||||
0.0,
|
||||
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
||||
)
|
||||
else:
|
||||
storage_saving_percent = 0.0
|
||||
|
||||
if args.complexity is None:
|
||||
# Prefer in-session Stage 3 result
|
||||
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
||||
complexity = results_out["stage3"]["best_complexity"]
|
||||
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
||||
else:
|
||||
# Try to load last saved Stage 3 result near index
|
||||
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||
if default_stage3_path.exists():
|
||||
with open(default_stage3_path, encoding="utf-8") as f:
|
||||
prev = json.load(f)
|
||||
complexity = prev.get("stage3", {}).get("best_complexity")
|
||||
if complexity is None:
|
||||
raise SystemExit(
|
||||
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
|
||||
)
|
||||
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
||||
else:
|
||||
raise SystemExit(
|
||||
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
||||
)
|
||||
else:
|
||||
complexity = args.complexity
|
||||
|
||||
comp = enron_eval.compare_index_performance(
|
||||
non_compact_path, args.index, test_q, complexity=complexity
|
||||
)
|
||||
results_out["stage4"] = {
|
||||
"current_index": current_sizes,
|
||||
"non_compact_index": non_compact_sizes,
|
||||
"storage_saving_percent": storage_saving_percent,
|
||||
"performance_comparison": comp,
|
||||
}
|
||||
nc_eval.cleanup()
|
||||
evaluator.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage in ("5", "all"):
|
||||
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
||||
|
||||
# Check if Stage 4 results exist
|
||||
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
||||
print("❌ Stage 5 requires Stage 4 retrieval results")
|
||||
print("💡 Run Stage 4 first or use --stage all")
|
||||
raise SystemExit(1)
|
||||
|
||||
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
||||
if not retrieval_results:
|
||||
print("❌ No retrieval results found from Stage 4")
|
||||
raise SystemExit(1)
|
||||
|
||||
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
||||
|
||||
# Load LLM
|
||||
try:
|
||||
if args.llm_backend == "hf":
|
||||
tokenizer, model = load_hf_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_hf(tokenizer, model, prompt)
|
||||
else: # vllm
|
||||
llm, sampling_params = load_vllm_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_vllm(llm, sampling_params, prompt)
|
||||
|
||||
# Run generation using stored retrieval results
|
||||
import time
|
||||
|
||||
from llm_utils import create_prompt
|
||||
|
||||
generation_times = []
|
||||
responses = []
|
||||
|
||||
print("🤖 Running generation on pre-retrieved results...")
|
||||
for i, item in enumerate(retrieval_results):
|
||||
query = item["query"]
|
||||
retrieved_docs = item["retrieved_docs"]
|
||||
|
||||
# Prepare context from retrieved docs
|
||||
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
||||
prompt = create_prompt(context, query, "emails")
|
||||
|
||||
# Time generation only
|
||||
gen_start = time.time()
|
||||
response = llm_func(prompt)
|
||||
gen_time = time.time() - gen_start
|
||||
|
||||
generation_times.append(gen_time)
|
||||
responses.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
||||
|
||||
avg_gen_time = sum(generation_times) / len(generation_times)
|
||||
|
||||
print("\n📊 Generation Results:")
|
||||
print(f" Total Queries: {len(retrieval_results)}")
|
||||
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
||||
print(" (Search time from Stage 4)")
|
||||
|
||||
results_out["stage5"] = {
|
||||
"total_queries": len(retrieval_results),
|
||||
"avg_generation_time": avg_gen_time,
|
||||
"generation_times": generation_times,
|
||||
"responses": responses,
|
||||
}
|
||||
|
||||
# Show sample results
|
||||
print("\n📝 Sample Results:")
|
||||
for i in range(min(3, len(retrieval_results))):
|
||||
query = retrieval_results[i]["query"]
|
||||
response = responses[i]
|
||||
print(f" Q{i + 1}: {query[:60]}...")
|
||||
print(f" A{i + 1}: {response[:100]}...")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation evaluation failed: {e}")
|
||||
print("💡 Make sure transformers/vllm is installed and model is available")
|
||||
|
||||
print("✅ Stage 5 completed!\n")
|
||||
|
||||
if args.output and results_out:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(results_out, f, indent=2)
|
||||
print(f"📝 Saved results to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,359 +0,0 @@
|
||||
"""
|
||||
Enron Emails Benchmark Setup Script
|
||||
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from email import message_from_string
|
||||
from email.policy import default
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
|
||||
class EnronSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
|
||||
self.index_path = self.data_dir / "enron_index_hnsw.leann"
|
||||
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||
self.downloads_dir = self.data_dir / "downloads"
|
||||
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ----------------------------
|
||||
# Dataset acquisition
|
||||
# ----------------------------
|
||||
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
|
||||
"""Return a path to emails.csv, downloading from Kaggle if needed."""
|
||||
if emails_csv:
|
||||
p = Path(emails_csv)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
|
||||
return str(p)
|
||||
|
||||
print(
|
||||
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
|
||||
)
|
||||
try:
|
||||
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||
|
||||
api = KaggleApi()
|
||||
api.authenticate()
|
||||
api.dataset_download_files(
|
||||
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
|
||||
)
|
||||
candidate = self.downloads_dir / "emails.csv"
|
||||
if candidate.exists():
|
||||
print(f"✅ Downloaded emails.csv: {candidate}")
|
||||
return str(candidate)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
|
||||
)
|
||||
print(
|
||||
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
|
||||
)
|
||||
raise e
|
||||
|
||||
# ----------------------------
|
||||
# Data preparation
|
||||
# ----------------------------
|
||||
@staticmethod
|
||||
def _extract_message_id(raw_email: str) -> str:
|
||||
msg = message_from_string(raw_email, policy=default)
|
||||
val = msg.get("Message-ID", "")
|
||||
if val.startswith("<") and val.endswith(">"):
|
||||
val = val[1:-1]
|
||||
return val or ""
|
||||
|
||||
@staticmethod
|
||||
def _split_header_body(raw_email: str) -> tuple[str, str]:
|
||||
parts = raw_email.split("\n\n", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[0].strip(), parts[1].strip()
|
||||
# Heuristic fallback
|
||||
first_lines = raw_email.splitlines()
|
||||
if first_lines and ":" in first_lines[0]:
|
||||
return raw_email.strip(), ""
|
||||
return "", raw_email.strip()
|
||||
|
||||
@staticmethod
|
||||
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
if chunk_words <= 0:
|
||||
return [text]
|
||||
words = text.split()
|
||||
if not words:
|
||||
return []
|
||||
limit = len(words)
|
||||
if not keep_last:
|
||||
limit = (len(words) // chunk_words) * chunk_words
|
||||
if limit == 0:
|
||||
return []
|
||||
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
|
||||
return [c for c in (s.strip() for s in chunks) if c]
|
||||
|
||||
def _iter_passages_from_csv(
|
||||
self,
|
||||
emails_csv: Path,
|
||||
chunk_words: int = 256,
|
||||
keep_last_header: bool = True,
|
||||
keep_last_body: bool = True,
|
||||
max_emails: int | None = None,
|
||||
) -> Iterable[dict]:
|
||||
with open(emails_csv, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
count = 0
|
||||
for i, row in enumerate(reader):
|
||||
if max_emails is not None and count >= max_emails:
|
||||
break
|
||||
|
||||
raw_message = row.get("message", "")
|
||||
email_file_id = row.get("file", "")
|
||||
|
||||
if not raw_message.strip():
|
||||
continue
|
||||
|
||||
message_id = self._extract_message_id(raw_message)
|
||||
if not message_id:
|
||||
# Fallback ID based on CSV position and file path
|
||||
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
|
||||
message_id = f"enron_{i}_{safe_file}"
|
||||
|
||||
header, body = self._split_header_body(raw_message)
|
||||
|
||||
# Header chunks
|
||||
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
|
||||
yield {
|
||||
"text": chunk,
|
||||
"metadata": {
|
||||
"message_id": message_id,
|
||||
"is_header": True,
|
||||
"email_file_id": email_file_id,
|
||||
},
|
||||
}
|
||||
|
||||
# Body chunks
|
||||
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
|
||||
yield {
|
||||
"text": chunk,
|
||||
"metadata": {
|
||||
"message_id": message_id,
|
||||
"is_header": False,
|
||||
"email_file_id": email_file_id,
|
||||
},
|
||||
}
|
||||
|
||||
count += 1
|
||||
|
||||
# ----------------------------
|
||||
# Build LEANN index and FAISS baseline
|
||||
# ----------------------------
|
||||
def build_leann_index(
|
||||
self,
|
||||
emails_csv: Optional[str],
|
||||
backend: str = "hnsw",
|
||||
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
chunk_words: int = 256,
|
||||
max_emails: int | None = None,
|
||||
) -> str:
|
||||
emails_csv_path = self.ensure_emails_csv(emails_csv)
|
||||
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=True,
|
||||
is_compact=True,
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Stream passages and add to builder
|
||||
preview_written = 0
|
||||
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
|
||||
for p in self._iter_passages_from_csv(
|
||||
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
|
||||
):
|
||||
builder.add_text(p["text"], metadata=p["metadata"])
|
||||
if preview_written < 200:
|
||||
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
|
||||
preview_written += 1
|
||||
|
||||
print(f"🔨 Building index at {self.index_path}...")
|
||||
builder.build_index(str(self.index_path))
|
||||
print("✅ LEANN index built!")
|
||||
return str(self.index_path)
|
||||
|
||||
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
|
||||
print("🔨 Building FAISS Flat baseline from LEANN passages...")
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from leann.api import compute_embeddings
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Read meta for passage source and embedding model
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
embedding_model = meta["embedding_model"]
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
if not os.path.isabs(passage_file):
|
||||
index_dir = os.path.dirname(index_path)
|
||||
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||
|
||||
# Load passages from builder output so IDs match LEANN
|
||||
passages: list[str] = []
|
||||
passage_ids: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"]) # builder-assigned ID
|
||||
|
||||
print(f"📄 Loaded {len(passages)} passages for baseline")
|
||||
print(f"🤖 Embedding model: {embedding_model}")
|
||||
|
||||
embeddings = compute_embeddings(
|
||||
passages,
|
||||
embedding_model,
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
)
|
||||
|
||||
# Build FAISS IndexFlatIP
|
||||
dim = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dim)
|
||||
emb_f32 = embeddings.astype(np.float32)
|
||||
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
|
||||
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as pf:
|
||||
pickle.dump(passage_ids, pf)
|
||||
|
||||
print(f"✅ FAISS baseline saved: {baseline_path}")
|
||||
print(f"✅ Metadata saved: {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
return baseline_path
|
||||
|
||||
# ----------------------------
|
||||
# Queries (optional): prepare evaluation queries file
|
||||
# ----------------------------
|
||||
def prepare_queries(self, min_realism: float = 0.85) -> Path:
|
||||
print(
|
||||
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
|
||||
)
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to load dataset: {e}")
|
||||
return self.queries_file
|
||||
|
||||
kept = 0
|
||||
with open(self.queries_file, "w", encoding="utf-8") as out:
|
||||
for i, item in enumerate(ds):
|
||||
how_realistic = float(item.get("how_realistic", 0.0))
|
||||
if how_realistic < min_realism:
|
||||
continue
|
||||
qid = str(item.get("id", f"enron_q_{i}"))
|
||||
query = item.get("question", "")
|
||||
if not query:
|
||||
continue
|
||||
record = {
|
||||
"id": qid,
|
||||
"query": query,
|
||||
# For reference only, not used in recall metric below
|
||||
"gt_message_ids": item.get("message_ids", []),
|
||||
}
|
||||
out.write(json.dumps(record) + "\n")
|
||||
kept += 1
|
||||
print(f"✅ Wrote {kept} queries to {self.queries_file}")
|
||||
return self.queries_file
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
|
||||
parser.add_argument(
|
||||
"--emails-csv",
|
||||
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
|
||||
)
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model for LEANN",
|
||||
)
|
||||
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
|
||||
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
|
||||
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setup = EnronSetup(args.data_dir)
|
||||
|
||||
# Build index
|
||||
if not args.skip_build:
|
||||
index_path = setup.build_leann_index(
|
||||
emails_csv=args.emails_csv,
|
||||
backend=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
chunk_words=args.chunk_words,
|
||||
max_emails=args.max_emails,
|
||||
)
|
||||
|
||||
# Build FAISS baseline from the same passages & embeddings
|
||||
setup.build_faiss_flat_baseline(index_path)
|
||||
else:
|
||||
print("⏭️ Skipping LEANN index build and baseline")
|
||||
|
||||
# Queries file (optional)
|
||||
if not args.skip_queries:
|
||||
setup.prepare_queries()
|
||||
else:
|
||||
print("⏭️ Skipping query preparation")
|
||||
|
||||
print("\n🎉 Enron Emails setup completed!")
|
||||
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||
print("Next steps:")
|
||||
print(
|
||||
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,115 +0,0 @@
|
||||
# FinanceBench Benchmark for LEANN-RAG
|
||||
|
||||
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
|
||||
|
||||
## Dataset
|
||||
|
||||
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
|
||||
- **Questions**: 150 financial Q&A examples
|
||||
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
|
||||
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
|
||||
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
benchmarks/financebench/
|
||||
├── setup_financebench.py # Downloads PDFs and builds index
|
||||
├── evaluate_financebench.py # Intelligent evaluation script
|
||||
├── data/
|
||||
│ ├── financebench_merged.jsonl # Q&A dataset
|
||||
│ ├── pdfs/ # Downloaded financial documents
|
||||
│ └── index/ # LEANN indexes
|
||||
│ └── financebench_full_hnsw.leann
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Setup (Download & Build Index)
|
||||
|
||||
```bash
|
||||
cd benchmarks/financebench
|
||||
python setup_financebench.py
|
||||
```
|
||||
|
||||
This will:
|
||||
- Download the 150 Q&A examples
|
||||
- Download all 368 PDF documents (parallel processing)
|
||||
- Build a LEANN index from 53K+ text chunks
|
||||
- Verify setup with test query
|
||||
|
||||
### 2. Evaluation
|
||||
|
||||
```bash
|
||||
# Basic retrieval evaluation
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
|
||||
|
||||
|
||||
# RAG generation evaluation with Qwen3-8B
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
|
||||
```
|
||||
|
||||
## Evaluation Methods
|
||||
|
||||
### Retrieval Evaluation
|
||||
Uses intelligent matching with three strategies:
|
||||
1. **Exact text overlap** - Direct substring matches
|
||||
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
|
||||
3. **Semantic similarity** - Word overlap with 20% threshold
|
||||
|
||||
### QA Evaluation
|
||||
LLM-based answer evaluation using GPT-4o:
|
||||
- Handles numerical rounding and equivalent representations
|
||||
- Considers fractions, percentages, and decimal equivalents
|
||||
- Evaluates semantic meaning rather than exact text match
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
|
||||
|
||||
**Retrieval Metrics:**
|
||||
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
|
||||
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
|
||||
- **Number Match Rate**: 120.7% (key financial figures matched)*
|
||||
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
|
||||
- **Average Search Time**: 0.097s
|
||||
|
||||
**QA Metrics:**
|
||||
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
|
||||
- **Average QA Time**: 4.71s (end-to-end response time)
|
||||
|
||||
**System Performance:**
|
||||
- **Index Size**: 53,985 chunks from 368 PDFs
|
||||
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
|
||||
|
||||
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
|
||||
|
||||
### LEANN-RAG Generation Performance (Qwen3-8B)
|
||||
|
||||
- **Stage 4 (Index Comparison):**
|
||||
- Compact Index: 5.0 MB
|
||||
- Non-compact Index: 172.2 MB
|
||||
- **Storage Saving**: 97.1%
|
||||
- **Search Performance**:
|
||||
- Non-compact (no recompute): 0.009s avg per query
|
||||
- Compact (with recompute): 2.203s avg per query
|
||||
- Speed ratio: 0.004x
|
||||
|
||||
**Generation Evaluation (20 queries, complexity=64):**
|
||||
- **Average Search Time**: 1.638s per query
|
||||
- **Average Generation Time**: 45.957s per query
|
||||
- **LLM Backend**: HuggingFace transformers
|
||||
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||
- **Total Questions Processed**: 20
|
||||
|
||||
## Options
|
||||
|
||||
```bash
|
||||
# Use different backends
|
||||
python setup_financebench.py --backend diskann
|
||||
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
|
||||
|
||||
# Use different embedding models
|
||||
python setup_financebench.py --embedding-model facebook/contriever
|
||||
```
|
||||
@@ -1,923 +0,0 @@
|
||||
"""
|
||||
FinanceBench Evaluation Script - Modular Recall-based Evaluation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
from leann import LeannChat, LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
# Load FAISS flat baseline
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.passage_ids = pickle.load(f)
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 for given queries at specified complexity"""
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = len(queries)
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
# Get ground truth: search with FAISS flat
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
query_embedding = compute_embeddings(
|
||||
[query],
|
||||
self.searcher.embedding_model,
|
||||
mode=self.searcher.embedding_mode,
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||
n = query_embedding.shape[0] # Number of queries
|
||||
k = 3 # Number of nearest neighbors
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(query_embedding),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
# Extract the results
|
||||
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN at specified complexity
|
||||
test_results = self.searcher.search(
|
||||
query,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {result.id for result in test_results}
|
||||
|
||||
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||
return avg_recall
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class FinanceBenchEvaluator:
|
||||
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
|
||||
self.index_path = index_path
|
||||
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
|
||||
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
self.chat = LeannChat(index_path) if openai_api_key else None
|
||||
|
||||
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
|
||||
"""Load FinanceBench dataset"""
|
||||
data = []
|
||||
with open(dataset_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data.append(json.loads(line))
|
||||
|
||||
print(f"📊 Loaded {len(data)} FinanceBench examples")
|
||||
return data
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes with and without embeddings"""
|
||||
|
||||
print("📏 Analyzing index sizes...")
|
||||
|
||||
# Get all index-related files
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem # Remove .leann extension
|
||||
|
||||
sizes = {}
|
||||
total_with_embeddings = 0
|
||||
|
||||
# Core index files
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||
|
||||
for file_path, name in [
|
||||
(index_file, "index"),
|
||||
(meta_file, "metadata"),
|
||||
(passages_file, "passages_text"),
|
||||
(passages_idx_file, "passages_index"),
|
||||
]:
|
||||
if file_path.exists():
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
sizes[name] = size_mb
|
||||
total_with_embeddings += size_mb
|
||||
|
||||
else:
|
||||
sizes[name] = 0
|
||||
|
||||
sizes["total_with_embeddings"] = total_with_embeddings
|
||||
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
|
||||
|
||||
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
|
||||
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
|
||||
|
||||
return sizes
|
||||
|
||||
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
|
||||
"""Create a compact index for comparison purposes"""
|
||||
print("🏗️ Building compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Build compact index with same passages
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||
is_compact=True, # Enable compact storage
|
||||
**meta.get("backend_kwargs", {}),
|
||||
)
|
||||
|
||||
# Load all passages
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||
|
||||
print(f"🔨 Building compact index at {compact_index_path}...")
|
||||
builder.build_index(compact_index_path)
|
||||
|
||||
# Analyze the compact index size
|
||||
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
|
||||
compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
compact_sizes["index_type"] = "compact"
|
||||
|
||||
return compact_sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison purposes"""
|
||||
print("🏗️ Building non-compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Build non-compact index with same passages
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=False, # Disable recompute (store embeddings)
|
||||
is_compact=False, # Disable compact storage
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact"]
|
||||
},
|
||||
)
|
||||
|
||||
# Load all passages
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||
|
||||
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
|
||||
builder.build_index(non_compact_index_path)
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
|
||||
) -> dict:
|
||||
"""Compare performance between non-compact and compact indexes"""
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
|
||||
import time
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
# Test queries
|
||||
test_queries = [item["question"] for item in test_data[:5]]
|
||||
|
||||
results = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
}
|
||||
|
||||
# Test non-compact index (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
query, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["non_compact"]["search_times"].append(search_time)
|
||||
|
||||
# Test compact index (with recompute)
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
_ = compact_searcher.search(
|
||||
query, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["compact"]["search_times"].append(search_time)
|
||||
|
||||
# Calculate averages
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
|
||||
# Performance ratio
|
||||
if results["avg_search_times"]["compact"] > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = float("inf")
|
||||
|
||||
print(
|
||||
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||
)
|
||||
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||
|
||||
# Cleanup
|
||||
non_compact_searcher.cleanup()
|
||||
compact_searcher.cleanup()
|
||||
|
||||
return results
|
||||
|
||||
def evaluate_timing_breakdown(
|
||||
self, data: list[dict], max_samples: Optional[int] = None
|
||||
) -> dict:
|
||||
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
|
||||
if not self.chat or not self.openai_client:
|
||||
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
|
||||
return {
|
||||
"total_questions": 0,
|
||||
"avg_search_time": 0.0,
|
||||
"avg_generation_time": 0.0,
|
||||
"avg_total_time": 0.0,
|
||||
"accuracy": 0.0,
|
||||
}
|
||||
|
||||
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
|
||||
|
||||
if max_samples:
|
||||
data = data[:max_samples]
|
||||
print(f"📝 Using first {max_samples} samples for timing evaluation")
|
||||
|
||||
search_times = []
|
||||
generation_times = []
|
||||
total_times = []
|
||||
correct_answers = 0
|
||||
|
||||
for i, item in enumerate(data):
|
||||
question = item["question"]
|
||||
ground_truth = item["answer"]
|
||||
|
||||
try:
|
||||
# Hack: Monkey-patch the ask method to capture internal timing
|
||||
original_ask = self.chat.ask
|
||||
captured_search_time = None
|
||||
captured_generation_time = None
|
||||
|
||||
def patched_ask(*args, **kwargs):
|
||||
nonlocal captured_search_time, captured_generation_time
|
||||
|
||||
# Time the search part
|
||||
search_start = time.time()
|
||||
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
|
||||
captured_search_time = time.time() - search_start
|
||||
|
||||
# Time the generation part
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
prompt = (
|
||||
"Here is some retrieved context that might help answer your question:\n\n"
|
||||
f"{context}\n\n"
|
||||
f"Question: {args[0]}\n\n"
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
generation_start = time.time()
|
||||
answer = self.chat.llm.ask(prompt)
|
||||
captured_generation_time = time.time() - generation_start
|
||||
|
||||
return answer
|
||||
|
||||
# Apply the patch
|
||||
self.chat.ask = patched_ask
|
||||
|
||||
# Time the total QA
|
||||
total_start = time.time()
|
||||
generated_answer = self.chat.ask(question)
|
||||
total_time = time.time() - total_start
|
||||
|
||||
# Restore original method
|
||||
self.chat.ask = original_ask
|
||||
|
||||
# Store the timings
|
||||
search_times.append(captured_search_time)
|
||||
generation_times.append(captured_generation_time)
|
||||
total_times.append(total_time)
|
||||
|
||||
# Check accuracy using LLM as judge
|
||||
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
|
||||
if is_correct:
|
||||
correct_answers += 1
|
||||
|
||||
status = "✅" if is_correct else "❌"
|
||||
print(
|
||||
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
|
||||
)
|
||||
print(f" GT: {ground_truth}")
|
||||
print(f" Gen: {generated_answer[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error: {e}")
|
||||
search_times.append(0.0)
|
||||
generation_times.append(0.0)
|
||||
total_times.append(0.0)
|
||||
|
||||
accuracy = correct_answers / len(data) if data else 0.0
|
||||
|
||||
metrics = {
|
||||
"total_questions": len(data),
|
||||
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
|
||||
"avg_generation_time": sum(generation_times) / len(generation_times)
|
||||
if generation_times
|
||||
else 0.0,
|
||||
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
|
||||
"accuracy": accuracy,
|
||||
"correct_answers": correct_answers,
|
||||
"search_times": search_times,
|
||||
"generation_times": generation_times,
|
||||
"total_times": total_times,
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
def _check_answer_accuracy(
|
||||
self, generated_answer: str, ground_truth: str, question: str
|
||||
) -> bool:
|
||||
"""Check if generated answer matches ground truth using LLM as judge"""
|
||||
judge_prompt = f"""You are an expert judge evaluating financial question answering.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Ground Truth Answer: {ground_truth}
|
||||
|
||||
Generated Answer: {generated_answer}
|
||||
|
||||
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
|
||||
1. Numerical accuracy (exact values, units, currency)
|
||||
2. Key financial concepts and terminology
|
||||
3. Overall factual correctness
|
||||
|
||||
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
|
||||
|
||||
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
|
||||
|
||||
try:
|
||||
judge_response = self.openai_client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
)
|
||||
judgment = judge_response.choices[0].message.content.strip().upper()
|
||||
return judgment == "CORRECT"
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Judge error: {e}, falling back to string matching")
|
||||
# Fallback to simple string matching
|
||||
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
|
||||
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
|
||||
return gt_clean in gen_clean
|
||||
|
||||
def _print_results(self, timing_metrics: dict):
|
||||
"""Print evaluation results"""
|
||||
print("\n🎯 EVALUATION RESULTS")
|
||||
print("=" * 50)
|
||||
|
||||
# Index comparison analysis
|
||||
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||
print("\n📏 Index Comparison Analysis:")
|
||||
current = timing_metrics["current_index"]
|
||||
non_compact = timing_metrics["non_compact_index"]
|
||||
|
||||
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
|
||||
print(
|
||||
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
|
||||
print(" Component breakdown (non-compact):")
|
||||
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||
|
||||
# Performance comparison
|
||||
if "performance_comparison" in timing_metrics:
|
||||
perf = timing_metrics["performance_comparison"]
|
||||
print("\n⚡ Performance Comparison:")
|
||||
print(
|
||||
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||
)
|
||||
print(
|
||||
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||
)
|
||||
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||
|
||||
# Legacy single index analysis (fallback)
|
||||
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||
print("\n📏 Index Size Analysis:")
|
||||
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
|
||||
|
||||
print("\n📊 Accuracy:")
|
||||
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
|
||||
print(
|
||||
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
|
||||
)
|
||||
|
||||
print("\n📊 Timing Breakdown:")
|
||||
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
|
||||
|
||||
if timing_metrics.get("avg_total_time", 0) > 0:
|
||||
search_pct = (
|
||||
timing_metrics.get("avg_search_time", 0)
|
||||
/ timing_metrics.get("avg_total_time", 1)
|
||||
* 100
|
||||
)
|
||||
gen_pct = (
|
||||
timing_metrics.get("avg_generation_time", 0)
|
||||
/ timing_metrics.get("avg_total_time", 1)
|
||||
* 100
|
||||
)
|
||||
print("\n📈 Time Distribution:")
|
||||
print(f" Search: {search_pct:.1f}%")
|
||||
print(f" Generation: {gen_pct:.1f}%")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "all"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument(
|
||||
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
|
||||
)
|
||||
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Check if baseline exists
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_financebench.py first to build the baseline")
|
||||
exit(1)
|
||||
|
||||
if args.stage == "2" or args.stage == "all":
|
||||
# Stage 2: Recall@3 evaluation
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
# Load FinanceBench queries for testing
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
queries = []
|
||||
with open(args.dataset, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
|
||||
# Test with more queries for robust measurement
|
||||
test_queries = queries[:2000]
|
||||
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||
|
||||
# Test with complexity 64
|
||||
complexity = 64
|
||||
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
|
||||
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
# Shared non-compact index path for Stage 3 and 4
|
||||
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||
complexity = args.complexity
|
||||
|
||||
if args.stage == "3" or args.stage == "all":
|
||||
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
|
||||
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||
print(
|
||||
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||
)
|
||||
|
||||
# Create non-compact index for binary search (will be reused in Stage 4)
|
||||
print("🏗️ Creating non-compact index for binary search...")
|
||||
evaluator = FinanceBenchEvaluator(args.index)
|
||||
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact index for binary search
|
||||
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
|
||||
# Load queries for testing
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
queries = []
|
||||
with open(args.dataset, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
|
||||
# Use more queries for robust measurement
|
||||
test_queries = queries[:200]
|
||||
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||
|
||||
# Binary search for 90% recall complexity (without recompute for speed)
|
||||
target_recall = 0.9
|
||||
min_complexity, max_complexity = 1, 32
|
||||
|
||||
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||
|
||||
best_complexity = None
|
||||
best_recall = 0.0
|
||||
|
||||
while min_complexity <= max_complexity:
|
||||
mid_complexity = (min_complexity + max_complexity) // 2
|
||||
|
||||
print(
|
||||
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||
)
|
||||
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_queries, mid_complexity, recompute_embeddings=False
|
||||
)
|
||||
|
||||
print(
|
||||
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||
)
|
||||
|
||||
if recall >= target_recall:
|
||||
best_complexity = mid_complexity
|
||||
best_recall = recall
|
||||
max_complexity = mid_complexity - 1
|
||||
print(" ✅ Target reached! Searching for lower complexity...")
|
||||
else:
|
||||
min_complexity = mid_complexity + 1
|
||||
print(" ❌ Below target. Searching for higher complexity...")
|
||||
|
||||
if best_complexity is not None:
|
||||
print("\n🎯 Optimal complexity found!")
|
||||
print(f" Complexity: {best_complexity}")
|
||||
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||
|
||||
# Test a few complexities around the optimal one for verification
|
||||
print("\n🔬 Verification test around optimal complexity:")
|
||||
verification_complexities = [
|
||||
max(1, best_complexity - 2),
|
||||
max(1, best_complexity - 1),
|
||||
best_complexity,
|
||||
best_complexity + 1,
|
||||
best_complexity + 2,
|
||||
]
|
||||
|
||||
for complexity in verification_complexities:
|
||||
if complexity <= 512: # reasonable upper bound
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_queries, complexity, recompute_embeddings=False
|
||||
)
|
||||
status = "✅" if recall >= target_recall else "❌"
|
||||
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||
|
||||
# Now test the optimal complexity with compact index and recompute for comparison
|
||||
print(
|
||||
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||
)
|
||||
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||
test_queries[:10], best_complexity, recompute_embeddings=True
|
||||
)
|
||||
print(
|
||||
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||
)
|
||||
complexity = best_complexity
|
||||
print(
|
||||
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||
)
|
||||
compact_evaluator.cleanup()
|
||||
else:
|
||||
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||
print("All tested complexities were below target.")
|
||||
|
||||
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||
binary_search_evaluator.cleanup()
|
||||
evaluator.cleanup()
|
||||
|
||||
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||
|
||||
if args.stage == "4" or args.stage == "all":
|
||||
# Stage 4: Comprehensive evaluation with dual index comparison
|
||||
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
|
||||
|
||||
# Use FinanceBench evaluator for QA evaluation
|
||||
evaluator = FinanceBenchEvaluator(
|
||||
args.index, args.openai_api_key if args.llm_backend == "openai" else None
|
||||
)
|
||||
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
data = evaluator.load_dataset(args.dataset)
|
||||
|
||||
# Step 1: Analyze current (compact) index
|
||||
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||
compact_size_metrics["index_type"] = "compact"
|
||||
|
||||
# Step 2: Use existing non-compact index or create if needed
|
||||
from pathlib import Path
|
||||
|
||||
if Path(non_compact_index_path).exists():
|
||||
print(
|
||||
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||
)
|
||||
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_size_metrics["index_type"] = "non_compact"
|
||||
else:
|
||||
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||
non_compact_index_path
|
||||
)
|
||||
|
||||
# Step 3: Compare index sizes
|
||||
print("\n📊 Index size comparison:")
|
||||
print(
|
||||
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||
)
|
||||
print("\n📊 Index-only size comparison (.index file only):")
|
||||
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
|
||||
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
|
||||
# Use index-only size for fair comparison (same as Enron emails)
|
||||
storage_saving = (
|
||||
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
|
||||
/ non_compact_size_metrics["index_only_mb"]
|
||||
* 100
|
||||
)
|
||||
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||
|
||||
# Step 4: Performance comparison between the two indexes
|
||||
if complexity is None:
|
||||
raise ValueError("Complexity is required for performance comparison")
|
||||
|
||||
print("\n⚡ Performance comparison between indexes...")
|
||||
performance_metrics = evaluator.compare_index_performance(
|
||||
non_compact_index_path, args.index, data[:10], complexity=complexity
|
||||
)
|
||||
|
||||
# Step 5: Generation evaluation
|
||||
test_samples = 20
|
||||
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
|
||||
|
||||
if args.llm_backend == "openai" and args.openai_api_key:
|
||||
print("🔍🤖 Running OpenAI-based generation evaluation...")
|
||||
evaluation_start = time.time()
|
||||
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
|
||||
evaluation_time = time.time() - evaluation_start
|
||||
else:
|
||||
print(
|
||||
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
|
||||
)
|
||||
try:
|
||||
# Load LLM
|
||||
if args.llm_backend == "hf":
|
||||
tokenizer, model = load_hf_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_hf(tokenizer, model, prompt)
|
||||
else: # vllm
|
||||
llm, sampling_params = load_vllm_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_vllm(llm, sampling_params, prompt)
|
||||
|
||||
# Simple generation evaluation
|
||||
queries = [item["question"] for item in data[:test_samples]]
|
||||
gen_results = evaluate_rag(
|
||||
evaluator.searcher,
|
||||
llm_func,
|
||||
queries,
|
||||
domain="finance",
|
||||
complexity=complexity,
|
||||
)
|
||||
|
||||
timing_metrics = {
|
||||
"total_questions": len(queries),
|
||||
"avg_search_time": gen_results["avg_search_time"],
|
||||
"avg_generation_time": gen_results["avg_generation_time"],
|
||||
"results": gen_results["results"],
|
||||
}
|
||||
evaluation_time = time.time()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation evaluation failed: {e}")
|
||||
timing_metrics = {
|
||||
"total_questions": 0,
|
||||
"avg_search_time": 0,
|
||||
"avg_generation_time": 0,
|
||||
}
|
||||
evaluation_time = 0
|
||||
|
||||
# Combine all metrics
|
||||
combined_metrics = {
|
||||
**timing_metrics,
|
||||
"total_evaluation_time": evaluation_time,
|
||||
"current_index": compact_size_metrics,
|
||||
"non_compact_index": non_compact_size_metrics,
|
||||
"performance_comparison": performance_metrics,
|
||||
"storage_saving_percent": storage_saving,
|
||||
}
|
||||
|
||||
# Print results
|
||||
print("\n📊 Generation Results:")
|
||||
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||
|
||||
# Save results if requested
|
||||
if args.output:
|
||||
print(f"\n💾 Saving results to {args.output}...")
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(combined_metrics, f, indent=2, default=str)
|
||||
print(f"✅ Results saved to {args.output}")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage == "all":
|
||||
print("🎉 All evaluation stages completed successfully!")
|
||||
print("\n📋 Summary:")
|
||||
print(" Stage 2: ✅ Recall@3 evaluation completed")
|
||||
print(" Stage 3: ✅ Optimal complexity found")
|
||||
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
|
||||
print("\n🔧 Recommended next steps:")
|
||||
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||
print(" - Review accuracy and timing breakdown for performance optimization")
|
||||
print(" - Run full evaluation on complete dataset if needed")
|
||||
|
||||
# Clean up non-compact index after all stages complete
|
||||
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||
from pathlib import Path
|
||||
|
||||
if Path(non_compact_index_path).exists():
|
||||
temp_index_dir = Path(non_compact_index_path).parent
|
||||
temp_index_name = Path(non_compact_index_path).name
|
||||
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||
temp_file.unlink()
|
||||
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||
else:
|
||||
print("📝 No temporary index to clean up")
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Evaluation interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,462 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FinanceBench Complete Setup Script
|
||||
Downloads all PDFs and builds full LEANN datastore
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import pymupdf
|
||||
import requests
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class FinanceBenchSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.base_dir = Path(__file__).parent # benchmarks/financebench/
|
||||
self.data_dir = self.base_dir / data_dir
|
||||
self.pdf_dir = self.data_dir / "pdfs"
|
||||
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
|
||||
self.index_dir = self.data_dir / "index"
|
||||
self.download_lock = Lock()
|
||||
|
||||
def download_dataset(self):
|
||||
"""Download the main FinanceBench dataset"""
|
||||
print("📊 Downloading FinanceBench dataset...")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.dataset_file.exists():
|
||||
print(f"✅ Dataset already exists: {self.dataset_file}")
|
||||
return
|
||||
|
||||
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(self.dataset_file, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"✅ Dataset downloaded: {self.dataset_file}")
|
||||
|
||||
def get_pdf_list(self):
|
||||
"""Get list of all PDF files from GitHub"""
|
||||
print("📋 Fetching PDF list from GitHub...")
|
||||
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
|
||||
)
|
||||
response.raise_for_status()
|
||||
pdf_files = response.json()
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files")
|
||||
return pdf_files
|
||||
|
||||
def download_single_pdf(self, pdf_info, position):
|
||||
"""Download a single PDF file"""
|
||||
pdf_name = pdf_info["name"]
|
||||
pdf_path = self.pdf_dir / pdf_name
|
||||
|
||||
# Skip if already downloaded
|
||||
if pdf_path.exists() and pdf_path.stat().st_size > 0:
|
||||
return f"✅ {pdf_name} (cached)"
|
||||
|
||||
try:
|
||||
# Download PDF
|
||||
response = requests.get(pdf_info["download_url"], timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Write to file
|
||||
with self.download_lock:
|
||||
with open(pdf_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return f"✅ {pdf_name} ({len(response.content) // 1024}KB)"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ {pdf_name}: {e!s}"
|
||||
|
||||
def download_all_pdfs(self, max_workers: int = 5):
|
||||
"""Download all PDF files with parallel processing"""
|
||||
self.pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_files = self.get_pdf_list()
|
||||
|
||||
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all download tasks
|
||||
future_to_pdf = {
|
||||
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
|
||||
for i, pdf_info in enumerate(pdf_files)
|
||||
}
|
||||
|
||||
# Process completed downloads with progress bar
|
||||
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
|
||||
for future in as_completed(future_to_pdf):
|
||||
result = future.result()
|
||||
pbar.set_postfix_str(result.split()[-1] if "✅" in result else "Error")
|
||||
pbar.update(1)
|
||||
|
||||
# Verify downloads
|
||||
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
|
||||
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
|
||||
|
||||
# Show any failures
|
||||
missing_pdfs = []
|
||||
for pdf_info in pdf_files:
|
||||
pdf_path = self.pdf_dir / pdf_info["name"]
|
||||
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
|
||||
missing_pdfs.append(pdf_info["name"])
|
||||
|
||||
if missing_pdfs:
|
||||
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
|
||||
for pdf in missing_pdfs[:5]: # Show first 5
|
||||
print(f" - {pdf}")
|
||||
if len(missing_pdfs) > 5:
|
||||
print(f" ... and {len(missing_pdfs) - 5} more")
|
||||
|
||||
def build_leann_index(
|
||||
self,
|
||||
backend: str = "hnsw",
|
||||
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
):
|
||||
"""Build LEANN index from all PDFs"""
|
||||
print(f"🏗️ Building LEANN index with {backend} backend...")
|
||||
|
||||
# Check if we have PDFs
|
||||
pdf_files = list(self.pdf_dir.glob("*.pdf"))
|
||||
if not pdf_files:
|
||||
raise RuntimeError("No PDF files found! Run download first.")
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files to process")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize builder with standard compact configuration
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||
is_compact=True, # Enable compact storage (pruned)
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Process PDFs and extract text
|
||||
total_chunks = 0
|
||||
failed_pdfs = []
|
||||
|
||||
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
|
||||
try:
|
||||
chunks = self.extract_pdf_text(pdf_path)
|
||||
for chunk in chunks:
|
||||
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||
total_chunks += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to process {pdf_path.name}: {e}")
|
||||
failed_pdfs.append(pdf_path.name)
|
||||
continue
|
||||
|
||||
# Build index in index directory
|
||||
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||
index_path = self.index_dir / f"financebench_full_{backend}.leann"
|
||||
print(f"🔨 Building index: {index_path}")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
|
||||
print("✅ Index built successfully!")
|
||||
print(f" 📁 Index path: {index_path}")
|
||||
print(f" 📊 Total chunks: {total_chunks:,}")
|
||||
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
|
||||
print(f" ⏱️ Build time: {build_time:.1f}s")
|
||||
|
||||
if failed_pdfs:
|
||||
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
|
||||
|
||||
return str(index_path)
|
||||
|
||||
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
|
||||
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
|
||||
print("🔨 Building FAISS Flat baseline...")
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from leann.api import compute_embeddings
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Read metadata from the built index
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.loads(f.read())
|
||||
|
||||
embedding_model = meta["embedding_model"]
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not os.path.isabs(passage_file):
|
||||
index_dir = os.path.dirname(index_path)
|
||||
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||
|
||||
print(f"📊 Loading passages from {passage_file}...")
|
||||
print(f"🤖 Using embedding model: {embedding_model}")
|
||||
|
||||
# Load all passages for baseline
|
||||
passages = []
|
||||
passage_ids = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"])
|
||||
|
||||
print(f"📄 Loaded {len(passages)} passages")
|
||||
|
||||
# Compute embeddings using the same method as LEANN
|
||||
print("🧮 Computing embeddings...")
|
||||
embeddings = compute_embeddings(
|
||||
passages,
|
||||
embedding_model,
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
)
|
||||
|
||||
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||
|
||||
# Build FAISS flat index
|
||||
print("🏗️ Building FAISS IndexFlatIP...")
|
||||
dimension = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
# Add embeddings to flat index
|
||||
embeddings_f32 = embeddings.astype(np.float32)
|
||||
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||
|
||||
# Save index and metadata
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as f:
|
||||
pickle.dump(passage_ids, f)
|
||||
|
||||
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||
print(f"✅ Metadata saved to {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
|
||||
return baseline_path
|
||||
|
||||
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
|
||||
"""Extract and chunk text from a PDF file"""
|
||||
chunks = []
|
||||
doc = pymupdf.open(pdf_path)
|
||||
|
||||
for page_num in range(len(doc)):
|
||||
page = doc[page_num]
|
||||
text = page.get_text() # type: ignore
|
||||
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source_file": pdf_path.name,
|
||||
"page_number": page_num + 1,
|
||||
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
|
||||
"company": pdf_path.name.split("_")[0],
|
||||
"doc_period": self.extract_year_from_filename(pdf_path.name),
|
||||
}
|
||||
|
||||
# Use recursive character splitting like LangChain
|
||||
if len(text.split()) > 500:
|
||||
# Split by double newlines (paragraphs)
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
|
||||
current_chunk = ""
|
||||
for para in paragraphs:
|
||||
# If adding this paragraph would make chunk too long, save current chunk
|
||||
if current_chunk and len((current_chunk + " " + para).split()) > 300:
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
current_chunk = para
|
||||
else:
|
||||
current_chunk = (current_chunk + " " + para).strip()
|
||||
|
||||
# Add the last chunk
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Page is short enough, use as single chunk
|
||||
chunks.append(
|
||||
{
|
||||
"text": text.strip(),
|
||||
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
|
||||
}
|
||||
)
|
||||
|
||||
doc.close()
|
||||
return chunks
|
||||
|
||||
def extract_year_from_filename(self, filename: str) -> str:
|
||||
"""Extract year from PDF filename"""
|
||||
# Try to find 4-digit year in filename
|
||||
|
||||
match = re.search(r"(\d{4})", filename)
|
||||
return match.group(1) if match else "unknown"
|
||||
|
||||
def verify_setup(self, index_path: str):
|
||||
"""Verify the setup by testing a simple query"""
|
||||
print("🧪 Verifying setup with test query...")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
# Test query
|
||||
test_query = "What is the capital expenditure for 3M in 2018?"
|
||||
results = searcher.search(test_query, top_k=3)
|
||||
|
||||
print(f"✅ Test query successful! Found {len(results)} results:")
|
||||
for i, result in enumerate(results, 1):
|
||||
company = result.metadata.get("company", "Unknown")
|
||||
year = result.metadata.get("doc_period", "Unknown")
|
||||
page = result.metadata.get("page_number", "Unknown")
|
||||
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
|
||||
print(f" {result.text[:100]}...")
|
||||
|
||||
searcher.cleanup()
|
||||
print("✅ Setup verification completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Setup verification failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument(
|
||||
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model",
|
||||
)
|
||||
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
|
||||
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||
parser.add_argument(
|
||||
"--build-baseline-only",
|
||||
action="store_true",
|
||||
help="Only build FAISS baseline from existing index",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🏦 FinanceBench Complete Setup")
|
||||
print("=" * 50)
|
||||
|
||||
setup = FinanceBenchSetup(args.data_dir)
|
||||
|
||||
try:
|
||||
if args.build_baseline_only:
|
||||
# Only build baseline from existing index
|
||||
index_path = setup.index_dir / f"financebench_full_{args.backend}"
|
||||
index_file = f"{index_path}.index"
|
||||
meta_file = f"{index_path}.leann.meta.json"
|
||||
|
||||
if not os.path.exists(index_file) or not os.path.exists(meta_file):
|
||||
print("❌ Index files not found:")
|
||||
print(f" Index: {index_file}")
|
||||
print(f" Meta: {meta_file}")
|
||||
print("💡 Run without --build-baseline-only to build the index first")
|
||||
exit(1)
|
||||
|
||||
print(f"🔨 Building baseline from existing index: {index_path}")
|
||||
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
|
||||
print(f"✅ Baseline built at {baseline_path}")
|
||||
return
|
||||
|
||||
# Step 1: Download dataset
|
||||
setup.download_dataset()
|
||||
|
||||
# Step 2: Download PDFs
|
||||
if not args.skip_download:
|
||||
setup.download_all_pdfs(max_workers=args.max_workers)
|
||||
else:
|
||||
print("⏭️ Skipping PDF download")
|
||||
|
||||
# Step 3: Build LEANN index
|
||||
if not args.skip_build:
|
||||
index_path = setup.build_leann_index(
|
||||
backend=args.backend, embedding_model=args.embedding_model
|
||||
)
|
||||
|
||||
# Step 4: Build FAISS flat baseline
|
||||
print("\n🔨 Building FAISS flat baseline...")
|
||||
baseline_path = setup.build_faiss_flat_baseline(index_path)
|
||||
print(f"✅ Baseline built at {baseline_path}")
|
||||
|
||||
# Step 5: Verify setup
|
||||
setup.verify_setup(index_path)
|
||||
else:
|
||||
print("⏭️ Skipping index building")
|
||||
|
||||
print("\n🎉 FinanceBench setup completed!")
|
||||
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||
print("\nNext steps:")
|
||||
print(
|
||||
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
|
||||
)
|
||||
print(
|
||||
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Setup interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Setup failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,214 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.9"
|
||||
# dependencies = [
|
||||
# "faiss-cpu",
|
||||
# "numpy",
|
||||
# "sentence-transformers",
|
||||
# "torch",
|
||||
# "tqdm",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Independent recall verification script using standard FAISS.
|
||||
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
|
||||
"""
|
||||
Direct embedding computation using sentence-transformers.
|
||||
Copied logic to avoid dependency issues.
|
||||
"""
|
||||
print(f"Loading model: {model_name}")
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
print(f"Computing embeddings for {len(chunks)} chunks...")
|
||||
embeddings = model.encode(
|
||||
chunks,
|
||||
show_progress_bar=True,
|
||||
batch_size=32,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False,
|
||||
)
|
||||
|
||||
return embeddings.astype(np.float32)
|
||||
|
||||
|
||||
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
|
||||
"""Load FinanceBench queries from dataset"""
|
||||
queries = []
|
||||
with open(dataset_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
if len(queries) >= max_queries:
|
||||
break
|
||||
return queries
|
||||
|
||||
|
||||
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
|
||||
"""Load passages from LEANN index structure"""
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
index_dir = Path(index_path).parent
|
||||
passage_file = index_dir / Path(passage_file).name
|
||||
|
||||
print(f"Loading passages from {passage_file}")
|
||||
|
||||
passages = []
|
||||
passage_ids = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in tqdm(f, desc="Loading passages"):
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"])
|
||||
|
||||
print(f"Loaded {len(passages)} passages")
|
||||
return passages, passage_ids
|
||||
|
||||
|
||||
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
|
||||
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
|
||||
dimension = embeddings.shape[1]
|
||||
|
||||
# Build Flat index (ground truth)
|
||||
print("Building FAISS IndexFlatIP (ground truth)...")
|
||||
flat_index = faiss.IndexFlatIP(dimension)
|
||||
flat_index.add(embeddings)
|
||||
|
||||
# Build HNSW index
|
||||
print("Building FAISS IndexHNSWFlat...")
|
||||
M = 32 # Same as LEANN default
|
||||
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
|
||||
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
|
||||
hnsw_index.add(embeddings)
|
||||
|
||||
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
|
||||
return flat_index, hnsw_index
|
||||
|
||||
|
||||
def evaluate_recall_at_k(
|
||||
query_embeddings: np.ndarray,
|
||||
flat_index: faiss.Index,
|
||||
hnsw_index: faiss.Index,
|
||||
passage_ids: list[str],
|
||||
k: int = 3,
|
||||
ef_search: int = 64,
|
||||
) -> float:
|
||||
"""Evaluate recall@k comparing HNSW vs Flat"""
|
||||
|
||||
# Set search parameters for HNSW
|
||||
hnsw_index.hnsw.efSearch = ef_search
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = query_embeddings.shape[0]
|
||||
|
||||
for i in range(num_queries):
|
||||
query = query_embeddings[i : i + 1] # Keep 2D shape
|
||||
|
||||
# Get ground truth from Flat index (standard FAISS API)
|
||||
flat_distances, flat_indices = flat_index.search(query, k)
|
||||
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
|
||||
|
||||
# Get results from HNSW index (standard FAISS API)
|
||||
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
|
||||
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
|
||||
|
||||
# Calculate recall
|
||||
intersection = ground_truth_ids.intersection(hnsw_ids)
|
||||
recall = len(intersection) / k
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
|
||||
print(f" Flat: {list(ground_truth_ids)}")
|
||||
print(f" HNSW: {list(hnsw_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
return avg_recall
|
||||
|
||||
|
||||
def main():
|
||||
# Configuration
|
||||
dataset_path = "data/financebench_merged.jsonl"
|
||||
index_path = "data/index/financebench_full_hnsw.leann"
|
||||
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
print("🔍 FAISS Recall Verification")
|
||||
print("=" * 50)
|
||||
|
||||
# Check if files exist
|
||||
if not Path(dataset_path).exists():
|
||||
print(f"❌ Dataset not found: {dataset_path}")
|
||||
return
|
||||
if not Path(f"{index_path}.meta.json").exists():
|
||||
print(f"❌ Index metadata not found: {index_path}.meta.json")
|
||||
return
|
||||
|
||||
# Load data
|
||||
print("📖 Loading FinanceBench queries...")
|
||||
queries = load_financebench_queries(dataset_path, max_queries=50)
|
||||
print(f"Loaded {len(queries)} queries")
|
||||
|
||||
print("📄 Loading passages from LEANN index...")
|
||||
passages, passage_ids = load_passages_from_leann_index(index_path)
|
||||
|
||||
# Compute embeddings
|
||||
print("🧮 Computing passage embeddings...")
|
||||
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
|
||||
|
||||
print("🧮 Computing query embeddings...")
|
||||
query_embeddings = compute_embeddings_direct(queries, embedding_model)
|
||||
|
||||
# Build FAISS indexes
|
||||
print("🏗️ Building FAISS indexes...")
|
||||
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
|
||||
|
||||
# Test different efSearch values (equivalent to LEANN complexity)
|
||||
print("\n📊 Evaluating Recall@3 at different efSearch values...")
|
||||
ef_search_values = [16, 32, 64, 128, 256]
|
||||
|
||||
for ef_search in ef_search_values:
|
||||
print(f"\n🧪 Testing efSearch = {ef_search}")
|
||||
start_time = time.time()
|
||||
|
||||
recall = evaluate_recall_at_k(
|
||||
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(
|
||||
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
print("\n✅ Verification completed!")
|
||||
print("\n📋 Summary:")
|
||||
print(" - Built independent FAISS Flat and HNSW indexes")
|
||||
print(" - Compared recall@3 at different efSearch values")
|
||||
print(" - Used same embedding model as LEANN")
|
||||
print(" - This validates LEANN's recall measurements")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
benchmarks/laion/.gitignore
vendored
1
benchmarks/laion/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
data/
|
||||
@@ -1,199 +0,0 @@
|
||||
# LAION Multimodal Benchmark
|
||||
|
||||
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
|
||||
|
||||
## Overview
|
||||
|
||||
This benchmark evaluates:
|
||||
- **Image retrieval timing** using caption-based queries
|
||||
- **Recall@K performance** for image search
|
||||
- **Complexity analysis** across different search parameters
|
||||
- **Index size and storage efficiency**
|
||||
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
|
||||
|
||||
## Dataset Configuration
|
||||
|
||||
- **Dataset**: LAION-400M subset (10,000 images)
|
||||
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
|
||||
- **Queries**: 200 random captions from the dataset
|
||||
- **Ground Truth**: Self-recall (query caption → original image)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Setup the benchmark
|
||||
|
||||
```bash
|
||||
cd benchmarks/laion
|
||||
python setup_laion.py --num-samples 10000 --num-queries 200
|
||||
```
|
||||
|
||||
This will:
|
||||
- Create dummy LAION data (10K samples)
|
||||
- Generate CLIP embeddings (512-dim)
|
||||
- Build LEANN index with HNSW backend
|
||||
- Create 200 evaluation queries
|
||||
|
||||
### 2. Run evaluation
|
||||
|
||||
```bash
|
||||
# Run all evaluation stages
|
||||
python evaluate_laion.py --index data/laion_index.leann
|
||||
|
||||
# Run specific stages
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
|
||||
|
||||
# Multimodal generation with Qwen2.5-VL
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
### 3. Save results
|
||||
|
||||
```bash
|
||||
python evaluate_laion.py --index data/laion_index.leann --output results.json
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Setup Options
|
||||
```bash
|
||||
python setup_laion.py \
|
||||
--num-samples 10000 \
|
||||
--num-queries 200 \
|
||||
--index-path data/laion_index.leann \
|
||||
--backend hnsw
|
||||
```
|
||||
|
||||
### Evaluation Options
|
||||
```bash
|
||||
python evaluate_laion.py \
|
||||
--index data/laion_index.leann \
|
||||
--queries data/evaluation_queries.jsonl \
|
||||
--complexity 64 \
|
||||
--top-k 3 \
|
||||
--num-samples 100 \
|
||||
--stage all
|
||||
```
|
||||
|
||||
## Evaluation Stages
|
||||
|
||||
### Stage 2: Recall Evaluation
|
||||
- Evaluates Recall@3 for multimodal retrieval
|
||||
- Compares LEANN vs FAISS baseline performance
|
||||
- Self-recall: query caption should retrieve original image
|
||||
|
||||
### Stage 3: Complexity Analysis
|
||||
- Binary search for optimal complexity (90% recall target)
|
||||
- Tests performance across different complexity levels
|
||||
- Analyzes speed vs. accuracy tradeoffs
|
||||
|
||||
### Stage 4: Index Comparison
|
||||
- Compares compact vs non-compact index sizes
|
||||
- Measures search performance differences
|
||||
- Reports storage efficiency and speed ratios
|
||||
|
||||
### Stage 5: Multimodal Generation
|
||||
- Uses Qwen2.5-VL for image understanding and description
|
||||
- Retrieval-Augmented Generation (RAG) with multimodal context
|
||||
- Measures both search and generation timing
|
||||
|
||||
## Output Metrics
|
||||
|
||||
### Timing Metrics
|
||||
- Average/median/min/max search time
|
||||
- Standard deviation
|
||||
- Searches per second
|
||||
- Latency in milliseconds
|
||||
|
||||
### Recall Metrics
|
||||
- Recall@3 percentage for image retrieval
|
||||
- Number of queries with ground truth
|
||||
|
||||
### Index Metrics
|
||||
- Total index size (MB)
|
||||
- Component breakdown (index, passages, metadata)
|
||||
- Storage savings (compact vs non-compact)
|
||||
- Backend and embedding model info
|
||||
|
||||
### Generation Metrics (Stage 5)
|
||||
- Average search time per query
|
||||
- Average generation time per query
|
||||
- Time distribution (search vs generation)
|
||||
- Sample multimodal responses
|
||||
- Model: Qwen2.5-VL performance
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
|
||||
|
||||
**Stage 3: Optimal Complexity Analysis**
|
||||
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
|
||||
- **Binary Search Range**: 1-128
|
||||
- **Target Recall**: 90%
|
||||
- **Index Type**: Non-compact (for fast binary search)
|
||||
|
||||
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
|
||||
- **Total Queries**: 20
|
||||
- **Average Search Time**: 1.200s per query
|
||||
- **Average Generation Time**: 6.558s per query
|
||||
- **Time Distribution**: Search 15.5%, Generation 84.5%
|
||||
- **LLM Backend**: HuggingFace transformers
|
||||
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
- **Optimal Complexity**: 85
|
||||
|
||||
**System Performance:**
|
||||
- **Index Size**: ~10,000 image embeddings from LAION subset
|
||||
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
|
||||
- **Backend**: HNSW with cosine distance
|
||||
|
||||
### Example Results
|
||||
|
||||
```
|
||||
🎯 LAION MULTIMODAL BENCHMARK RESULTS
|
||||
============================================================
|
||||
|
||||
📊 Multimodal Generation Results:
|
||||
Total Queries: 20
|
||||
Avg Search Time: 1.200s
|
||||
Avg Generation Time: 6.558s
|
||||
Time Distribution: Search 15.5%, Generation 84.5%
|
||||
LLM Backend: HuggingFace transformers
|
||||
Model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
⚙️ Optimal Complexity Analysis:
|
||||
Target Recall: 90%
|
||||
Optimal Complexity: 85
|
||||
Binary Search Range: 1-128
|
||||
Non-compact Index (fast search, no recompute)
|
||||
|
||||
🚀 Performance Summary:
|
||||
Multimodal RAG: 7.758s total per query
|
||||
Search: 15.5% of total time
|
||||
Generation: 84.5% of total time
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
benchmarks/laion/
|
||||
├── setup_laion.py # Setup script
|
||||
├── evaluate_laion.py # Evaluation script
|
||||
├── README.md # This file
|
||||
└── data/ # Generated data
|
||||
├── laion_images/ # Image files (placeholder)
|
||||
├── laion_metadata.jsonl # Image metadata
|
||||
├── laion_passages.jsonl # LEANN passages
|
||||
├── laion_embeddings.npy # CLIP embeddings
|
||||
├── evaluation_queries.jsonl # Evaluation queries
|
||||
└── laion_index.leann/ # LEANN index files
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Current implementation uses dummy data for demonstration
|
||||
- For real LAION data, implement actual download logic in `setup_laion.py`
|
||||
- CLIP embeddings are randomly generated - replace with real CLIP model for production
|
||||
- Adjust `num_samples` and `num_queries` based on available resources
|
||||
- Consider using `--num-samples` during evaluation for faster testing
|
||||
@@ -1,725 +0,0 @@
|
||||
"""
|
||||
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann import LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
# Load FAISS flat baseline (image embeddings)
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.image_ids = pickle.load(f)
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
|
||||
|
||||
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
|
||||
self.st_clip = SentenceTransformer("clip-ViT-L-14")
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = len(captions)
|
||||
|
||||
for i, caption in enumerate(captions):
|
||||
# Get ground truth: search with FAISS flat using caption text embedding
|
||||
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
|
||||
query_embedding = self.st_clip.encode(
|
||||
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||
n = query_embedding.shape[0] # Number of queries
|
||||
k = 3 # Number of nearest neighbors
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(query_embedding),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
# Extract the results (image IDs from FAISS)
|
||||
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN at specified complexity (using caption as text query)
|
||||
test_results = self.searcher.search(
|
||||
caption,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {result.id for result in test_results}
|
||||
|
||||
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||
return avg_recall
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class LAIONEvaluator:
|
||||
def __init__(self, index_path: str):
|
||||
self.index_path = index_path
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
def load_queries(self, queries_file: str) -> list[str]:
|
||||
"""Load caption queries from evaluation file"""
|
||||
captions = []
|
||||
with open(queries_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
query_data = json.loads(line)
|
||||
captions.append(query_data["query"])
|
||||
|
||||
print(f"📊 Loaded {len(captions)} caption queries")
|
||||
return captions
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
|
||||
print("📏 Analyzing index sizes (.index only)...")
|
||||
|
||||
# Get all index-related files
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem # Remove .leann extension
|
||||
|
||||
sizes: dict[str, float] = {}
|
||||
|
||||
# Core index files
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||
|
||||
# Core index size (.index only)
|
||||
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||
sizes["index_only_mb"] = index_mb
|
||||
|
||||
# Other files for reference (not counted in index_only_mb)
|
||||
sizes["metadata_mb"] = (
|
||||
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_text_mb"] = (
|
||||
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_index_mb"] = (
|
||||
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||
)
|
||||
|
||||
print(f" 📁 .index size: {index_mb:.1f} MB")
|
||||
if sizes["metadata_mb"]:
|
||||
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
|
||||
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
|
||||
print(
|
||||
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
|
||||
)
|
||||
|
||||
return sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison purposes"""
|
||||
print("🏗️ Building non-compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Load CLIP embeddings
|
||||
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
|
||||
embeddings = np.load(embeddings_file)
|
||||
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
|
||||
|
||||
# Build non-compact index with same passages and embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=False, # Disable recompute (store embeddings)
|
||||
is_compact=False, # Disable compact storage
|
||||
distance_metric="cosine",
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact", "distance_metric"]
|
||||
},
|
||||
)
|
||||
|
||||
# Prepare ids and add passages
|
||||
ids: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
ids.append(str(data["id"]))
|
||||
# Ensure metadata contains the id used by the vector index
|
||||
metadata = {**data.get("metadata", {}), "id": data["id"]}
|
||||
builder.add_text(text=data["text"], metadata=metadata)
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
|
||||
# Persist a pickle for build_index_from_embeddings
|
||||
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
|
||||
print(
|
||||
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||
)
|
||||
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
|
||||
) -> dict:
|
||||
"""Compare performance between non-compact and compact indexes"""
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
|
||||
# Test queries
|
||||
test_queries = test_captions[:5]
|
||||
|
||||
results = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
}
|
||||
|
||||
# Test non-compact index (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
|
||||
for caption in test_queries:
|
||||
start_time = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
caption, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["non_compact"]["search_times"].append(search_time)
|
||||
|
||||
# Test compact index (with recompute)
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
|
||||
for caption in test_queries:
|
||||
start_time = time.time()
|
||||
_ = compact_searcher.search(
|
||||
caption, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["compact"]["search_times"].append(search_time)
|
||||
|
||||
# Calculate averages
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
|
||||
# Performance ratio
|
||||
if results["avg_search_times"]["compact"] > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = float("inf")
|
||||
|
||||
print(
|
||||
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||
)
|
||||
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||
|
||||
# Cleanup
|
||||
non_compact_searcher.cleanup()
|
||||
compact_searcher.cleanup()
|
||||
|
||||
return results
|
||||
|
||||
def _print_results(self, timing_metrics: dict):
|
||||
"""Print evaluation results"""
|
||||
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
|
||||
print("=" * 60)
|
||||
|
||||
# Index comparison analysis (prefer .index-only view if present)
|
||||
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||
current = timing_metrics["current_index"]
|
||||
non_compact = timing_metrics["non_compact_index"]
|
||||
|
||||
if "index_only_mb" in current and "index_only_mb" in non_compact:
|
||||
print("\n📏 Index Comparison Analysis (.index only):")
|
||||
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
|
||||
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
# Show excluded components for reference if available
|
||||
if any(
|
||||
k in non_compact
|
||||
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
|
||||
):
|
||||
print(" (passages excluded in totals, shown for reference):")
|
||||
print(
|
||||
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
|
||||
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
|
||||
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
|
||||
)
|
||||
else:
|
||||
# Fallback to legacy totals if running with older metrics
|
||||
print("\n📏 Index Comparison Analysis:")
|
||||
print(
|
||||
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
print(" Component breakdown (non-compact):")
|
||||
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||
|
||||
# Performance comparison
|
||||
if "performance_comparison" in timing_metrics:
|
||||
perf = timing_metrics["performance_comparison"]
|
||||
print("\n⚡ Performance Comparison:")
|
||||
print(
|
||||
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||
)
|
||||
print(
|
||||
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||
)
|
||||
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||
|
||||
# Legacy single index analysis (fallback)
|
||||
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||
print("\n📏 Index Size Analysis:")
|
||||
print(
|
||||
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument(
|
||||
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "5", "all"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument(
|
||||
"--llm-backend",
|
||||
choices=["hf"],
|
||||
default="hf",
|
||||
help="LLM backend (Qwen2.5-VL only supports HF)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Check if baseline exists
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_laion.py first to build the baseline")
|
||||
exit(1)
|
||||
|
||||
if args.stage == "2" or args.stage == "all":
|
||||
# Stage 2: Recall@3 evaluation
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
|
||||
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
# Load caption queries for testing
|
||||
laion_evaluator = LAIONEvaluator(args.index)
|
||||
captions = laion_evaluator.load_queries(args.queries)
|
||||
|
||||
# Test with queries for robust measurement
|
||||
test_captions = captions[:100] # Use subset for speed
|
||||
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||
|
||||
# Test with complexity 64
|
||||
complexity = 64
|
||||
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
|
||||
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
# Shared non-compact index path for Stage 3 and 4
|
||||
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||
complexity = args.complexity
|
||||
|
||||
if args.stage == "3" or args.stage == "all":
|
||||
# Stage 3: Binary search for 90% recall complexity
|
||||
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||
print(
|
||||
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||
)
|
||||
|
||||
# Create non-compact index for binary search
|
||||
print("🏗️ Creating non-compact index for binary search...")
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact index for binary search
|
||||
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
|
||||
# Load caption queries for testing
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
|
||||
# Use subset for robust measurement
|
||||
test_captions = captions[:50] # Smaller subset for binary search speed
|
||||
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||
|
||||
# Binary search for 90% recall complexity
|
||||
target_recall = 0.9
|
||||
min_complexity, max_complexity = 1, 128
|
||||
|
||||
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||
|
||||
best_complexity = None
|
||||
best_recall = 0.0
|
||||
|
||||
while min_complexity <= max_complexity:
|
||||
mid_complexity = (min_complexity + max_complexity) // 2
|
||||
|
||||
print(
|
||||
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||
)
|
||||
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_captions, mid_complexity, recompute_embeddings=False
|
||||
)
|
||||
|
||||
print(
|
||||
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||
)
|
||||
|
||||
if recall >= target_recall:
|
||||
best_complexity = mid_complexity
|
||||
best_recall = recall
|
||||
max_complexity = mid_complexity - 1
|
||||
print(" ✅ Target reached! Searching for lower complexity...")
|
||||
else:
|
||||
min_complexity = mid_complexity + 1
|
||||
print(" ❌ Below target. Searching for higher complexity...")
|
||||
|
||||
if best_complexity is not None:
|
||||
print("\n🎯 Optimal complexity found!")
|
||||
print(f" Complexity: {best_complexity}")
|
||||
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||
|
||||
# Test a few complexities around the optimal one for verification
|
||||
print("\n🔬 Verification test around optimal complexity:")
|
||||
verification_complexities = [
|
||||
max(1, best_complexity - 2),
|
||||
max(1, best_complexity - 1),
|
||||
best_complexity,
|
||||
best_complexity + 1,
|
||||
best_complexity + 2,
|
||||
]
|
||||
|
||||
for complexity in verification_complexities:
|
||||
if complexity <= 512: # reasonable upper bound
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_captions, complexity, recompute_embeddings=False
|
||||
)
|
||||
status = "✅" if recall >= target_recall else "❌"
|
||||
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||
|
||||
# Now test the optimal complexity with compact index and recompute for comparison
|
||||
print(
|
||||
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||
)
|
||||
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||
test_captions[:10], best_complexity, recompute_embeddings=True
|
||||
)
|
||||
print(
|
||||
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||
)
|
||||
complexity = best_complexity
|
||||
print(
|
||||
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||
)
|
||||
compact_evaluator.cleanup()
|
||||
else:
|
||||
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||
print("All tested complexities were below target.")
|
||||
|
||||
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||
binary_search_evaluator.cleanup()
|
||||
evaluator.cleanup()
|
||||
|
||||
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||
|
||||
if args.stage == "4" or args.stage == "all":
|
||||
# Stage 4: Index comparison (without LLM generation)
|
||||
print("🚀 Starting Stage 4: Index comparison analysis")
|
||||
|
||||
# Use LAION evaluator for index comparison
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
|
||||
# Load caption queries
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
|
||||
# Step 1: Analyze current (compact) index
|
||||
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||
compact_size_metrics["index_type"] = "compact"
|
||||
|
||||
# Step 2: Use existing non-compact index or create if needed
|
||||
if Path(non_compact_index_path).exists():
|
||||
print(
|
||||
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||
)
|
||||
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_size_metrics["index_type"] = "non_compact"
|
||||
else:
|
||||
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||
non_compact_index_path
|
||||
)
|
||||
|
||||
# Step 3: Compare index sizes (.index only)
|
||||
print("\n📊 Index size comparison (.index only):")
|
||||
print(
|
||||
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
|
||||
)
|
||||
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
|
||||
|
||||
storage_saving = 0.0
|
||||
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
|
||||
storage_saving = (
|
||||
(
|
||||
non_compact_size_metrics.get("index_only_mb", 0)
|
||||
- compact_size_metrics.get("index_only_mb", 0)
|
||||
)
|
||||
/ non_compact_size_metrics.get("index_only_mb", 1)
|
||||
* 100
|
||||
)
|
||||
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||
|
||||
# Step 4: Performance comparison between the two indexes
|
||||
if complexity is None:
|
||||
raise ValueError("Complexity is required for index comparison")
|
||||
|
||||
print("\n⚡ Performance comparison between indexes...")
|
||||
performance_metrics = evaluator.compare_index_performance(
|
||||
non_compact_index_path, args.index, captions[:10], complexity=complexity
|
||||
)
|
||||
|
||||
# Combine all metrics
|
||||
combined_metrics = {
|
||||
"current_index": compact_size_metrics,
|
||||
"non_compact_index": non_compact_size_metrics,
|
||||
"performance_comparison": performance_metrics,
|
||||
"storage_saving_percent": storage_saving,
|
||||
}
|
||||
|
||||
# Print comprehensive results
|
||||
evaluator._print_results(combined_metrics)
|
||||
|
||||
# Save results if requested
|
||||
if args.output:
|
||||
print(f"\n💾 Saving results to {args.output}...")
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(combined_metrics, f, indent=2, default=str)
|
||||
print(f"✅ Results saved to {args.output}")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage in ("5", "all"):
|
||||
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
test_captions = captions[: min(20, len(captions))] # Use subset for generation
|
||||
|
||||
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
|
||||
|
||||
# Load Qwen2.5-VL model
|
||||
try:
|
||||
print("Loading Qwen2.5-VL model...")
|
||||
processor, model = load_qwen_vl_model(args.model_name)
|
||||
|
||||
# Run multimodal generation evaluation
|
||||
complexity = args.complexity or 64
|
||||
gen_results = evaluate_multimodal_rag(
|
||||
evaluator.searcher,
|
||||
test_captions,
|
||||
processor=processor,
|
||||
model=model,
|
||||
complexity=complexity,
|
||||
)
|
||||
|
||||
print("\n📊 Multimodal Generation Results:")
|
||||
print(f" Total Queries: {len(test_captions)}")
|
||||
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
|
||||
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
|
||||
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
|
||||
search_pct = (gen_results["avg_search_time"] / total_time) * 100
|
||||
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
|
||||
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
|
||||
print(" LLM Backend: HuggingFace transformers")
|
||||
print(f" Model: {args.model_name}")
|
||||
|
||||
# Show sample results
|
||||
print("\n📝 Sample Multimodal Generations:")
|
||||
for i, response in enumerate(gen_results["results"][:3]):
|
||||
# Handle both string and dict formats for captions
|
||||
if isinstance(test_captions[i], dict):
|
||||
caption_text = test_captions[i].get("query", str(test_captions[i]))
|
||||
else:
|
||||
caption_text = str(test_captions[i])
|
||||
print(f" Query {i + 1}: {caption_text[:60]}...")
|
||||
print(f" Response {i + 1}: {response[:100]}...")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Multimodal generation evaluation failed: {e}")
|
||||
print("💡 Make sure transformers and Qwen2.5-VL are installed")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 5 completed!\n")
|
||||
|
||||
if args.stage == "all":
|
||||
print("🎉 All evaluation stages completed successfully!")
|
||||
print("\n📋 Summary:")
|
||||
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
|
||||
print(" Stage 3: ✅ Optimal complexity found")
|
||||
print(" Stage 4: ✅ Index comparison analysis completed")
|
||||
print(" Stage 5: ✅ Multimodal generation evaluation completed")
|
||||
print("\n🔧 Recommended next steps:")
|
||||
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||
print(" - Review index comparison for storage vs performance tradeoffs")
|
||||
|
||||
# Clean up non-compact index after all stages complete
|
||||
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||
if Path(non_compact_index_path).exists():
|
||||
temp_index_dir = Path(non_compact_index_path).parent
|
||||
temp_index_name = Path(non_compact_index_path).name
|
||||
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||
temp_file.unlink()
|
||||
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||
else:
|
||||
print("📝 No temporary index to clean up")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Evaluation interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,576 +0,0 @@
|
||||
"""
|
||||
LAION Multimodal Benchmark Setup Script
|
||||
Downloads LAION subset and builds LEANN index with sentence embeddings
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from leann import LeannBuilder
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class LAIONSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.images_dir = self.data_dir / "laion_images"
|
||||
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
|
||||
|
||||
# Create directories
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
self.images_dir.mkdir(exist_ok=True)
|
||||
|
||||
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
|
||||
"""Download a single image asynchronously"""
|
||||
async with semaphore: # Limit concurrent downloads
|
||||
try:
|
||||
image_url = sample_data["url"]
|
||||
image_path = sample_data["image_path"]
|
||||
|
||||
# Skip if already exists
|
||||
if os.path.exists(image_path):
|
||||
progress_bar.update(1)
|
||||
return sample_data
|
||||
|
||||
async with session.get(image_url, timeout=10) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
|
||||
# Verify it's a valid image
|
||||
try:
|
||||
img = Image.open(io.BytesIO(content))
|
||||
img = img.convert("RGB")
|
||||
img.save(image_path, "JPEG")
|
||||
progress_bar.update(1)
|
||||
return sample_data
|
||||
except Exception:
|
||||
progress_bar.update(1)
|
||||
return None # Skip invalid images
|
||||
else:
|
||||
progress_bar.update(1)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
progress_bar.update(1)
|
||||
return None
|
||||
|
||||
def download_laion_subset(self, num_samples: int = 1000):
|
||||
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
|
||||
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
|
||||
|
||||
# Load LAION-400M subset from HuggingFace
|
||||
print("🤗 Loading from HuggingFace datasets...")
|
||||
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
|
||||
|
||||
# Collect sample metadata first (fast)
|
||||
print("📋 Collecting sample metadata...")
|
||||
candidates = []
|
||||
for sample in dataset:
|
||||
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
|
||||
break
|
||||
|
||||
image_url = sample.get("url", "")
|
||||
caption = sample.get("caption", "")
|
||||
|
||||
if not image_url or not caption:
|
||||
continue
|
||||
|
||||
image_filename = f"laion_{len(candidates):06d}.jpg"
|
||||
image_path = self.images_dir / image_filename
|
||||
|
||||
candidate = {
|
||||
"id": f"laion_{len(candidates):06d}",
|
||||
"url": image_url,
|
||||
"caption": caption,
|
||||
"image_path": str(image_path),
|
||||
"width": sample.get("original_width", 512),
|
||||
"height": sample.get("original_height", 512),
|
||||
"similarity": sample.get("similarity", 0.0),
|
||||
}
|
||||
candidates.append(candidate)
|
||||
|
||||
print(
|
||||
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
|
||||
)
|
||||
|
||||
# Download images in parallel
|
||||
async def download_batch():
|
||||
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
|
||||
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
|
||||
|
||||
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||
tasks = []
|
||||
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
|
||||
task = self.download_single_image(session, candidate, semaphore, progress_bar)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all downloads
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
progress_bar.close()
|
||||
|
||||
# Filter successful downloads
|
||||
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
|
||||
return successful[:num_samples]
|
||||
|
||||
# Run async download
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
samples = loop.run_until_complete(download_batch())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Save metadata
|
||||
with open(self.metadata_file, "w", encoding="utf-8") as f:
|
||||
for sample in samples:
|
||||
f.write(json.dumps(sample) + "\n")
|
||||
|
||||
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
|
||||
return samples
|
||||
|
||||
def generate_clip_image_embeddings(self, samples: list[dict]):
|
||||
"""Generate CLIP image embeddings for downloaded images"""
|
||||
print("🔍 Generating CLIP image embeddings...")
|
||||
|
||||
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
|
||||
# This single model can encode both images and text.
|
||||
model = SentenceTransformer("clip-ViT-L-14")
|
||||
|
||||
embeddings = []
|
||||
valid_samples = []
|
||||
|
||||
for sample in tqdm(samples, desc="Processing images"):
|
||||
try:
|
||||
# Load image
|
||||
image_path = sample["image_path"]
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Encode image to 768-dim embedding via sentence-transformers (normalized)
|
||||
vec = model.encode(
|
||||
[image],
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=1,
|
||||
show_progress_bar=False,
|
||||
)[0]
|
||||
embeddings.append(vec.astype(np.float32))
|
||||
valid_samples.append(sample)
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Failed to process {sample['id']}: {e}")
|
||||
# Skip invalid images
|
||||
|
||||
embeddings = np.array(embeddings, dtype=np.float32)
|
||||
|
||||
# Save embeddings
|
||||
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
|
||||
np.save(embeddings_file, embeddings)
|
||||
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
|
||||
|
||||
return embeddings, valid_samples
|
||||
|
||||
def build_faiss_baseline(
|
||||
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
|
||||
):
|
||||
"""Build FAISS flat baseline using CLIP image embeddings"""
|
||||
print("🔨 Building FAISS Flat baseline...")
|
||||
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Extract image IDs (must be present)
|
||||
if not samples or "id" not in samples[0]:
|
||||
raise KeyError("samples missing 'id' field for FAISS baseline")
|
||||
image_ids: list[str] = [str(sample["id"]) for sample in samples]
|
||||
|
||||
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||
print(f"📄 Processing {len(image_ids)} images")
|
||||
|
||||
# Build FAISS flat index
|
||||
print("🏗️ Building FAISS IndexFlatIP...")
|
||||
dimension = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
# Add embeddings to flat index
|
||||
embeddings_f32 = embeddings.astype(np.float32)
|
||||
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||
|
||||
# Save index and metadata
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as f:
|
||||
pickle.dump(image_ids, f)
|
||||
|
||||
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||
print(f"✅ Metadata saved to {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
|
||||
return baseline_path
|
||||
|
||||
def create_leann_passages(self, samples: list[dict]):
|
||||
"""Create LEANN-compatible passages from LAION data"""
|
||||
print("📝 Creating LEANN passages...")
|
||||
|
||||
passages_file = self.data_dir / "laion_passages.jsonl"
|
||||
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
for i, sample in enumerate(samples):
|
||||
passage = {
|
||||
"id": sample["id"],
|
||||
"text": sample["caption"], # Use caption as searchable text
|
||||
"metadata": {
|
||||
"image_url": sample["url"],
|
||||
"image_path": sample.get("image_path", ""),
|
||||
"width": sample["width"],
|
||||
"height": sample["height"],
|
||||
"similarity": sample["similarity"],
|
||||
"image_index": i, # Index for embedding lookup
|
||||
},
|
||||
}
|
||||
f.write(json.dumps(passage) + "\n")
|
||||
|
||||
print(f"✅ Created {len(samples)} passages")
|
||||
return passages_file
|
||||
|
||||
def build_compact_index(
|
||||
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
|
||||
print(f"🏗️ Building compact LEANN index with {backend} backend...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
|
||||
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||
np.save(npy_path, embeddings)
|
||||
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||
|
||||
# Prepare ids in the same order as passages_file (matches embeddings order)
|
||||
ids: list[str] = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
rec = json.loads(line)
|
||||
ids.append(str(rec["id"]))
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
|
||||
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||
|
||||
# Initialize builder - compact with recompute
|
||||
# Note: For multimodal case, we need to handle embeddings differently
|
||||
# Let's try using sentence-transformers mode but with custom embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
# HNSW params (or forwarded to chosen backend)
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
# Compact/pruned with recompute at query time
|
||||
is_recompute=True,
|
||||
is_compact=True,
|
||||
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Add passages (text + metadata)
|
||||
print("📚 Adding passages...")
|
||||
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||
|
||||
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
|
||||
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
print(f"✅ Compact index built in {build_time:.2f}s")
|
||||
|
||||
# Analyze index size
|
||||
self._analyze_index_size(index_path)
|
||||
|
||||
return index_path
|
||||
|
||||
def build_non_compact_index(
|
||||
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
|
||||
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Ensure embeddings are saved (npy + pickle)
|
||||
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||
if not npy_path.exists():
|
||||
np.save(npy_path, embeddings)
|
||||
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||
# Prepare ids in same order as passages_file
|
||||
ids: list[str] = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
rec = json.loads(line)
|
||||
ids.append(str(rec["id"]))
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||
if not pkl_path.exists():
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||
|
||||
# Initialize builder - non-compact without recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=False, # Store embeddings (no recompute needed)
|
||||
is_compact=False, # Store full index (not pruned)
|
||||
distance_metric="cosine",
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Add passages - embeddings will be loaded from file
|
||||
print("📚 Adding passages...")
|
||||
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||
|
||||
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
|
||||
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
print(f"✅ Non-compact index built in {build_time:.2f}s")
|
||||
|
||||
# Analyze index size
|
||||
self._analyze_index_size(index_path)
|
||||
|
||||
return index_path
|
||||
|
||||
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
|
||||
"""Helper to add passages with pre-computed CLIP embeddings"""
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in tqdm(f, desc="Adding passages"):
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
|
||||
# Add image metadata - LEANN will handle embeddings separately
|
||||
# Note: We store image metadata and caption text for searchability
|
||||
# Important: ensure passage ID in metadata matches vector ID
|
||||
builder.add_text(
|
||||
text=passage["text"], # Image caption for searchability
|
||||
metadata={**passage["metadata"], "id": passage["id"]},
|
||||
)
|
||||
|
||||
def _analyze_index_size(self, index_path: str):
|
||||
"""Analyze index file sizes"""
|
||||
print("📏 Analyzing index sizes...")
|
||||
|
||||
index_path = Path(index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.name # e.g., laion_index.leann
|
||||
index_prefix = index_path.stem # e.g., laion_index
|
||||
|
||||
files = [
|
||||
(f"{index_prefix}.index", ".index", "core"),
|
||||
(f"{index_name}.meta.json", ".meta.json", "core"),
|
||||
(f"{index_name}.ids.txt", ".ids.txt", "core"),
|
||||
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
|
||||
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
|
||||
]
|
||||
|
||||
def _fmt_size(bytes_val: int) -> str:
|
||||
if bytes_val < 1024:
|
||||
return f"{bytes_val} B"
|
||||
kb = bytes_val / 1024
|
||||
if kb < 1024:
|
||||
return f"{kb:.1f} KB"
|
||||
mb = kb / 1024
|
||||
if mb < 1024:
|
||||
return f"{mb:.2f} MB"
|
||||
gb = mb / 1024
|
||||
return f"{gb:.2f} GB"
|
||||
|
||||
total_index_only_mb = 0.0
|
||||
total_all_mb = 0.0
|
||||
for filename, label, group in files:
|
||||
file_path = index_dir / filename
|
||||
if file_path.exists():
|
||||
size_bytes = file_path.stat().st_size
|
||||
print(f" {label}: {_fmt_size(size_bytes)}")
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
total_all_mb += size_mb
|
||||
if group == "core":
|
||||
total_index_only_mb += size_mb
|
||||
else:
|
||||
print(f" {label}: (missing)")
|
||||
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
|
||||
print(f" Total (including passages): {total_all_mb:.2f} MB")
|
||||
|
||||
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
|
||||
"""Create evaluation queries from captions"""
|
||||
print(f"📝 Creating {num_queries} evaluation queries...")
|
||||
|
||||
# Sample random captions as queries
|
||||
import random
|
||||
|
||||
random.seed(42) # For reproducibility
|
||||
|
||||
query_samples = random.sample(samples, min(num_queries, len(samples)))
|
||||
|
||||
queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||
with open(queries_file, "w", encoding="utf-8") as f:
|
||||
for sample in query_samples:
|
||||
query = {
|
||||
"id": sample["id"],
|
||||
"query": sample["caption"],
|
||||
"ground_truth_id": sample["id"], # For potential recall evaluation
|
||||
}
|
||||
f.write(json.dumps(query) + "\n")
|
||||
|
||||
print(f"✅ Created {len(query_samples)} evaluation queries")
|
||||
return queries_file
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
|
||||
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
|
||||
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
|
||||
parser.add_argument(
|
||||
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
|
||||
)
|
||||
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🚀 Setting up LAION Multimodal Benchmark")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Initialize setup
|
||||
setup = LAIONSetup(args.data_dir)
|
||||
|
||||
# Step 1: Download LAION subset
|
||||
if not args.skip_download:
|
||||
print("\n📦 Step 1: Download LAION subset")
|
||||
samples = setup.download_laion_subset(args.num_samples)
|
||||
|
||||
# Step 2: Generate CLIP image embeddings
|
||||
print("\n🔍 Step 2: Generate CLIP image embeddings")
|
||||
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
|
||||
|
||||
# Step 3: Create LEANN passages (image metadata with embeddings)
|
||||
print("\n📝 Step 3: Create LEANN passages")
|
||||
passages_file = setup.create_leann_passages(valid_samples)
|
||||
else:
|
||||
print("⏭️ Skipping LAION dataset download")
|
||||
# Load existing data
|
||||
passages_file = setup.data_dir / "laion_passages.jsonl"
|
||||
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
|
||||
|
||||
if not passages_file.exists() or not embeddings_file.exists():
|
||||
raise FileNotFoundError(
|
||||
"Passages or embeddings file not found. Run without --skip-download first."
|
||||
)
|
||||
|
||||
embeddings = np.load(embeddings_file)
|
||||
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
|
||||
|
||||
# Step 4: Build LEANN indexes (both compact and non-compact)
|
||||
if not args.skip_build:
|
||||
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
|
||||
|
||||
# Build compact index (production mode - small, recompute required)
|
||||
compact_index_path = args.index_path
|
||||
print(f"Building compact index: {compact_index_path}")
|
||||
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
|
||||
|
||||
# Build non-compact index (comparison mode - large, fast search)
|
||||
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
|
||||
print(f"Building non-compact index: {non_compact_index_path}")
|
||||
setup.build_non_compact_index(
|
||||
passages_file, embeddings, non_compact_index_path, args.backend
|
||||
)
|
||||
|
||||
# Step 5: Build FAISS flat baseline
|
||||
print("\n🔨 Step 5: Build FAISS flat baseline")
|
||||
if not args.skip_download:
|
||||
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||
else:
|
||||
# Load valid_samples from passages file for FAISS baseline
|
||||
valid_samples = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
|
||||
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||
|
||||
# Step 6: Create evaluation queries
|
||||
print("\n📝 Step 6: Create evaluation queries")
|
||||
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
|
||||
else:
|
||||
print("⏭️ Skipping index building")
|
||||
baseline_path = "data/baseline/faiss_index.bin"
|
||||
queries_file = setup.data_dir / "evaluation_queries.jsonl"
|
||||
|
||||
print("\n🎉 Setup completed successfully!")
|
||||
print("📊 Summary:")
|
||||
if not args.skip_download:
|
||||
print(f" Downloaded samples: {len(samples)}")
|
||||
print(f" Valid samples with embeddings: {len(valid_samples)}")
|
||||
else:
|
||||
print(f" Loaded {len(embeddings)} embeddings")
|
||||
|
||||
if not args.skip_build:
|
||||
print(f" Compact index: {compact_index_path}")
|
||||
print(f" Non-compact index: {non_compact_index_path}")
|
||||
print(f" FAISS baseline: {baseline_path}")
|
||||
print(f" Queries: {queries_file}")
|
||||
|
||||
print("\n🔧 Next steps:")
|
||||
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
|
||||
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
|
||||
else:
|
||||
print(" Skipped building indexes")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Setup interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Setup failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,342 +0,0 @@
|
||||
"""
|
||||
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
|
||||
def is_qwen3_model(model_name):
|
||||
"""Check if model is Qwen3"""
|
||||
return "Qwen3" in model_name or "qwen3" in model_name.lower()
|
||||
|
||||
|
||||
def is_qwen_vl_model(model_name):
|
||||
"""Check if model is Qwen2.5-VL"""
|
||||
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
|
||||
|
||||
|
||||
def apply_qwen3_chat_template(tokenizer, prompt):
|
||||
"""Apply Qwen3 chat template with thinking enabled"""
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
|
||||
|
||||
def extract_thinking_answer(response):
|
||||
"""Extract final answer from Qwen3 thinking model response"""
|
||||
if "<think>" in response and "</think>" in response:
|
||||
try:
|
||||
think_end = response.index("</think>") + len("</think>")
|
||||
final_answer = response[think_end:].strip()
|
||||
return final_answer
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return response.strip()
|
||||
|
||||
|
||||
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load HuggingFace model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading HF: {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load vLLM model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading vLLM: {model_name}")
|
||||
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
|
||||
|
||||
# Qwen3 specific config
|
||||
if is_qwen3_model(model_name):
|
||||
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
|
||||
max_tokens = 2048
|
||||
else:
|
||||
stop_tokens = None
|
||||
max_tokens = 1024
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
|
||||
return llm, sampling_params
|
||||
|
||||
|
||||
def generate_hf(tokenizer, model, prompt, max_tokens=None):
|
||||
"""Generate with HF - supports Qwen3 thinking models"""
|
||||
model_name = getattr(model, "name_or_path", "unknown")
|
||||
is_qwen3 = is_qwen3_model(model_name)
|
||||
|
||||
# Apply chat template for Qwen3
|
||||
if is_qwen3:
|
||||
prompt = apply_qwen3_chat_template(tokenizer, prompt)
|
||||
max_tokens = max_tokens or 2048
|
||||
else:
|
||||
max_tokens = max_tokens or 1024
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=0.7,
|
||||
do_sample=True,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
response = response[len(prompt) :].strip()
|
||||
|
||||
# Extract final answer for thinking models
|
||||
if is_qwen3:
|
||||
return extract_thinking_answer(response)
|
||||
return response
|
||||
|
||||
|
||||
def generate_vllm(llm, sampling_params, prompt):
|
||||
"""Generate with vLLM - supports Qwen3 thinking models"""
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
response = outputs[0].outputs[0].text.strip()
|
||||
|
||||
# Extract final answer for Qwen3 thinking models
|
||||
model_name = str(llm.llm_engine.model_config.model)
|
||||
if is_qwen3_model(model_name):
|
||||
return extract_thinking_answer(response)
|
||||
return response
|
||||
|
||||
|
||||
def create_prompt(context, query, domain="default"):
|
||||
"""Create RAG prompt"""
|
||||
if domain == "emails":
|
||||
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
elif domain == "finance":
|
||||
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
elif domain == "multimodal":
|
||||
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
else:
|
||||
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
|
||||
|
||||
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
|
||||
"""Simple RAG evaluation with timing"""
|
||||
search_times = []
|
||||
gen_times = []
|
||||
results = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
# Search
|
||||
start = time.time()
|
||||
docs = searcher.search(query, top_k=top_k, complexity=complexity)
|
||||
search_time = time.time() - start
|
||||
|
||||
# Generate
|
||||
context = "\n\n".join([doc.text for doc in docs])
|
||||
prompt = create_prompt(context, query, domain)
|
||||
|
||||
start = time.time()
|
||||
response = llm_func(prompt)
|
||||
gen_time = time.time() - start
|
||||
|
||||
search_times.append(search_time)
|
||||
gen_times.append(gen_time)
|
||||
results.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||
|
||||
return {
|
||||
"avg_search_time": sum(search_times) / len(search_times),
|
||||
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
|
||||
"""Load Qwen2.5-VL multimodal model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
|
||||
|
||||
# Fallback to specific class
|
||||
try:
|
||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e2:
|
||||
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
|
||||
|
||||
|
||||
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
|
||||
"""Generate with Qwen2.5-VL multimodal model"""
|
||||
from PIL import Image
|
||||
|
||||
# Prepare inputs
|
||||
if image_path:
|
||||
image = Image.open(image_path)
|
||||
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
||||
else:
|
||||
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
|
||||
)
|
||||
|
||||
# Decode response
|
||||
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
|
||||
response = processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
|
||||
"""Create prompt for multimodal RAG"""
|
||||
if task_type == "images":
|
||||
return f"""Based on the retrieved images and their descriptions, answer the following question.
|
||||
|
||||
Retrieved Image Descriptions:
|
||||
{context}
|
||||
|
||||
Question: {query}
|
||||
|
||||
Provide a detailed answer based on the visual content described above."""
|
||||
|
||||
return f"Context: {context}\nQuestion: {query}\nAnswer:"
|
||||
|
||||
|
||||
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
|
||||
"""Evaluate multimodal RAG with Qwen2.5-VL"""
|
||||
search_times = []
|
||||
gen_times = []
|
||||
results = []
|
||||
|
||||
for i, query_item in enumerate(queries):
|
||||
# Handle both string and dict formats for queries
|
||||
if isinstance(query_item, dict):
|
||||
query = query_item.get("query", "")
|
||||
image_path = query_item.get("image_path") # Optional reference image
|
||||
else:
|
||||
query = str(query_item)
|
||||
image_path = None
|
||||
|
||||
# Search
|
||||
start_time = time.time()
|
||||
search_results = searcher.search(query, top_k=3, complexity=complexity)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
|
||||
# Prepare context from search results
|
||||
context_parts = []
|
||||
for result in search_results:
|
||||
context_parts.append(f"- {result.text}")
|
||||
context = "\n".join(context_parts)
|
||||
|
||||
# Generate with multimodal model
|
||||
start_time = time.time()
|
||||
if processor and model:
|
||||
prompt = create_multimodal_prompt(context, query, context_parts)
|
||||
response = generate_qwen_vl(processor, model, prompt, image_path)
|
||||
else:
|
||||
response = f"Context: {context}"
|
||||
gen_time = time.time() - start_time
|
||||
|
||||
gen_times.append(gen_time)
|
||||
results.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||
|
||||
return {
|
||||
"avg_search_time": sum(search_times) / len(search_times),
|
||||
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||
"results": results,
|
||||
}
|
||||
@@ -53,7 +53,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
print(
|
||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||
)
|
||||
print("uv sync --only-group dev")
|
||||
print("uv pip install -e '.[dev]'")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during data download: {e}")
|
||||
|
||||
@@ -53,9 +53,9 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
|
||||
|
||||
### Setup Pre-commit
|
||||
|
||||
1. **Install pre-commit tools**:
|
||||
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||
```bash
|
||||
uv sync lint
|
||||
uv pip install pre-commit
|
||||
```
|
||||
|
||||
2. **Install the git hooks**:
|
||||
@@ -65,7 +65,7 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
|
||||
|
||||
3. **Run pre-commit manually** (optional):
|
||||
```bash
|
||||
uv run pre-commit run --all-files
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Pre-commit Checks
|
||||
@@ -85,9 +85,6 @@ Our pre-commit configuration includes:
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Install test tools only (no project runtime)
|
||||
uv sync --group test
|
||||
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
|
||||
@@ -83,81 +83,6 @@ ollama pull nomic-embed-text
|
||||
|
||||
</details>
|
||||
|
||||
## Local & Remote Inference Endpoints
|
||||
|
||||
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
|
||||
|
||||
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables.
|
||||
|
||||
### One-Time Environment Setup
|
||||
|
||||
```bash
|
||||
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
|
||||
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
|
||||
export OPENAI_BASE_URL="http://localhost:1234/v1"
|
||||
|
||||
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
|
||||
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
|
||||
```
|
||||
|
||||
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
|
||||
|
||||
### Passing Hosts Per Command
|
||||
|
||||
```bash
|
||||
# Build an index with a remote embedding server
|
||||
leann build my-notes \
|
||||
--docs ./notes \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-qwen3-embedding-0.6b \
|
||||
--embedding-api-base http://192.168.1.50:1234/v1 \
|
||||
--embedding-api-key local-dev-key
|
||||
|
||||
# Query using a local LM Studio instance via OpenAI-compatible API
|
||||
leann ask my-notes \
|
||||
--llm openai \
|
||||
--llm-model qwen3-8b \
|
||||
--api-base http://localhost:1234/v1 \
|
||||
--api-key local-dev-key
|
||||
|
||||
# Query an Ollama instance running on another box
|
||||
leann ask my-notes \
|
||||
--llm ollama \
|
||||
--llm-model qwen3:14b \
|
||||
--host http://192.168.1.101:11434
|
||||
```
|
||||
|
||||
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
|
||||
|
||||
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
|
||||
- Configure router or cloud provider port forwarding.
|
||||
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
|
||||
|
||||
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
|
||||
|
||||
**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials.
|
||||
|
||||
### Python API Usage
|
||||
|
||||
You can pass the same configuration from Python:
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_mode="openai",
|
||||
embedding_model="text-embedding-qwen3-embedding-0.6b",
|
||||
embedding_options={
|
||||
"base_url": "http://192.168.1.50:1234/v1",
|
||||
"api_key": "local-dev-key",
|
||||
},
|
||||
)
|
||||
builder.build_index("./indexes/my-notes", chunks)
|
||||
```
|
||||
|
||||
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||
|
||||
## Index Selection: Matching Your Scale
|
||||
|
||||
### HNSW (Hierarchical Navigable Small World)
|
||||
@@ -455,5 +380,5 @@ Conclusion:
|
||||
|
||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||
- [DiskANN Original Paper](https://suhasjs.github.io/files/diskann_neurips19.pdf)
|
||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||
|
||||
@@ -43,11 +43,7 @@ from apps.chunking import create_text_chunks
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
DEFAULT_QUERY = "What's LEANN?"
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
REPO_ROOT / "data" / "PrideandPrejudice.txt",
|
||||
]
|
||||
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
|
||||
@@ -186,7 +182,6 @@ def run_workflow(
|
||||
is_recompute: bool,
|
||||
query: str,
|
||||
top_k: int,
|
||||
skip_search: bool,
|
||||
) -> dict[str, Any]:
|
||||
prefix = f"[{label}] " if label else ""
|
||||
|
||||
@@ -203,15 +198,12 @@ def run_workflow(
|
||||
)
|
||||
|
||||
initial_size = index_file_size(index_path)
|
||||
if not skip_search:
|
||||
before_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
else:
|
||||
before_results = None
|
||||
before_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
|
||||
print(f"\n{prefix}Updating index with additional passages...")
|
||||
update_index(
|
||||
@@ -223,23 +215,20 @@ def run_workflow(
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
|
||||
if not skip_search:
|
||||
after_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
else:
|
||||
after_results = None
|
||||
after_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
updated_size = index_file_size(index_path)
|
||||
|
||||
return {
|
||||
"initial_size": initial_size,
|
||||
"updated_size": updated_size,
|
||||
"delta": updated_size - initial_size,
|
||||
"before_results": before_results if not skip_search else None,
|
||||
"after_results": after_results if not skip_search else None,
|
||||
"before_results": before_results,
|
||||
"after_results": after_results,
|
||||
"metadata": load_metadata_snapshot(index_path),
|
||||
}
|
||||
|
||||
@@ -325,12 +314,6 @@ def main() -> None:
|
||||
action="store_false",
|
||||
help="Skip building the no-recompute baseline.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-search",
|
||||
dest="skip_search",
|
||||
action="store_true",
|
||||
help="Skip the search step.",
|
||||
)
|
||||
parser.set_defaults(compare_no_recompute=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -367,13 +350,10 @@ def main() -> None:
|
||||
is_recompute=True,
|
||||
query=args.query,
|
||||
top_k=args.top_k,
|
||||
skip_search=args.skip_search,
|
||||
)
|
||||
|
||||
if not args.skip_search:
|
||||
print_results("initial search", recompute_stats["before_results"])
|
||||
if not args.skip_search:
|
||||
print_results("after update", recompute_stats["after_results"])
|
||||
print_results("initial search", recompute_stats["before_results"])
|
||||
print_results("after update", recompute_stats["after_results"])
|
||||
print(
|
||||
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||
f" (Δ {recompute_stats['delta']})"
|
||||
@@ -398,7 +378,6 @@ def main() -> None:
|
||||
is_recompute=False,
|
||||
query=args.query,
|
||||
top_k=args.top_k,
|
||||
skip_search=args.skip_search,
|
||||
)
|
||||
|
||||
print(
|
||||
@@ -406,12 +385,8 @@ def main() -> None:
|
||||
f" (Δ {baseline_stats['delta']})"
|
||||
)
|
||||
|
||||
after_texts = (
|
||||
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
|
||||
)
|
||||
baseline_after_texts = (
|
||||
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
|
||||
)
|
||||
after_texts = [res.text for res in recompute_stats["after_results"]]
|
||||
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
|
||||
if after_texts == baseline_after_texts:
|
||||
print(
|
||||
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||
|
||||
@@ -343,8 +343,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
"full_index_prefix": full_index_prefix,
|
||||
"num_threads": self.num_threads,
|
||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||
# 1 -> initialize cache using sample_data; 2 -> ready cache without init; others disable cache
|
||||
"cache_mechanism": kwargs.get("cache_mechanism", 1),
|
||||
"cache_mechanism": 1,
|
||||
"pq_prefix": "",
|
||||
"partition_prefix": partition_prefix,
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
@@ -32,16 +32,6 @@ if not logger.handlers:
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||
try:
|
||||
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||
PROVIDER_OPTIONS = {}
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
@@ -191,12 +181,7 @@ def create_diskann_embedding_server(
|
||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||
|
||||
# Process embeddings using unified computation
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
@@ -311,12 +296,7 @@ def create_diskann_embedding_server(
|
||||
continue
|
||||
|
||||
# Process the request
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||
|
||||
# Validation
|
||||
|
||||
@@ -90,15 +90,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
|
||||
try:
|
||||
idmap_file = index_dir / f"{index_prefix}.ids.txt"
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for id_str in ids:
|
||||
f.write(str(id_str) + "\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write ID map: {e}")
|
||||
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
elif self.is_recompute:
|
||||
@@ -161,16 +152,6 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
# Load ID map if available
|
||||
self._id_map: list[str] = []
|
||||
try:
|
||||
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
|
||||
if idmap_file.exists():
|
||||
with open(idmap_file, encoding="utf-8") as f:
|
||||
self._id_map = [line.rstrip("\n") for line in f]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ID map: {e}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
@@ -269,19 +250,6 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
search_time = time.time() - search_time
|
||||
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||
if self._id_map:
|
||||
|
||||
def map_label(x: int) -> str:
|
||||
if 0 <= x < len(self._id_map):
|
||||
return self._id_map[x]
|
||||
return str(x)
|
||||
|
||||
string_labels = [
|
||||
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
else:
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
@@ -45,15 +45,6 @@ if log_path:
|
||||
|
||||
logger.propagate = False
|
||||
|
||||
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||
try:
|
||||
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||
PROVIDER_OPTIONS = {}
|
||||
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
@@ -114,35 +105,6 @@ def create_hnsw_embedding_server(
|
||||
embedding_dim = 0
|
||||
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||
|
||||
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
|
||||
id_map: list[str] = []
|
||||
try:
|
||||
meta_path = Path(passages_file)
|
||||
base = meta_path.name
|
||||
if base.endswith(".meta.json"):
|
||||
base = base[: -len(".meta.json")] # e.g., laion_index.leann
|
||||
if base.endswith(".leann"):
|
||||
base = base[: -len(".leann")] # e.g., laion_index
|
||||
idmap_file = meta_path.parent / f"{base}.ids.txt"
|
||||
if idmap_file.exists():
|
||||
with open(idmap_file, encoding="utf-8") as f:
|
||||
id_map = [line.rstrip("\n") for line in f]
|
||||
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
|
||||
else:
|
||||
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ID map: {e}")
|
||||
|
||||
def _map_node_id(nid) -> str:
|
||||
try:
|
||||
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
|
||||
idx = int(nid)
|
||||
if 0 <= idx < len(id_map):
|
||||
return id_map[idx]
|
||||
except Exception:
|
||||
pass
|
||||
return str(nid)
|
||||
|
||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||
|
||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||
@@ -189,12 +151,7 @@ def create_hnsw_embedding_server(
|
||||
):
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request)
|
||||
embeddings = compute_embeddings(
|
||||
request,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
@@ -224,14 +181,13 @@ def create_hnsw_embedding_server(
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
except Exception as e:
|
||||
@@ -244,10 +200,7 @@ def create_hnsw_embedding_server(
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
texts, model_name, mode=embedding_mode
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
@@ -298,14 +251,13 @@ def create_hnsw_embedding_server(
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
@@ -313,12 +265,7 @@ def create_hnsw_embedding_server(
|
||||
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 5952745237...1d51f0c074
@@ -18,16 +18,14 @@ dependencies = [
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
"torch>=2.0.0",
|
||||
"sentence-transformers>=3.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"llama-index-readers-file>=0.4.0", # Essential for document reading
|
||||
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
|
||||
"python-dotenv>=1.0.0",
|
||||
"openai>=1.0.0",
|
||||
"huggingface-hub>=0.20.0",
|
||||
# Keep transformers below 4.46: 4.46.0 adds Python 3.10-only return type syntax and
|
||||
# breaks Python 3.9 environments.
|
||||
"transformers>=4.30.0,<4.46",
|
||||
"transformers>=4.30.0",
|
||||
"requests>=2.25.0",
|
||||
"accelerate>=0.20.0",
|
||||
"PyPDF2>=3.0.0",
|
||||
@@ -42,7 +40,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
colab = [
|
||||
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
|
||||
"transformers>=4.30.0,<4.46", # 4.46.0 switches to PEP 604 typing (int | None), breaks Py3.9
|
||||
"transformers>=4.30.0,<5.0.0", # Limit transformers version
|
||||
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
|
||||
]
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ with the correct, original embedding logic from the user's reference code.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import subprocess
|
||||
@@ -18,11 +17,9 @@ from typing import Any, Literal, Optional, Union
|
||||
import numpy as np
|
||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||
|
||||
from leann.interactive_utils import create_api_session
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
|
||||
from .chat import get_llm
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .metadata_filter import MetadataFilterEngine
|
||||
from .registry import BACKEND_REGISTRY
|
||||
@@ -42,7 +39,6 @@ def compute_embeddings(
|
||||
use_server: bool = True,
|
||||
port: Optional[int] = None,
|
||||
is_build=False,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -76,7 +72,6 @@ def compute_embeddings(
|
||||
model_name,
|
||||
mode=mode,
|
||||
is_build=is_build,
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -283,7 +278,6 @@ class LeannBuilder:
|
||||
embedding_model: str = "facebook/contriever",
|
||||
dimensions: Optional[int] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
embedding_options: Optional[dict[str, Any]] = None,
|
||||
**backend_kwargs,
|
||||
):
|
||||
self.backend_name = backend_name
|
||||
@@ -306,7 +300,6 @@ class LeannBuilder:
|
||||
self.embedding_model = embedding_model
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
self.embedding_options = embedding_options or {}
|
||||
|
||||
# Check if we need to use cosine distance for normalized embeddings
|
||||
normalized_embeddings_models = {
|
||||
@@ -414,7 +407,6 @@ class LeannBuilder:
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
provider_options=self.embedding_options,
|
||||
)[0]
|
||||
)
|
||||
path = Path(index_path)
|
||||
@@ -454,20 +446,8 @@ class LeannBuilder:
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
is_build=True,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
|
||||
try:
|
||||
idmap_file = (
|
||||
index_dir
|
||||
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
|
||||
)
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for sid in string_ids:
|
||||
f.write(str(sid) + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
||||
@@ -492,9 +472,6 @@ class LeannBuilder:
|
||||
],
|
||||
}
|
||||
|
||||
if self.embedding_options:
|
||||
meta_data["embedding_options"] = self.embedding_options
|
||||
|
||||
# Add storage status flags for HNSW backend
|
||||
if self.backend_name == "hnsw":
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
@@ -587,17 +564,6 @@ class LeannBuilder:
|
||||
|
||||
# Build the vector index using precomputed embeddings
|
||||
string_ids = [str(id_val) for id_val in ids]
|
||||
# Persist ID map (order == embeddings order)
|
||||
try:
|
||||
idmap_file = (
|
||||
index_dir
|
||||
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
|
||||
)
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for sid in string_ids:
|
||||
f.write(str(sid) + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
builder_instance.build(embeddings, string_ids, index_path)
|
||||
@@ -626,9 +592,6 @@ class LeannBuilder:
|
||||
"embeddings_source": str(embeddings_file),
|
||||
}
|
||||
|
||||
if self.embedding_options:
|
||||
meta_data["embedding_options"] = self.embedding_options
|
||||
|
||||
# Add storage status flags for HNSW backend
|
||||
if self.backend_name == "hnsw":
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
@@ -710,7 +673,6 @@ class LeannBuilder:
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
is_build=True,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
|
||||
embedding_dim = embeddings.shape[1]
|
||||
@@ -731,7 +693,6 @@ class LeannBuilder:
|
||||
index = faiss.read_index(str(index_file))
|
||||
if hasattr(index, "is_recompute"):
|
||||
index.is_recompute = needs_recompute
|
||||
print(f"index.is_recompute: {index.is_recompute}")
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
@@ -739,112 +700,37 @@ class LeannBuilder:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
# Faiss expects storage.ntotal to reflect the existing graph's
|
||||
# population (even if the vectors themselves were pruned from disk
|
||||
# for recompute mode). When we attach a fresh IndexFlat here its
|
||||
# ntotal starts at zero, which later causes IndexHNSW::add to
|
||||
# believe new "preset" levels were provided and trips the
|
||||
# `n0 + n == levels.size()` assertion. Seed the temporary storage
|
||||
# with the current ntotal so Faiss maintains the proper offset for
|
||||
# incoming vectors.
|
||||
try:
|
||||
storage_index.ntotal = index.ntotal
|
||||
except AttributeError:
|
||||
# Older Faiss builds may not expose ntotal as a writable
|
||||
# attribute; in that case we fall back to the default behaviour.
|
||||
pass
|
||||
if index.d != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
||||
)
|
||||
|
||||
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
|
||||
passage_provider_options = meta.get("embedding_options", self.embedding_options)
|
||||
|
||||
base_id = index.ntotal
|
||||
for offset, chunk in enumerate(valid_chunks):
|
||||
new_id = str(base_id + offset)
|
||||
chunk.setdefault("metadata", {})["id"] = new_id
|
||||
chunk["id"] = new_id
|
||||
|
||||
# Append passages/offsets before we attempt index.add so the ZMQ server
|
||||
# can resolve newly assigned IDs during recompute. Keep rollback hooks
|
||||
# so we can restore files if the update fails mid-way.
|
||||
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||
offset_map_backup = offset_map.copy()
|
||||
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
try:
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for chunk in valid_chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for chunk in valid_chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
server_manager: Optional[EmbeddingServerManager] = None
|
||||
server_started = False
|
||||
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
|
||||
|
||||
try:
|
||||
if needs_recompute:
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
server_started, actual_port = server_manager.start_server(
|
||||
port=requested_zmq_port,
|
||||
model_name=self.embedding_model,
|
||||
embedding_mode=passage_meta_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=distance_metric,
|
||||
provider_options=passage_provider_options,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(
|
||||
"Failed to start HNSW embedding server for recompute update."
|
||||
)
|
||||
if actual_port != requested_zmq_port:
|
||||
logger.warning(
|
||||
"Embedding server started on port %s instead of requested %s. "
|
||||
"Using reassigned port.",
|
||||
actual_port,
|
||||
requested_zmq_port,
|
||||
)
|
||||
try:
|
||||
index.hnsw.zmq_port = actual_port
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if needs_recompute:
|
||||
for i in range(embeddings.shape[0]):
|
||||
print(f"add {i} embeddings")
|
||||
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||
else:
|
||||
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||
faiss.write_index(index, str(index_file))
|
||||
finally:
|
||||
if server_started and server_manager is not None:
|
||||
server_manager.stop_server()
|
||||
|
||||
except Exception:
|
||||
# Roll back appended passages/offset map to keep files consistent.
|
||||
if passages_file.exists():
|
||||
with open(passages_file, "rb+") as f:
|
||||
f.truncate(rollback_passages_size)
|
||||
offset_map = offset_map_backup
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
raise
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
@@ -885,7 +771,6 @@ class LeannSearcher:
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
self.embedding_options = self.meta_data.get("embedding_options", {})
|
||||
# Delegate portability handling to PassageManager
|
||||
self.passage_manager = PassageManager(
|
||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||
@@ -897,8 +782,6 @@ class LeannSearcher:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
if self.embedding_options:
|
||||
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
@@ -1243,14 +1126,19 @@ class LeannChat:
|
||||
return ans
|
||||
|
||||
def start_interactive(self):
|
||||
"""Start interactive chat session."""
|
||||
session = create_api_session()
|
||||
|
||||
def handle_query(user_input: str):
|
||||
response = self.ask(user_input)
|
||||
print(f"Leann: {response}")
|
||||
|
||||
session.run_interactive_loop(handle_query)
|
||||
print("\nLeann Chat started (type 'quit' to exit)")
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
if user_input.lower() in ["quit", "exit"]:
|
||||
break
|
||||
if not user_input:
|
||||
continue
|
||||
response = self.ask(user_input)
|
||||
print(f"Leann: {response}")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
def cleanup(self):
|
||||
"""Explicitly cleanup embedding server resources.
|
||||
|
||||
@@ -12,8 +12,6 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -312,12 +310,11 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||
|
||||
|
||||
def validate_model_and_suggest(
|
||||
model_name: str, llm_type: str, host: Optional[str] = None
|
||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||
) -> Optional[str]:
|
||||
"""Validate model name and provide suggestions if invalid"""
|
||||
if llm_type == "ollama":
|
||||
resolved_host = resolve_ollama_host(host)
|
||||
available_models = check_ollama_models(resolved_host)
|
||||
available_models = check_ollama_models(host)
|
||||
if available_models and model_name not in available_models:
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
|
||||
@@ -460,19 +457,19 @@ class LLMInterface(ABC):
|
||||
class OllamaChat(LLMInterface):
|
||||
"""LLM interface for Ollama models."""
|
||||
|
||||
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
|
||||
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
||||
self.model = model
|
||||
self.host = resolve_ollama_host(host)
|
||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
|
||||
self.host = host
|
||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
||||
try:
|
||||
import requests
|
||||
|
||||
# Check if the Ollama server is responsive
|
||||
if self.host:
|
||||
requests.get(self.host)
|
||||
if host:
|
||||
requests.get(host)
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model, "ollama", self.host)
|
||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
||||
if model_error:
|
||||
raise ValueError(model_error)
|
||||
|
||||
@@ -481,11 +478,9 @@ class OllamaChat(LLMInterface):
|
||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(
|
||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||
)
|
||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
||||
raise ConnectionError(
|
||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||
)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
@@ -546,30 +541,11 @@ class OllamaChat(LLMInterface):
|
||||
|
||||
|
||||
class HFChat(LLMInterface):
|
||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates.
|
||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the Hugging Face model to load.
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models as this can pose
|
||||
a security risk if the model repository is compromised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat", trust_remote_code: bool = False
|
||||
):
|
||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||
|
||||
# Security warning when trust_remote_code is enabled
|
||||
if trust_remote_code:
|
||||
logger.warning(
|
||||
"SECURITY WARNING: trust_remote_code=True allows execution of arbitrary code from the model repository. "
|
||||
"Only enable this for models from trusted sources. This creates a potential security risk if the model "
|
||||
"repository is compromised."
|
||||
)
|
||||
|
||||
self.trust_remote_code = trust_remote_code
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model_name, "hf")
|
||||
if model_error:
|
||||
@@ -607,16 +583,14 @@ class HFChat(LLMInterface):
|
||||
|
||||
try:
|
||||
logger.info(f"Loading tokenizer for {model_name}...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
logger.info(f"Loading model {model_name}...")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
logger.info(f"Successfully loaded {model_name}")
|
||||
finally:
|
||||
@@ -763,31 +737,21 @@ class GeminiChat(LLMInterface):
|
||||
class OpenAIChat(LLMInterface):
|
||||
"""LLM interface for OpenAI models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||
self.model = model
|
||||
self.base_url = resolve_openai_base_url(base_url)
|
||||
self.api_key = resolve_openai_api_key(api_key)
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
|
||||
model,
|
||||
self.base_url,
|
||||
)
|
||||
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
self.client = openai.OpenAI(api_key=self.api_key)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||
@@ -877,19 +841,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
if llm_type == "ollama":
|
||||
return OllamaChat(
|
||||
model=model or "llama3:8b",
|
||||
host=llm_config.get("host"),
|
||||
host=llm_config.get("host", "http://localhost:11434"),
|
||||
)
|
||||
elif llm_type == "hf":
|
||||
return HFChat(
|
||||
model_name=model or "deepseek-ai/deepseek-llm-7b-chat",
|
||||
trust_remote_code=llm_config.get("trust_remote_code", False),
|
||||
)
|
||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||
elif llm_type == "openai":
|
||||
return OpenAIChat(
|
||||
model=model or "gpt-4o",
|
||||
api_key=llm_config.get("api_key"),
|
||||
base_url=llm_config.get("base_url"),
|
||||
)
|
||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "gemini":
|
||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "simulated":
|
||||
|
||||
@@ -8,9 +8,7 @@ from llama_index.core.node_parser import SentenceSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .interactive_utils import create_cli_session
|
||||
from .registry import register_project_directory
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
|
||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||
@@ -125,24 +123,6 @@ Examples:
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible embedding host",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible embedding services",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||
)
|
||||
@@ -258,11 +238,6 @@ Examples:
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
ask_parser.add_argument("index_name", help="Index name")
|
||||
ask_parser.add_argument(
|
||||
"query",
|
||||
nargs="?",
|
||||
help="Question to ask (omit for prompt or when using --interactive)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
@@ -273,12 +248,7 @@ Examples:
|
||||
ask_parser.add_argument(
|
||||
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||
)
|
||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||
ask_parser.add_argument(
|
||||
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
||||
)
|
||||
@@ -307,18 +277,6 @@ Examples:
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# List command
|
||||
subparsers.add_parser("list", help="List all indexes")
|
||||
@@ -1367,20 +1325,10 @@ Examples:
|
||||
|
||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||
|
||||
embedding_options: dict[str, Any] = {}
|
||||
if args.embedding_mode == "ollama":
|
||||
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||
elif args.embedding_mode == "openai":
|
||||
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
embedding_options=embedding_options or None,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.complexity,
|
||||
is_compact=args.compact,
|
||||
@@ -1528,49 +1476,58 @@ Examples:
|
||||
|
||||
llm_config = {"type": args.llm, "model": args.model}
|
||||
if args.llm == "ollama":
|
||||
llm_config["host"] = resolve_ollama_host(args.host)
|
||||
elif args.llm == "openai":
|
||||
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
|
||||
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||
if resolved_api_key:
|
||||
llm_config["api_key"] = resolved_api_key
|
||||
llm_config["host"] = args.host
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
llm_kwargs: dict[str, Any] = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
def _ask_once(prompt: str) -> None:
|
||||
response = chat.ask(
|
||||
prompt,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
initial_query = (args.query or "").strip()
|
||||
|
||||
if args.interactive:
|
||||
# Create interactive session
|
||||
session = create_cli_session(index_name)
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
|
||||
if initial_query:
|
||||
_ask_once(initial_query)
|
||||
while True:
|
||||
user_input = input("\nYou: ").strip()
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
session.run_interactive_loop(_ask_once)
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
else:
|
||||
query = initial_query or input("Enter your question: ").strip()
|
||||
if not query:
|
||||
print("No question provided. Exiting.")
|
||||
return
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
_ask_once(query)
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
async def run(self, args=None):
|
||||
parser = self.create_parser()
|
||||
|
||||
@@ -7,13 +7,11 @@ Preserves all optimization parameters to ensure performance
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
# Set up logger with proper level
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
@@ -33,7 +31,6 @@ def compute_embeddings(
|
||||
adaptive_optimization: bool = True,
|
||||
manual_tokenize: bool = False,
|
||||
max_length: int = 512,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Unified embedding computation entry point
|
||||
@@ -49,8 +46,6 @@ def compute_embeddings(
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
provider_options = provider_options or {}
|
||||
|
||||
if mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
texts,
|
||||
@@ -62,21 +57,11 @@ def compute_embeddings(
|
||||
max_length=max_length,
|
||||
)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(
|
||||
texts,
|
||||
model_name,
|
||||
base_url=provider_options.get("base_url"),
|
||||
api_key=provider_options.get("api_key"),
|
||||
)
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
elif mode == "ollama":
|
||||
return compute_embeddings_ollama(
|
||||
texts,
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
host=provider_options.get("host"),
|
||||
)
|
||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
||||
elif mode == "gemini":
|
||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||
else:
|
||||
@@ -368,15 +353,12 @@ def compute_embeddings_sentence_transformers(
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
try:
|
||||
import os
|
||||
|
||||
import openai
|
||||
except ImportError as e:
|
||||
raise ImportError(f"OpenAI package not installed: {e}")
|
||||
@@ -391,18 +373,16 @@ def compute_embeddings_openai(
|
||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||
)
|
||||
|
||||
resolved_base_url = resolve_openai_base_url(base_url)
|
||||
resolved_api_key = resolve_openai_api_key(api_key)
|
||||
|
||||
if not resolved_api_key:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
# Cache OpenAI client
|
||||
cache_key = f"openai_client::{resolved_base_url}"
|
||||
cache_key = "openai_client"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
_model_cache[cache_key] = client
|
||||
logger.info("OpenAI client cached")
|
||||
|
||||
@@ -527,10 +507,7 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
||||
|
||||
|
||||
def compute_embeddings_ollama(
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
is_build: bool = False,
|
||||
host: Optional[str] = None,
|
||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API with simplified batch processing.
|
||||
@@ -541,7 +518,7 @@ def compute_embeddings_ollama(
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||
host: Ollama host URL (default: http://localhost:11434)
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
@@ -556,19 +533,17 @@ def compute_embeddings_ollama(
|
||||
if not texts:
|
||||
raise ValueError("Cannot compute embeddings for empty text list")
|
||||
|
||||
resolved_host = resolve_ollama_host(host)
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
|
||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# Check if Ollama is running
|
||||
try:
|
||||
response = requests.get(f"{resolved_host}/api/version", timeout=5)
|
||||
response = requests.get(f"{host}/api/version", timeout=5)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.ConnectionError:
|
||||
error_msg = (
|
||||
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
|
||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
||||
"Please ensure Ollama is running:\n"
|
||||
" • macOS/Linux: ollama serve\n"
|
||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||
@@ -580,7 +555,7 @@ def compute_embeddings_ollama(
|
||||
|
||||
# Check if model exists and provide helpful suggestions
|
||||
try:
|
||||
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
|
||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
||||
response.raise_for_status()
|
||||
models = response.json()
|
||||
model_names = [model["name"] for model in models.get("models", [])]
|
||||
@@ -643,9 +618,7 @@ def compute_embeddings_ollama(
|
||||
# Verify the model supports embeddings by testing it
|
||||
try:
|
||||
test_response = requests.post(
|
||||
f"{resolved_host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": "test"},
|
||||
timeout=10,
|
||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
||||
)
|
||||
if test_response.status_code != 200:
|
||||
error_msg = (
|
||||
@@ -692,7 +665,7 @@ def compute_embeddings_ollama(
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{resolved_host}/api/embeddings",
|
||||
f"{host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": truncated_text},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
@@ -9,8 +8,6 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .settings import encode_provider_options
|
||||
|
||||
# Lightweight, self-contained server manager with no cross-process inspection
|
||||
|
||||
# Set up logging based on environment variable
|
||||
@@ -49,85 +46,6 @@ def _check_port(port: int) -> bool:
|
||||
# Note: All cross-process scanning helpers removed for simplicity
|
||||
|
||||
|
||||
def _safe_resolve(path: Path) -> str:
|
||||
"""Resolve paths safely even if the target does not yet exist."""
|
||||
try:
|
||||
return str(path.resolve(strict=False))
|
||||
except Exception:
|
||||
return str(path)
|
||||
|
||||
|
||||
def _safe_stat_signature(path: Path) -> dict:
|
||||
"""Return a lightweight signature describing the current state of a path."""
|
||||
signature: dict[str, object] = {"path": _safe_resolve(path)}
|
||||
try:
|
||||
stat = path.stat()
|
||||
except FileNotFoundError:
|
||||
signature["missing"] = True
|
||||
except Exception as exc: # pragma: no cover - unexpected filesystem errors
|
||||
signature["error"] = str(exc)
|
||||
else:
|
||||
signature["mtime_ns"] = stat.st_mtime_ns
|
||||
signature["size"] = stat.st_size
|
||||
return signature
|
||||
|
||||
|
||||
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
|
||||
"""Collect modification signatures for metadata and referenced passage files."""
|
||||
if not passages_file:
|
||||
return None
|
||||
|
||||
meta_path = Path(passages_file)
|
||||
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
|
||||
|
||||
try:
|
||||
with meta_path.open(encoding="utf-8") as fh:
|
||||
meta = json.load(fh)
|
||||
except FileNotFoundError:
|
||||
signature["meta_missing"] = True
|
||||
signature["sources"] = []
|
||||
return signature
|
||||
except json.JSONDecodeError as exc:
|
||||
signature["meta_error"] = f"json_error:{exc}"
|
||||
signature["sources"] = []
|
||||
return signature
|
||||
except Exception as exc: # pragma: no cover - unexpected errors
|
||||
signature["meta_error"] = str(exc)
|
||||
signature["sources"] = []
|
||||
return signature
|
||||
|
||||
base_dir = meta_path.parent
|
||||
seen_paths: set[str] = set()
|
||||
source_signatures: list[dict[str, object]] = []
|
||||
|
||||
for source in meta.get("passage_sources", []):
|
||||
for key, kind in (
|
||||
("path", "passages"),
|
||||
("path_relative", "passages"),
|
||||
("index_path", "index"),
|
||||
("index_path_relative", "index"),
|
||||
):
|
||||
raw_path = source.get(key)
|
||||
if not raw_path:
|
||||
continue
|
||||
candidate = Path(raw_path)
|
||||
if not candidate.is_absolute():
|
||||
candidate = base_dir / candidate
|
||||
resolved = _safe_resolve(candidate)
|
||||
if resolved in seen_paths:
|
||||
continue
|
||||
seen_paths.add(resolved)
|
||||
sig = _safe_stat_signature(candidate)
|
||||
sig["kind"] = kind
|
||||
source_signatures.append(sig)
|
||||
|
||||
signature["sources"] = source_signatures
|
||||
return signature
|
||||
|
||||
|
||||
# Note: All cross-process scanning helpers removed for simplicity
|
||||
|
||||
|
||||
class EmbeddingServerManager:
|
||||
"""
|
||||
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||
@@ -164,42 +82,16 @@ class EmbeddingServerManager:
|
||||
) -> tuple[bool, int]:
|
||||
"""Start the embedding server."""
|
||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||
provider_options = kwargs.pop("provider_options", None)
|
||||
passages_file = kwargs.get("passages_file", "")
|
||||
|
||||
config_signature = self._build_config_signature(
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
provider_options=provider_options,
|
||||
passages_file=passages_file,
|
||||
)
|
||||
|
||||
# If this manager already has a live server, just reuse it
|
||||
if (
|
||||
self.server_process
|
||||
and self.server_process.poll() is None
|
||||
and self.server_port
|
||||
and self._server_config == config_signature
|
||||
):
|
||||
if self.server_process and self.server_process.poll() is None and self.server_port:
|
||||
logger.info("Reusing in-process server")
|
||||
return True, self.server_port
|
||||
|
||||
# Configuration changed, stop existing server before starting a new one
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
logger.info("Existing server configuration differs; restarting embedding server")
|
||||
self.stop_server()
|
||||
|
||||
# For Colab environment, use a different strategy
|
||||
if _is_colab_environment():
|
||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||
return self._start_server_colab(
|
||||
port,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
config_signature=config_signature,
|
||||
provider_options=provider_options,
|
||||
**kwargs,
|
||||
)
|
||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
# Always pick a fresh available port
|
||||
try:
|
||||
@@ -209,40 +101,13 @@ class EmbeddingServerManager:
|
||||
return False, port
|
||||
|
||||
# Start a new server
|
||||
return self._start_new_server(
|
||||
actual_port,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
provider_options=provider_options,
|
||||
config_signature=config_signature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _build_config_signature(
|
||||
self,
|
||||
*,
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
provider_options: Optional[dict],
|
||||
passages_file: Optional[str],
|
||||
) -> dict:
|
||||
"""Create a signature describing the current server configuration."""
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"passages_file": passages_file or "",
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
"passages_signature": _build_passages_signature(passages_file),
|
||||
}
|
||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
def _start_server_colab(
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
*,
|
||||
config_signature: Optional[dict] = None,
|
||||
provider_options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
"""Start server with Colab-specific configuration."""
|
||||
@@ -260,21 +125,8 @@ class EmbeddingServerManager:
|
||||
|
||||
try:
|
||||
# In Colab, we'll use a more direct approach
|
||||
self._launch_server_process_colab(
|
||||
command,
|
||||
actual_port,
|
||||
provider_options=provider_options,
|
||||
config_signature=config_signature,
|
||||
)
|
||||
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||
if started:
|
||||
self._server_config = config_signature or {
|
||||
"model_name": model_name,
|
||||
"passages_file": kwargs.get("passages_file", ""),
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
return started, ready_port
|
||||
self._launch_server_process_colab(command, actual_port)
|
||||
return self._wait_for_server_ready_colab(actual_port)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||
return False, actual_port
|
||||
@@ -282,13 +134,7 @@ class EmbeddingServerManager:
|
||||
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
||||
|
||||
def _start_new_server(
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
provider_options: Optional[dict] = None,
|
||||
config_signature: Optional[dict] = None,
|
||||
**kwargs,
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> tuple[bool, int]:
|
||||
"""Start a new embedding server on the given port."""
|
||||
logger.info(f"Starting embedding server on port {port}...")
|
||||
@@ -296,21 +142,8 @@ class EmbeddingServerManager:
|
||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
try:
|
||||
self._launch_server_process(
|
||||
command,
|
||||
port,
|
||||
provider_options=provider_options,
|
||||
config_signature=config_signature,
|
||||
)
|
||||
started, ready_port = self._wait_for_server_ready(port)
|
||||
if started:
|
||||
self._server_config = config_signature or {
|
||||
"model_name": model_name,
|
||||
"passages_file": kwargs.get("passages_file", ""),
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
return started, ready_port
|
||||
self._launch_server_process(command, port)
|
||||
return self._wait_for_server_ready(port)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start embedding server: {e}")
|
||||
return False, port
|
||||
@@ -340,14 +173,7 @@ class EmbeddingServerManager:
|
||||
|
||||
return command
|
||||
|
||||
def _launch_server_process(
|
||||
self,
|
||||
command: list,
|
||||
port: int,
|
||||
*,
|
||||
provider_options: Optional[dict] = None,
|
||||
config_signature: Optional[dict] = None,
|
||||
) -> None:
|
||||
def _launch_server_process(self, command: list, port: int) -> None:
|
||||
"""Launch the server process."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
logger.info(f"Command: {' '.join(command)}")
|
||||
@@ -367,43 +193,32 @@ class EmbeddingServerManager:
|
||||
|
||||
# Start embedding server subprocess
|
||||
logger.info(f"Starting server process with command: {' '.join(command)}")
|
||||
env = os.environ.copy()
|
||||
encoded_options = encode_provider_options(provider_options)
|
||||
if encoded_options:
|
||||
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=stdout_target,
|
||||
stderr=stderr_target,
|
||||
env=env,
|
||||
)
|
||||
self.server_port = port
|
||||
# Record config for in-process reuse (best effort; refined later when ready)
|
||||
if config_signature is not None:
|
||||
self._server_config = config_signature
|
||||
else: # Fallback for unexpected code paths
|
||||
try:
|
||||
self._server_config = {
|
||||
"model_name": command[command.index("--model-name") + 1]
|
||||
if "--model-name" in command
|
||||
else "",
|
||||
"passages_file": command[command.index("--passages-file") + 1]
|
||||
if "--passages-file" in command
|
||||
else "",
|
||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||
if "--embedding-mode" in command
|
||||
else "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
except Exception:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
# Record config for in-process reuse
|
||||
try:
|
||||
self._server_config = {
|
||||
"model_name": command[command.index("--model-name") + 1]
|
||||
if "--model-name" in command
|
||||
else "",
|
||||
"passages_file": command[command.index("--passages-file") + 1]
|
||||
if "--passages-file" in command
|
||||
else "",
|
||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||
if "--embedding-mode" in command
|
||||
else "sentence-transformers",
|
||||
}
|
||||
except Exception:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
}
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
# Register atexit callback only when we actually start a process
|
||||
@@ -507,29 +322,16 @@ class EmbeddingServerManager:
|
||||
# Removed: cross-process adoption no longer supported
|
||||
return
|
||||
|
||||
def _launch_server_process_colab(
|
||||
self,
|
||||
command: list,
|
||||
port: int,
|
||||
*,
|
||||
provider_options: Optional[dict] = None,
|
||||
config_signature: Optional[dict] = None,
|
||||
) -> None:
|
||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||
"""Launch the server process with Colab-specific settings."""
|
||||
logger.info(f"Colab Command: {' '.join(command)}")
|
||||
|
||||
# In Colab, we need to be more careful about process management
|
||||
env = os.environ.copy()
|
||||
encoded_options = encode_provider_options(provider_options)
|
||||
if encoded_options:
|
||||
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
self.server_port = port
|
||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||
@@ -539,15 +341,11 @@ class EmbeddingServerManager:
|
||||
atexit.register(self._finalize_process)
|
||||
self._atexit_registered = True
|
||||
# Record config for in-process reuse is best-effort in Colab mode
|
||||
if config_signature is not None:
|
||||
self._server_config = config_signature
|
||||
else:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
}
|
||||
|
||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
"""
|
||||
Interactive session utilities for LEANN applications.
|
||||
|
||||
Provides shared readline functionality and command handling across
|
||||
CLI, API, and RAG example interactive modes.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
# Try to import readline with fallback for Windows
|
||||
try:
|
||||
import readline
|
||||
|
||||
HAS_READLINE = True
|
||||
except ImportError:
|
||||
# Windows doesn't have readline by default
|
||||
HAS_READLINE = False
|
||||
readline = None
|
||||
|
||||
|
||||
class InteractiveSession:
|
||||
"""Manages interactive session with optional readline support and common commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_name: str,
|
||||
prompt: str = "You: ",
|
||||
welcome_message: str = "",
|
||||
):
|
||||
"""
|
||||
Initialize interactive session with optional readline support.
|
||||
|
||||
Args:
|
||||
history_name: Name for history file (e.g., "cli", "api_chat")
|
||||
(ignored if readline not available)
|
||||
prompt: Input prompt to display
|
||||
welcome_message: Message to show when starting session
|
||||
|
||||
Note:
|
||||
On systems without readline (e.g., Windows), falls back to basic input()
|
||||
with limited functionality (no history, no line editing).
|
||||
"""
|
||||
self.history_name = history_name
|
||||
self.prompt = prompt
|
||||
self.welcome_message = welcome_message
|
||||
self._setup_complete = False
|
||||
|
||||
def setup_readline(self):
|
||||
"""Setup readline with history support (if available)."""
|
||||
if self._setup_complete:
|
||||
return
|
||||
|
||||
if not HAS_READLINE:
|
||||
# Readline not available (likely Windows), skip setup
|
||||
self._setup_complete = True
|
||||
return
|
||||
|
||||
# History file setup
|
||||
history_dir = Path.home() / ".leann" / "history"
|
||||
history_dir.mkdir(parents=True, exist_ok=True)
|
||||
history_file = history_dir / f"{self.history_name}.history"
|
||||
|
||||
# Load history if exists
|
||||
try:
|
||||
readline.read_history_file(str(history_file))
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# Save history on exit
|
||||
atexit.register(readline.write_history_file, str(history_file))
|
||||
|
||||
# Optional: Enable vi editing mode (commented out by default)
|
||||
# readline.parse_and_bind("set editing-mode vi")
|
||||
|
||||
self._setup_complete = True
|
||||
|
||||
def _show_help(self):
|
||||
"""Show available commands."""
|
||||
print("Commands:")
|
||||
print(" quit/exit/q - Exit the chat")
|
||||
print(" help - Show this help message")
|
||||
print(" clear - Clear screen")
|
||||
print(" history - Show command history")
|
||||
|
||||
def _show_history(self):
|
||||
"""Show command history."""
|
||||
if not HAS_READLINE:
|
||||
print(" History not available (readline not supported on this system)")
|
||||
return
|
||||
|
||||
history_length = readline.get_current_history_length()
|
||||
if history_length == 0:
|
||||
print(" No history available")
|
||||
return
|
||||
|
||||
for i in range(history_length):
|
||||
item = readline.get_history_item(i + 1)
|
||||
if item:
|
||||
print(f" {i + 1}: {item}")
|
||||
|
||||
def get_user_input(self) -> Optional[str]:
|
||||
"""
|
||||
Get user input with readline support.
|
||||
|
||||
Returns:
|
||||
User input string, or None if EOF (Ctrl+D)
|
||||
"""
|
||||
try:
|
||||
return input(self.prompt).strip()
|
||||
except KeyboardInterrupt:
|
||||
print("\n(Use 'quit' to exit)")
|
||||
return "" # Return empty string to continue
|
||||
except EOFError:
|
||||
print("\nGoodbye!")
|
||||
return None
|
||||
|
||||
def run_interactive_loop(self, handler_func: Callable[[str], None]):
|
||||
"""
|
||||
Run the interactive loop with a custom handler function.
|
||||
|
||||
Args:
|
||||
handler_func: Function to handle user input that's not a built-in command
|
||||
Should accept a string and handle the user's query
|
||||
"""
|
||||
self.setup_readline()
|
||||
|
||||
if self.welcome_message:
|
||||
print(self.welcome_message)
|
||||
|
||||
while True:
|
||||
user_input = self.get_user_input()
|
||||
|
||||
if user_input is None: # EOF (Ctrl+D)
|
||||
break
|
||||
|
||||
if not user_input: # Empty input or KeyboardInterrupt
|
||||
continue
|
||||
|
||||
# Handle built-in commands
|
||||
command = user_input.lower()
|
||||
if command in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
elif command == "help":
|
||||
self._show_help()
|
||||
elif command == "clear":
|
||||
os.system("clear" if os.name != "nt" else "cls")
|
||||
elif command == "history":
|
||||
self._show_history()
|
||||
else:
|
||||
# Regular user input - pass to handler
|
||||
try:
|
||||
handler_func(user_input)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def create_cli_session(index_name: str) -> InteractiveSession:
|
||||
"""Create an interactive session for CLI usage."""
|
||||
return InteractiveSession(
|
||||
history_name=index_name,
|
||||
prompt="\nYou: ",
|
||||
welcome_message="LEANN Assistant ready! Type 'quit' to exit, 'help' for commands\n"
|
||||
+ "=" * 40,
|
||||
)
|
||||
|
||||
|
||||
def create_api_session() -> InteractiveSession:
|
||||
"""Create an interactive session for API chat."""
|
||||
return InteractiveSession(
|
||||
history_name="api_chat",
|
||||
prompt="You: ",
|
||||
welcome_message="Leann Chat started (type 'quit' to exit, 'help' for commands)\n"
|
||||
+ "=" * 40,
|
||||
)
|
||||
|
||||
|
||||
def create_rag_session(app_name: str, data_description: str) -> InteractiveSession:
|
||||
"""Create an interactive session for RAG examples."""
|
||||
return InteractiveSession(
|
||||
history_name=f"{app_name}_rag",
|
||||
prompt="You: ",
|
||||
welcome_message=f"[Interactive Mode] Chat with your {data_description} data!\nType 'quit' or 'exit' to stop, 'help' for commands.\n"
|
||||
+ "=" * 40,
|
||||
)
|
||||
@@ -41,7 +41,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
||||
|
||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
self.embedding_options = self.meta.get("embedding_options", {})
|
||||
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name=backend_module_name,
|
||||
@@ -78,7 +77,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=distance_metric,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||
@@ -127,12 +125,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
from .embedding_compute import compute_embeddings
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
return compute_embeddings(
|
||||
[query],
|
||||
self.embedding_model,
|
||||
embedding_mode,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||
|
||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||
"""Compute embeddings using the ZMQ embedding server."""
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Runtime configuration helpers for LEANN."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
|
||||
def _clean_url(value: str) -> str:
|
||||
"""Normalize URL strings by stripping trailing slashes."""
|
||||
|
||||
return value.rstrip("/") if value else value
|
||||
|
||||
|
||||
def resolve_ollama_host(explicit: str | None = None) -> str:
|
||||
"""Resolve the Ollama-compatible endpoint to use."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_LOCAL_LLM_HOST"),
|
||||
os.getenv("LEANN_OLLAMA_HOST"),
|
||||
os.getenv("OLLAMA_HOST"),
|
||||
os.getenv("LOCAL_LLM_ENDPOINT"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_OLLAMA_HOST)
|
||||
|
||||
|
||||
def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||
"""Resolve the base URL for OpenAI-compatible services."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_OPENAI_BASE_URL"),
|
||||
os.getenv("OPENAI_BASE_URL"),
|
||||
os.getenv("LOCAL_OPENAI_BASE_URL"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||
|
||||
|
||||
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for OpenAI-compatible services."""
|
||||
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||
"""Serialize provider options for child processes."""
|
||||
|
||||
if not options:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.dumps(options)
|
||||
except (TypeError, ValueError):
|
||||
# Fall back to empty payload if serialization fails
|
||||
return None
|
||||
@@ -22,10 +22,7 @@ dependencies = [
|
||||
"sglang",
|
||||
"ollama",
|
||||
"requests>=2.25.0",
|
||||
"sentence-transformers>=3.0.0",
|
||||
# Pin transformers below 4.46: 4.46.0 introduced Python 3.10-only typing (PEP 604) and
|
||||
# breaks our Python 3.9 test matrix when pulled in by sentence-transformers.
|
||||
"transformers<4.46",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"openai>=1.0.0",
|
||||
# PDF parsing dependencies - essential for document processing
|
||||
"PyPDF2>=3.0.0",
|
||||
@@ -56,10 +53,27 @@ dependencies = [
|
||||
"tree-sitter-java>=0.20.0",
|
||||
"tree-sitter-c-sharp>=0.20.0",
|
||||
"tree-sitter-typescript>=0.20.0",
|
||||
"torchvision>=0.23.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0",
|
||||
"pytest-cov>=4.0",
|
||||
"pytest-xdist>=3.0", # For parallel test execution
|
||||
"black>=23.0",
|
||||
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
||||
"matplotlib",
|
||||
"huggingface-hub>=0.20.0",
|
||||
"pre-commit>=3.5.0",
|
||||
]
|
||||
|
||||
test = [
|
||||
"pytest>=7.0",
|
||||
"pytest-timeout>=2.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
diskann = [
|
||||
"leann-backend-diskann",
|
||||
]
|
||||
@@ -87,36 +101,10 @@ leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = tr
|
||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||
astchunk = { path = "packages/astchunk-leann", editable = true }
|
||||
|
||||
[dependency-groups]
|
||||
# Minimal lint toolchain for CI and local hooks
|
||||
lint = [
|
||||
"pre-commit>=3.5.0",
|
||||
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
||||
]
|
||||
|
||||
# Test toolchain (no heavy project runtime deps)
|
||||
test = [
|
||||
"pytest>=7.0",
|
||||
"pytest-cov>=4.0",
|
||||
"pytest-xdist>=3.0",
|
||||
"pytest-timeout>=2.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
# dependencies by apps/ should list here
|
||||
dev = [
|
||||
"matplotlib",
|
||||
"huggingface-hub>=0.20.0",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
line-length = 100
|
||||
extend-exclude = [
|
||||
"third_party",
|
||||
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py",
|
||||
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
|
||||
]
|
||||
extend-exclude = ["third_party"]
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Upload local evaluation data to Hugging Face Hub, excluding diskann_rpj_wiki.
|
||||
|
||||
Defaults:
|
||||
- repo_id: LEANN-RAG/leann-rag-evaluation-data (dataset)
|
||||
- folder_path: benchmarks/data
|
||||
- ignore_patterns: diskann_rpj_wiki/** and .cache/**
|
||||
|
||||
Requires authentication via `huggingface-cli login` or HF_TOKEN env var.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
try:
|
||||
from huggingface_hub import HfApi
|
||||
except Exception as e:
|
||||
raise SystemExit(
|
||||
"huggingface_hub is required. Install with: pip install huggingface_hub hf_transfer"
|
||||
) from e
|
||||
|
||||
|
||||
def _enable_transfer_accel_if_available() -> None:
|
||||
"""Best-effort enabling of accelerated transfers across hub versions.
|
||||
|
||||
Tries the public util if present; otherwise, falls back to env flag when
|
||||
hf_transfer is installed. Silently no-ops if unavailable.
|
||||
"""
|
||||
try:
|
||||
# Newer huggingface_hub exposes this under utils
|
||||
from huggingface_hub.utils import hf_hub_enable_hf_transfer # type: ignore
|
||||
|
||||
hf_hub_enable_hf_transfer()
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# If hf_transfer is installed, set env flag recognized by the hub
|
||||
import hf_transfer # noqa: F401
|
||||
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
except Exception:
|
||||
# Acceleration not available; proceed without it
|
||||
pass
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Upload local data to HF, excluding diskann_rpj_wiki")
|
||||
p.add_argument(
|
||||
"--repo-id",
|
||||
default="LEANN-RAG/leann-rag-evaluation-data",
|
||||
help="Target dataset repo id (namespace/name)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--folder-path",
|
||||
default="benchmarks/data",
|
||||
help="Local folder to upload (default: benchmarks/data)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--ignore",
|
||||
default=["diskann_rpj_wiki/**", ".cache/**"],
|
||||
nargs="+",
|
||||
help="Glob patterns to ignore (space-separated)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--allow",
|
||||
default=["**"],
|
||||
nargs="+",
|
||||
help="Glob patterns to allow (space-separated). Defaults to everything.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--message",
|
||||
default="sync local data (exclude diskann_rpj_wiki)",
|
||||
help="Commit message",
|
||||
)
|
||||
p.add_argument(
|
||||
"--no-transfer-accel",
|
||||
action="store_true",
|
||||
help="Disable hf_transfer accelerated uploads",
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.no_transfer_accel:
|
||||
_enable_transfer_accel_if_available()
|
||||
|
||||
if not os.path.isdir(args.folder_path):
|
||||
raise SystemExit(f"Folder not found: {args.folder_path}")
|
||||
|
||||
print("Uploading to Hugging Face Hub:")
|
||||
print(f" repo_id: {args.repo_id}")
|
||||
print(" repo_type: dataset")
|
||||
print(f" folder_path: {args.folder_path}")
|
||||
print(f" allow_patterns: {args.allow}")
|
||||
print(f" ignore_patterns:{args.ignore}")
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# Perform upload. This skips unchanged files by content hash.
|
||||
api.upload_folder(
|
||||
repo_id=args.repo_id,
|
||||
repo_type="dataset",
|
||||
folder_path=args.folder_path,
|
||||
path_in_repo=".",
|
||||
allow_patterns=args.allow,
|
||||
ignore_patterns=args.ignore,
|
||||
commit_message=args.message,
|
||||
)
|
||||
|
||||
print("Upload completed (unchanged files were skipped by the Hub).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -40,8 +40,8 @@ Tests DiskANN graph partitioning functionality:
|
||||
|
||||
### Install test dependencies:
|
||||
```bash
|
||||
# Using uv dependency groups (tools only)
|
||||
uv sync --only-group test
|
||||
# Using extras
|
||||
uv pip install -e ".[test]"
|
||||
```
|
||||
|
||||
### Run all tests:
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from leann.cli import LeannCLI
|
||||
|
||||
|
||||
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
|
||||
|
||||
assert args.command == "ask"
|
||||
assert args.index_name == "my-docs"
|
||||
assert args.query == "Where are prompts configured?"
|
||||
@@ -1,137 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
|
||||
|
||||
class DummyProcess:
|
||||
def __init__(self):
|
||||
self.pid = 12345
|
||||
self._terminated = False
|
||||
|
||||
def poll(self):
|
||||
return 0 if self._terminated else None
|
||||
|
||||
def terminate(self):
|
||||
self._terminated = True
|
||||
|
||||
def kill(self):
|
||||
self._terminated = True
|
||||
|
||||
def wait(self, timeout=None):
|
||||
self._terminated = True
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_manager(monkeypatch):
|
||||
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
|
||||
|
||||
def fake_get_available_port(start_port):
|
||||
return start_port
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_server_manager._get_available_port",
|
||||
fake_get_available_port,
|
||||
)
|
||||
|
||||
start_calls = []
|
||||
|
||||
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
|
||||
config_signature = kwargs.get("config_signature")
|
||||
start_calls.append(config_signature)
|
||||
self.server_process = DummyProcess()
|
||||
self.server_port = port
|
||||
self._server_config = config_signature
|
||||
return True, port
|
||||
|
||||
monkeypatch.setattr(
|
||||
EmbeddingServerManager,
|
||||
"_start_new_server",
|
||||
fake_start_new_server,
|
||||
)
|
||||
|
||||
# Ensure stop_server doesn't try to operate on real subprocesses
|
||||
def fake_stop_server(self):
|
||||
self.server_process = None
|
||||
self.server_port = None
|
||||
self._server_config = None
|
||||
|
||||
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
|
||||
|
||||
return manager, start_calls
|
||||
|
||||
|
||||
def _write_meta(meta_path, passages_name, index_name, total):
|
||||
meta_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"backend_name": "hnsw",
|
||||
"embedding_model": "test-model",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"dimensions": 3,
|
||||
"backend_kwargs": {},
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": passages_name,
|
||||
"index_path": index_name,
|
||||
}
|
||||
],
|
||||
"total_passages": total,
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
|
||||
manager, start_calls = embedding_manager
|
||||
|
||||
meta_path = tmp_path / "example.meta.json"
|
||||
passages_path = tmp_path / "example.passages.jsonl"
|
||||
index_path = tmp_path / "example.passages.idx"
|
||||
|
||||
passages_path.write_text("first\n", encoding="utf-8")
|
||||
index_path.write_bytes(b"index")
|
||||
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
|
||||
|
||||
# Initial start populates signature
|
||||
ok, port = manager.start_server(
|
||||
port=6000,
|
||||
model_name="test-model",
|
||||
passages_file=str(meta_path),
|
||||
)
|
||||
assert ok
|
||||
assert port == 6000
|
||||
assert len(start_calls) == 1
|
||||
|
||||
initial_signature = start_calls[0]["passages_signature"]
|
||||
|
||||
# No metadata change => reuse existing server
|
||||
ok, port_again = manager.start_server(
|
||||
port=6000,
|
||||
model_name="test-model",
|
||||
passages_file=str(meta_path),
|
||||
)
|
||||
assert ok
|
||||
assert port_again == 6000
|
||||
assert len(start_calls) == 1
|
||||
|
||||
# Modify passage data and metadata to force signature change
|
||||
time.sleep(0.01) # Ensure filesystem timestamps move forward
|
||||
passages_path.write_text("second\n", encoding="utf-8")
|
||||
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
|
||||
|
||||
ok, port_third = manager.start_server(
|
||||
port=6000,
|
||||
model_name="test-model",
|
||||
passages_file=str(meta_path),
|
||||
)
|
||||
assert ok
|
||||
assert port_third == 6000
|
||||
assert len(start_calls) == 2
|
||||
|
||||
updated_signature = start_calls[1]["passages_signature"]
|
||||
assert updated_signature != initial_signature
|
||||
Reference in New Issue
Block a user