Compare commits

..

28 Commits

Author SHA1 Message Date
Andy Lee
8e6aa34afd Exclude macos-15-intel + Python 3.13 (no PyTorch wheels available) 2025-12-25 01:22:32 +00:00
Andy Lee
5791367d13 Fix macos-15-intel deployment target
The macos-15-intel runner runs macOS 15.7, so Homebrew libraries are
built for macOS 14+. Setting MACOSX_DEPLOYMENT_TARGET=13.0 causes
delocate to fail because system libraries require newer macOS.

Fix by setting deployment target to 15.0 for macos-15-intel, matching
the actual OS version. Intel Mac users will need macOS 15+.
2025-12-24 05:35:36 +00:00
Andy Lee
674977a950 Add macOS 26 (beta) to build matrix
Add macos-26 (arm64) runner to the build matrix for testing future
macOS compatibility. This is currently a beta runner that helps ensure
wheels work on upcoming macOS versions.
2025-12-24 01:48:09 +00:00
Andy Lee
56785d30ee Add macos-15-intel for Intel Mac builds (free runner)
Use macos-15-intel (free standard runner) instead of macos-15-large
(paid). This provides Intel Mac wheel support until Aug 2027.

- MACOSX_DEPLOYMENT_TARGET=13.0 for backward compatibility
- Replaces deprecated macos-13 runner
2025-12-24 01:44:12 +00:00
Andy Lee
a73640f95e Remove Intel Mac builds (macos-15-large requires paid plan)
Intel Mac users can build from source. This avoids:
- Paid GitHub Actions runners (macos-15-large)
- Complex cross-compilation setup
2025-12-24 01:06:07 +00:00
Andy Lee
47b91f7313 Set MACOSX_DEPLOYMENT_TARGET=13.x for Intel builds
Intel Mac wheels (macos-15-large) now target macOS 13.0/13.3 for
backward compatibility, allowing macOS 13/14/15 Intel users to
install pre-built wheels.
2025-12-24 01:03:15 +00:00
Andy Lee
7601e0b112 Add macos-15-large for Intel Mac builds
Replace deprecated macos-13 with macos-15-large (x86_64 Intel)
to continue supporting Intel Mac users.
2025-12-24 01:01:08 +00:00
Andy Lee
2a22ec1b26 Remove macos-13 from CI build matrix
GitHub Actions deprecated macos-13 runner (brownout started Sept 2025,
fully retired Dec 2025). See: https://github.blog/changelog/2025-09-19-github-actions-macos-13-runner-image-is-closing-down/
2025-12-24 00:59:15 +00:00
Andy Lee
530507d39d Drop Python 3.9 support, require Python 3.10+
Python 3.9 reached end-of-life and the codebase uses PEP 604 union
type syntax (str | None) which requires Python 3.10+.

Changes:
- Remove Python 3.9 from CI build matrix
- Update requires-python to >=3.10 in all pyproject.toml files
- Update classifiers to reflect supported Python versions (3.10-3.13)
2025-12-24 00:54:44 +00:00
Andy Lee
8a2ea37871 Fix: handle dict format from create_text_chunks (introduced in PR #157)
PR #157 changed create_text_chunks() to return list[dict] instead of
list[str] to preserve metadata, but base_rag_example.py was not updated
to handle the new format. This caused all chunks to fail validation
with "All provided chunks are empty or invalid".
2025-12-23 08:50:31 +00:00
Yichuan Wang
7ddb4772c0 Feature/custom folder multi vector/ add Readme to LEANN MCP (#189)
* Add custom folder support and improve image loading for multi-vector retrieval

- Enhanced _load_images_from_dir with recursive search support and better error handling
- Added support for WebP format and RGB conversion for all image modes
- Added custom folder CLI arguments (--custom-folder, --recursive, --rebuild-index)
- Improved documentation and removed completed TODO comment

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* Format code style in leann_multi_vector.py for better readability

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* docs: polish README performance tip section

- Fix typo: 'matrilize' -> 'materialize'
- Improve clarity and formatting of --no-recompute flag explanation
- Add code block for better readability

* format

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-19 17:29:14 -08:00
aakash
a1c21adbce Move COLQWEN_GUIDE.md to docs and remove test_colqwen_reproduction.py 2025-12-19 13:57:47 -08:00
Aakash Suresh
d1b3c93a5a Merge pull request #162 from yichuan-w/feature/colqwen-integration
add ColQwen multimodal PDF retrieval integration
2025-12-19 13:53:29 -08:00
Yichuan Wang
a6ee95b18a Add custom folder support and improve image loading for multi-vector … (#188)
* Add custom folder support and improve image loading for multi-vector retrieval

- Enhanced _load_images_from_dir with recursive search support and better error handling
- Added support for WebP format and RGB conversion for all image modes
- Added custom folder CLI arguments (--custom-folder, --recursive, --rebuild-index)
- Improved documentation and removed completed TODO comment

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* Format code style in leann_multi_vector.py for better readability

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-17 01:03:45 -08:00
Alex
17cbd07b25 Add Anthropic LLM support (#185)
* Add Anthropic LLM support

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>

* Update skypilot link

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>

* Handle anthropic base_url

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>

* Address ruff format finding

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>

---------

Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>
2025-12-12 10:53:41 -08:00
Alex
3629ccf8f7 Use logger instead of print (#186)
Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>
2025-12-10 13:48:57 -08:00
aakash
0175bc9c20 docs: Add ColQwen guide to docs directory
Add COLQWEN_GUIDE.md to docs/ directory for proper documentation structure.
This file is referenced in the README and needs to be tracked in git.
2025-12-07 09:57:14 -08:00
aakash
af47dfdde7 fix: Update ColQwen guide link to docs/ directory 2025-12-06 03:33:02 -08:00
aakash
f13bd02fbd docs: Add ColQwen multimodal PDF retrieval to README
Add brief introduction and usage guide for ColQwen integration,
similar to other RAG application sections in the README.

- Quick start examples for building, searching, and interactive Q&A
- Setup instructions with prerequisites
- Model options (ColQwen2 vs ColPali)
- Link to detailed ColQwen guide
2025-12-06 03:28:08 -08:00
Yichuan Wang
a0bbf831db Add ColQwen2.5 model support and improve model selection (#183)
- Add ColQwen2.5 and ColQwen2_5_Processor imports
- Implement smart model type detection for colqwen2, colqwen2.5, and colpali
- Add task name aliases for easier benchmark invocation
- Add safe model name handling for file paths and index naming
- Support custom model paths including LoRA adapters
- Improve model choice validation and error handling

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-05 03:36:55 -08:00
aakash
86287d8832 Revert unnecessary faiss submodule update
Reset faiss submodule to match main branch to avoid unnecessary changes
2025-12-03 18:32:04 -08:00
Yichuan Wang
76cc798e3e Feat/multi vector timing and dataset improvements (#181)
* Add timing instrumentation and multi-dataset support for multi-vector retrieval

- Add timing measurements for search operations (load and core time)
- Increase embedding batch size from 1 to 32 for better performance
- Add explicit memory cleanup with del all_embeddings
- Support loading and merging multiple datasets with different splits
- Add CLI arguments for search method selection (ann/exact/exact-all)
- Auto-detect image field names across different dataset structures
- Print candidate doc counts for performance monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* update vidore

* reproduce docvqa results

* reproduce docvqa results and add debug file

* fix: format colqwen_forward.py to pass pre-commit checks

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-03 01:10:49 -08:00
Yichuan Wang
d599566fd7 Revert "[Multi-vector]Add timing instrumentation and multi-dataset support fo…" (#180)
This reverts commit 00770aebbb.
2025-12-03 01:09:39 -08:00
Yichuan Wang
00770aebbb [Multi-vector]Add timing instrumentation and multi-dataset support for multi-vector… (#161)
* Add timing instrumentation and multi-dataset support for multi-vector retrieval

- Add timing measurements for search operations (load and core time)
- Increase embedding batch size from 1 to 32 for better performance
- Add explicit memory cleanup with del all_embeddings
- Support loading and merging multiple datasets with different splits
- Add CLI arguments for search method selection (ann/exact/exact-all)
- Auto-detect image field names across different dataset structures
- Print candidate doc counts for performance monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* update vidore

* reproduce docvqa results

* reproduce docvqa results and add debug file

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-03 00:55:42 -08:00
Aakash Suresh
e268392d5b Fix: Prevent duplicate PDF processing when using --file-types .pdf (#179)
Fixes #175

Problem:
When --file-types .pdf is specified, PDFs were being processed twice:
1. Separately with PyMuPDF/pdfplumber extractors
2. Again in the 'other file types' section via SimpleDirectoryReader

This caused duplicate processing and potential conflicts.

Solution:
- Exclude .pdf from other_file_extensions when PDFs are already
  processed separately
- Only load other file types if there are extensions to process
- Prevents duplicate PDF processing

Changes:
- Added logic to filter out .pdf from code_extensions when loading
  other file types if PDFs were processed separately
- Updated SimpleDirectoryReader to use filtered extensions
- Added check to skip loading if no other extensions to process
2025-12-01 13:48:44 -08:00
aakash
13beb98164 Add CLIP-based image RAG application
- Add apps/image_rag.py for indexing and searching images using CLIP embeddings
- Supports text-based image search queries
- Uses CLIP ViT-L/14 model via sentence-transformers
- Follows the same pattern as other RAG apps in the apps directory
- Addresses feature request for CLIP support in apps (issue #94)
2025-11-17 13:52:44 -08:00
aakash
9b7353f336 Fix linting errors in colqwen_rag.py and test_colqwen_reproduction.py
- Add noqa comments for E402 errors (imports after sys.path modifications)
- Remove unused variable assignment in colqwen_rag.py
- Use importlib.util.find_spec for dependency checks instead of unused imports
- Fix import ordering in test_colqwen_reproduction.py
2025-11-11 05:12:49 -08:00
aakash
9dd0e0b26f feat: Add ColQwen multimodal PDF retrieval integration
- Add ColQwenRAG class with easy-to-use CLI for multimodal PDF retrieval
- Support for both ColQwen2 and ColPali models with automatic device selection
- MPS optimization for Apple Silicon with memory-efficient loading
- Complete pipeline: PDF→images→embeddings→HNSW index→search
- Multi-vector indexing for fine-grained document matching
- Comprehensive user guide and reproduction test script
- Resolves #119: ColQwen Doc and Support Management

Features:
- python -m apps.colqwen_rag build --pdfs ./pdfs/ --index my_index
- python -m apps.colqwen_rag search my_index "query text"
- python -m apps.colqwen_rag ask my_index --interactive
- Automatic CPU fallback for memory constraints
- Robust error handling and progress tracking
2025-11-10 13:31:58 -08:00
23 changed files with 3400 additions and 1238 deletions

View File

@@ -35,8 +35,8 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- os: ubuntu-22.04 # Note: Python 3.9 dropped - uses PEP 604 union syntax (str | None)
python: '3.9' # which requires Python 3.10+
- os: ubuntu-22.04 - os: ubuntu-22.04
python: '3.10' python: '3.10'
- os: ubuntu-22.04 - os: ubuntu-22.04
@@ -46,8 +46,6 @@ jobs:
- os: ubuntu-22.04 - os: ubuntu-22.04
python: '3.13' python: '3.13'
# ARM64 Linux builds # ARM64 Linux builds
- os: ubuntu-24.04-arm
python: '3.9'
- os: ubuntu-24.04-arm - os: ubuntu-24.04-arm
python: '3.10' python: '3.10'
- os: ubuntu-24.04-arm - os: ubuntu-24.04-arm
@@ -56,8 +54,6 @@ jobs:
python: '3.12' python: '3.12'
- os: ubuntu-24.04-arm - os: ubuntu-24.04-arm
python: '3.13' python: '3.13'
- os: macos-14
python: '3.9'
- os: macos-14 - os: macos-14
python: '3.10' python: '3.10'
- os: macos-14 - os: macos-14
@@ -66,8 +62,6 @@ jobs:
python: '3.12' python: '3.12'
- os: macos-14 - os: macos-14
python: '3.13' python: '3.13'
- os: macos-15
python: '3.9'
- os: macos-15 - os: macos-15
python: '3.10' python: '3.10'
- os: macos-15 - os: macos-15
@@ -76,16 +70,24 @@ jobs:
python: '3.12' python: '3.12'
- os: macos-15 - os: macos-15
python: '3.13' python: '3.13'
- os: macos-13 # Intel Mac builds (x86_64) - replaces deprecated macos-13
python: '3.9' # Note: Python 3.13 excluded - PyTorch has no wheels for macOS x86_64 + Python 3.13
- os: macos-13 # (PyTorch <=2.4.1 lacks cp313, PyTorch >=2.5.0 dropped Intel Mac support)
- os: macos-15-intel
python: '3.10' python: '3.10'
- os: macos-13 - os: macos-15-intel
python: '3.11' python: '3.11'
- os: macos-13 - os: macos-15-intel
python: '3.12' python: '3.12'
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility # macOS 26 (beta) - arm64
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64) - os: macos-26
python: '3.10'
- os: macos-26
python: '3.11'
- os: macos-26
python: '3.12'
- os: macos-26
python: '3.13'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@@ -204,13 +206,16 @@ jobs:
# Use system clang for better compatibility # Use system clang for better compatibility
export CC=clang export CC=clang
export CXX=clang++ export CXX=clang++
# Homebrew libraries on each macOS version require matching minimum version # Set deployment target based on runner
if [[ "${{ matrix.os }}" == "macos-13" ]]; then # macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
export MACOSX_DEPLOYMENT_TARGET=13.0 if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0 export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
export MACOSX_DEPLOYMENT_TARGET=26.0
fi fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else else
@@ -224,14 +229,16 @@ jobs:
# Use system clang for better compatibility # Use system clang for better compatibility
export CC=clang export CC=clang
export CXX=clang++ export CXX=clang++
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function # Set deployment target based on runner
# But Homebrew libraries on each macOS version require matching minimum version # macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
if [[ "${{ matrix.os }}" == "macos-13" ]]; then if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0 export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
export MACOSX_DEPLOYMENT_TARGET=26.0
fi fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else else
@@ -269,16 +276,19 @@ jobs:
if: runner.os == 'macOS' if: runner.os == 'macOS'
run: | run: |
# Determine deployment target based on runner OS # Determine deployment target based on runner OS
# Must match the Homebrew libraries for each macOS version # macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
if [[ "${{ matrix.os }}" == "macos-13" ]]; then if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
HNSW_TARGET="13.0"
DISKANN_TARGET="13.3"
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
HNSW_TARGET="14.0"
DISKANN_TARGET="14.0"
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
HNSW_TARGET="15.0" HNSW_TARGET="15.0"
DISKANN_TARGET="15.0" DISKANN_TARGET="15.0"
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
HNSW_TARGET="14.0"
DISKANN_TARGET="14.0"
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
HNSW_TARGET="15.0"
DISKANN_TARGET="15.0"
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
HNSW_TARGET="26.0"
DISKANN_TARGET="26.0"
fi fi
# Repair HNSW wheel # Repair HNSW wheel
@@ -334,12 +344,15 @@ jobs:
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')") PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
if [[ "$RUNNER_OS" == "macOS" ]]; then if [[ "$RUNNER_OS" == "macOS" ]]; then
if [[ "${{ matrix.os }}" == "macos-13" ]]; then # macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
export MACOSX_DEPLOYMENT_TARGET=13.3 if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0 export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
export MACOSX_DEPLOYMENT_TARGET=26.0
fi fi
fi fi

3
.gitignore vendored
View File

@@ -91,7 +91,8 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json *.meta.json
*.passages.json *.passages.json
*.npy
*.db
batchtest.py batchtest.py
tests/__pytest_cache__/ tests/__pytest_cache__/
tests/__pycache__/ tests/__pycache__/

View File

@@ -36,7 +36,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276) LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy. **Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md) \* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
@@ -201,7 +201,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
#### LLM Backend #### LLM Backend
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API). LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
<details> <details>
@@ -269,6 +269,7 @@ Below is a list of base URLs for common providers to get you started.
| **SiliconFlow** | `https://api.siliconflow.cn/v1` | | **SiliconFlow** | `https://api.siliconflow.cn/v1` |
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` | | **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
| **Mistral AI** | `https://api.mistral.ai/v1` | | **Mistral AI** | `https://api.mistral.ai/v1` |
| **Anthropic** | `https://api.anthropic.com/v1` |
@@ -328,7 +329,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama --embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
# LLM Parameters (Text generation models) # LLM Parameters (Text generation models)
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai) --llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct --llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models) --thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
@@ -391,6 +392,54 @@ python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authenticat
</details> </details>
### 🎨 ColQwen: Multimodal PDF Retrieval with Vision-Language Models
Search through PDFs using both text and visual understanding with ColQwen2/ColPali models. Perfect for research papers, technical documents, and any PDFs with complex layouts, figures, or diagrams.
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
```bash
# Build index from PDFs
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
# Search with text queries
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
# Interactive Q&A
python -m apps.colqwen_rag ask research_papers --interactive
```
<details>
<summary><strong>📋 Click to expand: ColQwen Setup & Usage</strong></summary>
#### Prerequisites
```bash
# Install dependencies
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
brew install poppler # macOS only, for PDF processing
```
#### Build Index
```bash
python -m apps.colqwen_rag build \
--pdfs ./pdf_directory/ \
--index my_index \
--model colqwen2 # or colpali
```
#### Search
```bash
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
```
#### Models
- **ColQwen2** (`colqwen2`): Latest vision-language model with improved performance
- **ColPali** (`colpali`): Proven multimodal retriever
For detailed usage, see the [ColQwen Guide](docs/COLQWEN_GUIDE.md).
</details>
### 📧 Your Personal Email Secretary: RAG on Apple Mail! ### 📧 Your Personal Email Secretary: RAG on Apple Mail!
> **Note:** The examples below currently support macOS only. Windows support coming soon. > **Note:** The examples below currently support macOS only. Windows support coming soon.
@@ -1057,7 +1106,7 @@ Options:
leann ask INDEX_NAME [OPTIONS] leann ask INDEX_NAME [OPTIONS]
Options: Options:
--llm {ollama,openai,hf} LLM provider (default: ollama) --llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
--model MODEL Model name (default: qwen3:8b) --model MODEL Model name (default: qwen3:8b)
--interactive Interactive chat mode --interactive Interactive chat mode
--top-k N Retrieval count (default: 20) --top-k N Retrieval count (default: 20)

View File

@@ -6,7 +6,7 @@ Provides common parameters and functionality for all RAG examples.
import argparse import argparse
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Union
import dotenv import dotenv
from leann.api import LeannBuilder, LeannChat from leann.api import LeannBuilder, LeannChat
@@ -257,8 +257,8 @@ class BaseRAGExample(ABC):
pass pass
@abstractmethod @abstractmethod
async def load_data(self, args) -> list[str]: async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
"""Load data from the source. Returns list of text chunks.""" """Load data from the source. Returns list of text chunks (strings or dicts with 'text' key)."""
pass pass
def get_llm_config(self, args) -> dict[str, Any]: def get_llm_config(self, args) -> dict[str, Any]:
@@ -282,8 +282,8 @@ class BaseRAGExample(ABC):
return config return config
async def build_index(self, args, texts: list[str]) -> str: async def build_index(self, args, texts: list[Union[str, dict[str, Any]]]) -> str:
"""Build LEANN index from texts.""" """Build LEANN index from texts (accepts strings or dicts with 'text' key)."""
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann") index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
print(f"\n[Building Index] Creating {self.name} index...") print(f"\n[Building Index] Creating {self.name} index...")
@@ -314,8 +314,14 @@ class BaseRAGExample(ABC):
batch_size = 1000 batch_size = 1000
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size] batch = texts[i : i + batch_size]
for text in batch: for item in batch:
builder.add_text(text) # Handle both dict format (from create_text_chunks) and plain strings
if isinstance(item, dict):
text = item.get("text", "")
metadata = item.get("metadata")
builder.add_text(text, metadata)
else:
builder.add_text(item)
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...") print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
print("Building index structure...") print("Building index structure...")

364
apps/colqwen_rag.py Normal file
View File

@@ -0,0 +1,364 @@
#!/usr/bin/env python3
"""
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
Usage:
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
python -m apps.colqwen_rag search my_index "How does attention work?"
python -m apps.colqwen_rag ask my_index --interactive
"""
import argparse
import os
import sys
from pathlib import Path
from typing import Optional, cast
# Add LEANN packages to path
_repo_root = Path(__file__).resolve().parents[1]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
import torch # noqa: E402
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
from pdf2image import convert_from_path # noqa: E402
from PIL import Image # noqa: E402
from torch.utils.data import DataLoader # noqa: E402
from tqdm import tqdm # noqa: E402
# Import the existing multi-vector implementation
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
from leann_multi_vector import LeannMultiVector # noqa: E402
class ColQwenRAG:
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
def __init__(self, model_type: str = "colpali"):
"""
Initialize ColQwen RAG system.
Args:
model_type: "colqwen2" or "colpali"
"""
self.model_type = model_type
self.device = self._get_device()
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
if self.device.type == "mps":
self.dtype = torch.float32
elif self.device.type == "cuda":
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
# Load model and processor with MPS-optimized settings
try:
if model_type == "colqwen2":
self.model_name = "vidore/colqwen2-v1.0"
if self.device.type == "mps":
# For MPS, load on CPU first then move to avoid memory allocation issues
self.model = ColQwen2.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map="cpu",
low_cpu_mem_usage=True,
).eval()
self.model = self.model.to(self.device)
else:
self.model = ColQwen2.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map=self.device,
low_cpu_mem_usage=True,
).eval()
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
else: # colpali
self.model_name = "vidore/colpali-v1.2"
if self.device.type == "mps":
# For MPS, load on CPU first then move to avoid memory allocation issues
self.model = ColPali.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map="cpu",
low_cpu_mem_usage=True,
).eval()
self.model = self.model.to(self.device)
else:
self.model = ColPali.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map=self.device,
low_cpu_mem_usage=True,
).eval()
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
except Exception as e:
if "memory" in str(e).lower() or "offload" in str(e).lower():
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
self.device = torch.device("cpu")
self.dtype = torch.float32
if model_type == "colqwen2":
self.model = ColQwen2.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map="cpu",
low_cpu_mem_usage=True,
).eval()
else:
self.model = ColPali.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map="cpu",
low_cpu_mem_usage=True,
).eval()
else:
raise
def _get_device(self):
"""Auto-select best available device."""
if torch.cuda.is_available():
return torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
"""
Build multimodal index from PDF files.
Args:
pdf_paths: List of PDF file paths
index_name: Name for the index
pages_dir: Directory to save page images (optional)
"""
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
# Convert PDFs to images
all_images = []
all_metadata = []
if pages_dir:
os.makedirs(pages_dir, exist_ok=True)
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
try:
images = convert_from_path(pdf_path, dpi=150)
pdf_name = Path(pdf_path).stem
for i, image in enumerate(images):
# Save image if pages_dir specified
if pages_dir:
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
image.save(image_path)
all_images.append(image)
all_metadata.append(
{
"pdf_path": pdf_path,
"pdf_name": pdf_name,
"page_number": i + 1,
"image_path": str(image_path) if pages_dir else None,
}
)
except Exception as e:
print(f"❌ Error processing {pdf_path}: {e}")
continue
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
print(f"All metadata: {all_metadata}")
# Generate embeddings
print("🧠 Generating embeddings...")
embeddings = self._embed_images(all_images)
# Build LEANN index
print("🔍 Building LEANN index...")
leann_mv = LeannMultiVector(
index_path=index_name,
dim=embeddings.shape[-1],
embedding_model_name=self.model_type,
)
# Create collection and insert data
leann_mv.create_collection()
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
data = {
"doc_id": i,
"filepath": metadata.get("image_path", ""),
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
}
leann_mv.insert(data)
# Build the index
leann_mv.create_index()
print(f"✅ Index '{index_name}' built successfully!")
return leann_mv
def search(self, index_name: str, query: str, top_k: int = 5):
"""
Search the index with a text query.
Args:
index_name: Name of the index to search
query: Text query
top_k: Number of results to return
"""
print(f"🔍 Searching '{index_name}' for: '{query}'")
# Load index
leann_mv = LeannMultiVector(
index_path=index_name,
dim=128, # Will be updated when loading
embedding_model_name=self.model_type,
)
# Generate query embedding
query_embedding = self._embed_query(query)
# Search (returns list of (score, doc_id) tuples)
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
# Display results
print(f"\n📋 Top {len(search_results)} results:")
for i, (score, doc_id) in enumerate(search_results, 1):
# Get metadata for this doc_id (we need to load the metadata)
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
return search_results
def ask(self, index_name: str, interactive: bool = False):
"""
Interactive Q&A with the indexed documents.
Args:
index_name: Name of the index to query
interactive: Whether to run in interactive mode
"""
print(f"💬 ColQwen Chat with '{index_name}'")
if interactive:
print("Type 'quit' to exit, 'help' for commands")
while True:
try:
query = input("\n🤔 Your question: ").strip()
if query.lower() in ["quit", "exit", "q"]:
break
elif query.lower() == "help":
print("Commands: quit/exit/q (exit), help (this message)")
continue
elif not query:
continue
self.search(index_name, query, top_k=3)
# TODO: Add answer generation with Qwen-VL
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
except KeyboardInterrupt:
print("\n👋 Goodbye!")
break
else:
query = input("🤔 Your question: ").strip()
if query:
self.search(index_name, query)
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
"""Generate embeddings for a list of images."""
dataset = ListDataset(images)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
embeddings = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Embedding images"):
batch_images = cast(list, batch)
batch_inputs = self.processor.process_images(batch_images).to(self.device)
batch_embeddings = self.model(**batch_inputs)
embeddings.append(batch_embeddings.cpu())
return torch.cat(embeddings, dim=0)
def _embed_query(self, query: str) -> torch.Tensor:
"""Generate embedding for a text query."""
with torch.no_grad():
query_inputs = self.processor.process_queries([query]).to(self.device)
query_embedding = self.model(**query_inputs)
return query_embedding.cpu()
def main():
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Build command
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
build_parser.add_argument("--index", required=True, help="Index name")
build_parser.add_argument(
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
)
build_parser.add_argument("--pages-dir", help="Directory to save page images")
# Search command
search_parser = subparsers.add_parser("search", help="Search the index")
search_parser.add_argument("index", help="Index name")
search_parser.add_argument("query", help="Search query")
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
search_parser.add_argument(
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
)
# Ask command
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
ask_parser.add_argument("index", help="Index name")
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
ask_parser.add_argument(
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
)
args = parser.parse_args()
if not args.command:
parser.print_help()
return
# Initialize ColQwen RAG
if args.command == "build":
colqwen = ColQwenRAG(args.model)
# Get PDF files
pdf_dir = Path(args.pdfs)
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
pdf_paths = [str(pdf_dir)]
elif pdf_dir.is_dir():
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
else:
print(f"❌ Invalid PDF path: {args.pdfs}")
return
if not pdf_paths:
print(f"❌ No PDF files found in {args.pdfs}")
return
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
elif args.command == "search":
colqwen = ColQwenRAG(args.model)
colqwen.search(args.index, args.query, args.top_k)
elif args.command == "ask":
colqwen = ColQwenRAG(args.model)
colqwen.ask(args.index, args.interactive)
if __name__ == "__main__":
main()

View File

@@ -5,6 +5,7 @@ Supports PDF, TXT, MD, and other document formats.
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Union
# Add parent directory to path for imports # Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path(__file__).parent))
@@ -51,7 +52,7 @@ class DocumentRAG(BaseRAGExample):
help="Enable AST-aware chunking for code files in the data directory", help="Enable AST-aware chunking for code files in the data directory",
) )
async def load_data(self, args) -> list[str]: async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
"""Load documents and convert to text chunks.""" """Load documents and convert to text chunks."""
print(f"Loading documents from: {args.data_dir}") print(f"Loading documents from: {args.data_dir}")
if args.file_types: if args.file_types:

218
apps/image_rag.py Normal file
View File

@@ -0,0 +1,218 @@
#!/usr/bin/env python3
"""
CLIP Image RAG Application
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
You can index a directory of images and search them using text queries.
Usage:
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
python -m apps.image_rag --image-dir ./my_images/ --interactive
"""
import argparse
import pickle
import tempfile
from pathlib import Path
import numpy as np
from PIL import Image
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from apps.base_rag_example import BaseRAGExample
class ImageRAG(BaseRAGExample):
"""
RAG application for images using CLIP embeddings.
This class provides a complete RAG pipeline for image data, including
CLIP embedding generation, indexing, and text-based image search.
"""
def __init__(self):
super().__init__(
name="Image RAG",
description="RAG application for images using CLIP embeddings",
default_index_name="image_index",
)
# Override default embedding model to use CLIP
self.embedding_model_default = "clip-ViT-L-14"
self.embedding_mode_default = "sentence-transformers"
self._image_data: list[dict] = []
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add image-specific arguments."""
image_group = parser.add_argument_group("Image Parameters")
image_group.add_argument(
"--image-dir",
type=str,
required=True,
help="Directory containing images to index",
)
image_group.add_argument(
"--image-extensions",
type=str,
nargs="+",
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
)
image_group.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size for CLIP embedding generation (default: 32)",
)
async def load_data(self, args) -> list[str]:
"""Load images, generate CLIP embeddings, and return text descriptions."""
self._image_data = self._load_images_and_embeddings(args)
return [entry["text"] for entry in self._image_data]
def _load_images_and_embeddings(self, args) -> list[dict]:
"""Helper to process images and produce embeddings/metadata."""
image_dir = Path(args.image_dir)
if not image_dir.exists():
raise ValueError(f"Image directory does not exist: {image_dir}")
print(f"📸 Loading images from {image_dir}...")
# Find all image files
image_files = []
for ext in args.image_extensions:
image_files.extend(image_dir.rglob(f"*{ext}"))
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
if not image_files:
raise ValueError(
f"No images found in {image_dir} with extensions {args.image_extensions}"
)
print(f"✅ Found {len(image_files)} images")
# Limit if max_items is set
if args.max_items > 0:
image_files = image_files[: args.max_items]
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
# Load CLIP model
print("🔍 Loading CLIP model...")
model = SentenceTransformer(self.embedding_model_default)
# Process images and generate embeddings
print("🖼️ Processing images and generating embeddings...")
image_data = []
batch_images = []
batch_paths = []
for image_path in tqdm(image_files, desc="Processing images"):
try:
image = Image.open(image_path).convert("RGB")
batch_images.append(image)
batch_paths.append(image_path)
# Process in batches
if len(batch_images) >= args.batch_size:
embeddings = model.encode(
batch_images,
convert_to_numpy=True,
normalize_embeddings=True,
batch_size=args.batch_size,
show_progress_bar=False,
)
for img_path, embedding in zip(batch_paths, embeddings):
image_data.append(
{
"text": f"Image: {img_path.name}\nPath: {img_path}",
"metadata": {
"image_path": str(img_path),
"image_name": img_path.name,
"image_dir": str(image_dir),
},
"embedding": embedding.astype(np.float32),
}
)
batch_images = []
batch_paths = []
except Exception as e:
print(f"⚠️ Failed to process {image_path}: {e}")
continue
# Process remaining images
if batch_images:
embeddings = model.encode(
batch_images,
convert_to_numpy=True,
normalize_embeddings=True,
batch_size=len(batch_images),
show_progress_bar=False,
)
for img_path, embedding in zip(batch_paths, embeddings):
image_data.append(
{
"text": f"Image: {img_path.name}\nPath: {img_path}",
"metadata": {
"image_path": str(img_path),
"image_name": img_path.name,
"image_dir": str(image_dir),
},
"embedding": embedding.astype(np.float32),
}
)
print(f"✅ Processed {len(image_data)} images")
return image_data
async def build_index(self, args, texts: list[str]) -> str:
"""Build index using pre-computed CLIP embeddings."""
from leann.api import LeannBuilder
if not self._image_data or len(self._image_data) != len(texts):
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
print("🔨 Building LEANN index with CLIP embeddings...")
builder = LeannBuilder(
backend_name=args.backend_name,
embedding_model=self.embedding_model_default,
embedding_mode=self.embedding_mode_default,
is_recompute=False,
distance_metric="cosine",
graph_degree=args.graph_degree,
build_complexity=args.build_complexity,
is_compact=not args.no_compact,
)
for text, data in zip(texts, self._image_data):
builder.add_text(text=text, metadata=data["metadata"])
ids = [str(i) for i in range(len(self._image_data))]
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
pickle.dump((ids, embeddings), f)
pkl_path = f.name
try:
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
builder.build_index_from_embeddings(index_path, pkl_path)
print(f"✅ Index built successfully at {index_path}")
return index_path
finally:
Path(pkl_path).unlink()
def main():
"""Main entry point for the image RAG application."""
import asyncio
app = ImageRAG()
asyncio.run(app.run())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python3
"""Simple test script to test colqwen2 forward pass with a single image."""
import os
import sys
from pathlib import Path
# Add the current directory to path to import leann_multi_vector
sys.path.insert(0, str(Path(__file__).parent))
import torch
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
from PIL import Image
# Ensure repo paths are importable
_ensure_repo_paths_importable(__file__)
# Set environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def create_test_image():
"""Create a simple test image."""
# Create a simple RGB image (800x600)
img = Image.new("RGB", (800, 600), color="white")
return img
def load_test_image_from_file():
"""Try to load an image from the indexes directory if available."""
# Try to find an existing image in the indexes directory
indexes_dir = Path(__file__).parent / "indexes"
# Look for images in common locations
possible_paths = [
indexes_dir / "vidore_fastplaid" / "images",
indexes_dir / "colvision_large.leann.images",
indexes_dir / "colvision.leann.images",
]
for img_dir in possible_paths:
if img_dir.exists():
# Find first image file
for ext in [".png", ".jpg", ".jpeg"]:
for img_file in img_dir.glob(f"*{ext}"):
print(f"Loading test image from: {img_file}")
return Image.open(img_file)
return None
def main():
print("=" * 60)
print("Testing ColQwen2 Forward Pass")
print("=" * 60)
# Step 1: Load or create test image
print("\n[Step 1] Loading test image...")
test_image = load_test_image_from_file()
if test_image is None:
print("No existing image found, creating a simple test image...")
test_image = create_test_image()
else:
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
# Convert to RGB if needed
if test_image.mode != "RGB":
test_image = test_image.convert("RGB")
print(f"✓ Converted to RGB: {test_image.size}")
# Step 2: Load model
print("\n[Step 2] Loading ColQwen2 model...")
try:
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
print(f"✓ Model loaded: {model_name}")
print(f"✓ Device: {device_str}, dtype: {dtype}")
# Print model info
if hasattr(model, "device"):
print(f"✓ Model device: {model.device}")
if hasattr(model, "dtype"):
print(f"✓ Model dtype: {model.dtype}")
except Exception as e:
print(f"✗ Error loading model: {e}")
import traceback
traceback.print_exc()
return
# Step 3: Test forward pass
print("\n[Step 3] Running forward pass...")
try:
# Use the _embed_images function which handles batching and forward pass
images = [test_image]
print(f"Processing {len(images)} image(s)...")
doc_vecs = _embed_images(model, processor, images)
print("✓ Forward pass completed!")
print(f"✓ Number of embeddings: {len(doc_vecs)}")
if len(doc_vecs) > 0:
emb = doc_vecs[0]
print(f"✓ Embedding shape: {emb.shape}")
print(f"✓ Embedding dtype: {emb.dtype}")
print("✓ Embedding stats:")
print(f" - Min: {emb.min().item():.4f}")
print(f" - Max: {emb.max().item():.4f}")
print(f" - Mean: {emb.mean().item():.4f}")
print(f" - Std: {emb.std().item():.4f}")
# Check for NaN or Inf
if torch.isnan(emb).any():
print("⚠ Warning: Embedding contains NaN values!")
if torch.isinf(emb).any():
print("⚠ Warning: Embedding contains Inf values!")
except Exception as e:
print(f"✗ Error during forward pass: {e}")
import traceback
traceback.print_exc()
return
print("\n" + "=" * 60)
print("Test completed successfully!")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -1,8 +1,11 @@
import concurrent.futures import concurrent.futures
import glob
import json import json
import logging
import os import os
import re import re
import sys import sys
import time
from pathlib import Path from pathlib import Path
from typing import Any, Optional, cast from typing import Any, Optional, cast
@@ -10,6 +13,8 @@ import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger(__name__)
def _ensure_repo_paths_importable(current_file: str) -> None: def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py).""" """Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
@@ -95,12 +100,63 @@ def _natural_sort_key(name: str) -> int:
return int(m.group()) if m else 0 return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]: def _load_images_from_dir(
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))] pages_dir: str, recursive: bool = False
filenames = sorted(filenames, key=_natural_sort_key) ) -> tuple[list[str], list[Image.Image]]:
filepaths = [os.path.join(pages_dir, n) for n in filenames] """
images = [Image.open(p) for p in filepaths] Load images from a directory.
return filepaths, images
Args:
pages_dir: Directory path containing images
recursive: If True, recursively search subdirectories (default: False)
Returns:
Tuple of (filepaths, images)
"""
# Supported image extensions
extensions = ("*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG", "*.webp", "*.WEBP")
if recursive:
# Recursive search
filepaths = []
for ext in extensions:
pattern = os.path.join(pages_dir, "**", ext)
filepaths.extend(glob.glob(pattern, recursive=True))
else:
# Non-recursive search (only top-level directory)
filepaths = []
for ext in extensions:
pattern = os.path.join(pages_dir, ext)
filepaths.extend(glob.glob(pattern))
# Sort files naturally
filepaths = sorted(filepaths, key=lambda x: _natural_sort_key(os.path.basename(x)))
# Load images with error handling
images = []
valid_filepaths = []
failed_count = 0
for filepath in filepaths:
try:
img = Image.open(filepath)
# Convert to RGB if necessary (handles RGBA, P, etc.)
if img.mode != "RGB":
img = img.convert("RGB")
images.append(img)
valid_filepaths.append(filepath)
except Exception as e:
failed_count += 1
print(f"Warning: Failed to load image {filepath}: {e}")
continue
if failed_count > 0:
print(
f"Warning: Failed to load {failed_count} image(s) out of {len(filepaths)} total files"
)
return valid_filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None: def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
@@ -150,36 +206,99 @@ def _select_device_and_dtype():
def _load_colvision(model_choice: str): def _load_colvision(model_choice: str):
import os
import torch import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor from colpali_engine.models import (
ColPali,
ColQwen2,
ColQwen2_5,
ColQwen2_5_Processor,
ColQwen2Processor,
)
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available from transformers.utils.import_utils import is_flash_attn_2_available
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
# Set environment variables to ensure models are downloaded from HuggingFace
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
# Log model loading info
logger.info(f"Loading ColVision model: {model_choice}")
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
device_str, device, dtype = _select_device_and_dtype() device_str, device, dtype = _select_device_and_dtype()
# Determine model name and type
# IMPORTANT: Check colqwen2.5 BEFORE colqwen2 to avoid false matches
model_choice_lower = model_choice.lower()
if model_choice == "colqwen2": if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0" model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available model_type = "colqwen2"
elif model_choice == "colqwen2.5" or model_choice == "colqwen25":
model_name = "vidore/colqwen2.5-v0.2"
model_type = "colqwen2.5"
elif model_choice == "colpali":
model_name = "vidore/colpali-v1.2"
model_type = "colpali"
elif (
"colqwen2.5" in model_choice_lower
or "colqwen25" in model_choice_lower
or "colqwen2_5" in model_choice_lower
):
# Handle HuggingFace model names like "vidore/colqwen2.5-v0.2"
model_name = model_choice
model_type = "colqwen2.5"
elif "colqwen2" in model_choice_lower and "colqwen2-v1.0" in model_choice_lower:
# Handle HuggingFace model names like "vidore/colqwen2-v1.0" (but not colqwen2.5)
model_name = model_choice
model_type = "colqwen2"
elif "colpali" in model_choice_lower:
# Handle HuggingFace model names like "vidore/colpali-v1.2"
model_name = model_choice
model_type = "colpali"
else:
# Default to colpali for backward compatibility
model_name = "vidore/colpali-v1.2"
model_type = "colpali"
# Load model based on type
attn_implementation = ( attn_implementation = (
"flash_attention_2" "flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
) )
# Load model from HuggingFace Hub (not Google Drive)
# Use local_files_only=False to ensure download from HF if not cached
if model_type == "colqwen2.5":
model = ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval()
processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
elif model_type == "colqwen2":
model = ColQwen2.from_pretrained( model = ColQwen2.from_pretrained(
model_name, model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=device, device_map=device,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval() ).eval()
processor = ColQwen2Processor.from_pretrained(model_name) processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False)
else: else: # colpali
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained( model = ColPali.from_pretrained(
model_name, model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=device, device_map=device,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval() ).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) processor = cast(
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
)
return model_name, model, processor, device_str, device, dtype return model_name, model, processor, device_str, device, dtype
@@ -194,7 +313,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
dataloader = DataLoader( dataloader = DataLoader(
dataset=ListDataset[Image.Image](images), dataset=ListDataset[Image.Image](images),
batch_size=1, batch_size=32,
shuffle=False, shuffle=False,
collate_fn=lambda x: processor.process_images(x), collate_fn=lambda x: processor.process_images(x),
) )
@@ -218,32 +337,47 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
def _embed_queries(model, processor, queries: list[str]) -> list[Any]: def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval() model.eval()
dataloader = DataLoader( # Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings:
dataset=ListDataset[str](queries), # 1. MTEB receives batch["text"] which already includes instruction/prompt (from _combine_queries_with_instruction_text)
batch_size=1, # 2. Manually adds: query_prefix + text + query_augmentation_token * 10
shuffle=False, # 3. Calls processor.process_queries(batch) where batch is now a list of strings
collate_fn=lambda x: processor.process_queries(x), # 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
) #
# This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total
# We need to match this exactly to reproduce MTEB results
all_embeds = []
batch_size = 32 # Match MTEB's default batch_size
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad(): with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()} for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
batch_queries = queries[i : i + batch_size]
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
# Then process_queries will add them again (resulting in 20 augmentation tokens total)
batch = [
processor.query_prefix + t + processor.query_augmentation_token * 10
for t in batch_queries
]
inputs = processor.process_queries(batch)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
if model.device.type == "cuda": if model.device.type == "cuda":
with torch.autocast( with torch.autocast(
device_type="cuda", device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16, dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
): ):
embeddings_query = model(**batch_query) outs = model(**inputs)
else: else:
embeddings_query = model(**batch_query) outs = model(**inputs)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs # Match MTEB: convert to float32 on CPU
all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32))))
return all_embeds
def _build_index( def _build_index(
@@ -283,6 +417,279 @@ def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
return None return None
def _build_fast_plaid_index(
index_path: str,
doc_vecs: list[Any],
filepaths: list[str],
images: list[Image.Image],
) -> tuple[Any, float]:
"""
Build a Fast-Plaid index from document embeddings.
Args:
index_path: Path to save the Fast-Plaid index
doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim])
filepaths: List of filepath identifiers for each document
images: List of PIL Images corresponding to each document
Returns:
Tuple of (FastPlaid index object, build_time_in_seconds)
"""
import torch
from fast_plaid import search as fast_plaid_search
print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...")
_t0 = time.perf_counter()
# Convert doc_vecs to list of tensors
documents_embeddings = []
for i, vec in enumerate(doc_vecs):
if i % 1000 == 0:
print(f" Converting embedding {i}/{len(doc_vecs)}...")
if not isinstance(vec, torch.Tensor):
vec = (
torch.tensor(vec)
if isinstance(vec, np.ndarray)
else torch.from_numpy(np.array(vec))
)
# Ensure float32 for Fast-Plaid
if vec.dtype != torch.float32:
vec = vec.float()
documents_embeddings.append(vec)
print(f" Converted {len(documents_embeddings)} embeddings")
if len(documents_embeddings) > 0:
print(f" First embedding shape: {documents_embeddings[0].shape}")
print(f" First embedding dtype: {documents_embeddings[0].dtype}")
# Prepare metadata for Fast-Plaid
print(f" Preparing metadata for {len(filepaths)} documents...")
metadata_list = []
for i, filepath in enumerate(filepaths):
metadata_list.append(
{
"filepath": filepath,
"index": i,
}
)
# Create Fast-Plaid index
print(f" Creating FastPlaid object with index path: {index_path}")
try:
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
print(" FastPlaid object created successfully")
except Exception as e:
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...")
try:
fast_plaid_index.create(
documents_embeddings=documents_embeddings,
metadata=metadata_list,
)
print(" Fast-Plaid index created successfully")
except Exception as e:
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
build_secs = time.perf_counter() - _t0
# Save images separately (Fast-Plaid doesn't store images)
print(f" Saving {len(images)} images...")
images_dir = Path(index_path) / "images"
images_dir.mkdir(parents=True, exist_ok=True)
for i, img in enumerate(tqdm(images, desc="Saving images")):
img_path = images_dir / f"doc_{i}.png"
img.save(str(img_path))
return fast_plaid_index, build_secs
def _fast_plaid_index_exists(index_path: str) -> bool:
"""
Check if Fast-Plaid index exists by checking for key files.
This avoids creating the FastPlaid object which may trigger memory allocation.
Args:
index_path: Path to the Fast-Plaid index
Returns:
True if index appears to exist, False otherwise
"""
index_path_obj = Path(index_path)
if not index_path_obj.exists() or not index_path_obj.is_dir():
return False
# Fast-Plaid creates a SQLite database file for metadata
# Check for metadata.db as the most reliable indicator
metadata_db = index_path_obj / "metadata.db"
if metadata_db.exists() and metadata_db.stat().st_size > 0:
return True
# Also check if directory has any files (might be incomplete index)
try:
if any(index_path_obj.iterdir()):
return True
except Exception:
pass
return False
def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]:
"""
Load Fast-Plaid index if it exists.
First checks if index files exist, then creates the FastPlaid object.
The actual index data loading happens lazily when search is called.
Args:
index_path: Path to the Fast-Plaid index
Returns:
FastPlaid index object if exists, None otherwise
"""
try:
from fast_plaid import search as fast_plaid_search
# First check if index files exist without creating the object
if not _fast_plaid_index_exists(index_path):
return None
# Now try to create FastPlaid object
# This may trigger some memory allocation, but the full index loading is deferred
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
return fast_plaid_index
except ImportError:
# fast-plaid not installed
return None
except Exception as e:
# Any error (including memory errors from Rust backend) - return None
# The error will be caught and index will be rebuilt
print(f"Warning: Could not load Fast-Plaid index: {type(e).__name__}: {e}")
return None
def _search_fast_plaid(
fast_plaid_index: Any,
query_vec: Any,
top_k: int,
) -> tuple[list[tuple[float, int]], float]:
"""
Search Fast-Plaid index with a query embedding.
Args:
fast_plaid_index: FastPlaid index object
query_vec: Query embedding tensor with shape [num_tokens, embedding_dim]
top_k: Number of top results to return
Returns:
Tuple of (results_list, search_time_in_seconds)
results_list: List of (score, doc_id) tuples
"""
import torch
_t0 = time.perf_counter()
# Ensure query is a torch tensor
if not isinstance(query_vec, torch.Tensor):
q_vec_tensor = (
torch.tensor(query_vec)
if isinstance(query_vec, np.ndarray)
else torch.from_numpy(np.array(query_vec))
)
else:
q_vec_tensor = query_vec
# Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim]
if q_vec_tensor.dim() == 2:
q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim]
# Perform search
scores = fast_plaid_index.search(
queries_embeddings=q_vec_tensor,
top_k=top_k,
show_progress=True,
)
search_secs = time.perf_counter() - _t0
# Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples
results = []
if scores and len(scores) > 0:
query_results = scores[0]
# Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format
results = [(float(score), int(doc_id)) for doc_id, score in query_results]
return results, search_secs
def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve image for a document from Fast-Plaid index.
Args:
index_path: Path to the Fast-Plaid index
doc_id: Document ID returned by Fast-Plaid search
Returns:
PIL Image if found, None otherwise
Note: Uses metadata['index'] to get the actual file index, as Fast-Plaid
doc_id may differ from the file naming index.
"""
# First get metadata to find the actual index used for file naming
metadata = _get_fast_plaid_metadata(index_path, doc_id)
if metadata is None:
# Fallback: try using doc_id directly
file_index = doc_id
else:
# Use the 'index' field from metadata, which matches the file naming
file_index = metadata.get("index", doc_id)
images_dir = Path(index_path) / "images"
image_path = images_dir / f"doc_{file_index}.png"
if image_path.exists():
return Image.open(image_path)
# If not found with index, try doc_id as fallback
if file_index != doc_id:
fallback_path = images_dir / f"doc_{doc_id}.png"
if fallback_path.exists():
return Image.open(fallback_path)
return None
def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a document from Fast-Plaid index.
Args:
index_path: Path to the Fast-Plaid index
doc_id: Document ID
Returns:
Dictionary with metadata if found, None otherwise
"""
try:
from fast_plaid import filtering
metadata_list = filtering.get(index=index_path, subset=[doc_id])
if metadata_list and len(metadata_list) > 0:
return metadata_list[0]
except Exception:
pass
return None
def _generate_similarity_map( def _generate_similarity_map(
model, model,
processor, processor,
@@ -678,11 +1085,15 @@ class LeannMultiVector:
return (float(score), doc_id) return (float(score), doc_id)
scores: list[tuple[float, int]] = [] scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids] futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures): for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result()) scores.append(fut.result())
end_time = time.time()
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True) scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores return scores[:topk] if len(scores) >= topk else scores
@@ -710,7 +1121,6 @@ class LeannMultiVector:
emb_path = self._embeddings_path() emb_path = self._embeddings_path()
if not emb_path.exists(): if not emb_path.exists():
return self.search(data, topk) return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r") all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32: if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32) all_embeddings = all_embeddings.astype(np.float32)
@@ -718,23 +1128,29 @@ class LeannMultiVector:
assert self._docid_to_indices is not None assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys()) candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int) -> tuple[float, int]: def _score_one(doc_id: int, _all_embeddings=all_embeddings) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, []) token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices: if not token_indices:
return (0.0, doc_id) return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32) doc_vecs = np.asarray(_all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T) sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30) sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum() score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id) return (float(score), doc_id)
scores: list[tuple[float, int]] = [] scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids] futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures): for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result()) scores.append(fut.result())
end_time = time.time()
# print number of candidate doc ids
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True) scores.sort(key=lambda x: x[0], reverse=True)
del all_embeddings
return scores[:topk] if len(scores) >= topk else scores return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]: def get_image(self, doc_id: int) -> Optional[Image.Image]:
@@ -778,3 +1194,259 @@ class LeannMultiVector:
"image_path": meta.get("image_path", ""), "image_path": meta.get("image_path", ""),
} }
return None return None
class ViDoReBenchmarkEvaluator:
"""
A reusable class for evaluating ViDoRe benchmarks (v1 and v2).
This class encapsulates common functionality for building indexes, searching, and evaluating.
"""
def __init__(
self,
model_name: str,
use_fast_plaid: bool = False,
top_k: int = 100,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
):
"""
Initialize the evaluator.
Args:
model_name: Model name ("colqwen2" or "colpali")
use_fast_plaid: Whether to use Fast-Plaid instead of LEANN
top_k: Top-k results to retrieve
first_stage_k: First stage k for LEANN search
k_values: List of k values for evaluation metrics
"""
self.model_name = model_name
self.use_fast_plaid = use_fast_plaid
self.top_k = top_k
self.first_stage_k = first_stage_k
self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100]
# Load model once (can be reused across tasks)
self._model = None
self._processor = None
self._model_name_actual = None
def _load_model_if_needed(self):
"""Lazy load the model."""
if self._model is None:
print(f"\nLoading model: {self.model_name}")
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(
self.model_name
)
print(f"Model loaded: {self._model_name_actual}")
def build_index_from_corpus(
self,
corpus: dict[str, Image.Image],
index_path: str,
rebuild: bool = False,
) -> tuple[Any, list[str]]:
"""
Build index from corpus images.
Args:
corpus: dict mapping corpus_id to PIL Image
index_path: Path to save/load the index
rebuild: Whether to rebuild even if index exists
Returns:
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
"""
self._load_model_if_needed()
# Ensure consistent ordering
corpus_ids = sorted(corpus.keys())
images = [corpus[cid] for cid in corpus_ids]
if self.use_fast_plaid:
# Check if Fast-Plaid index exists
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
print(f"Fast-Plaid index already exists at {index_path}")
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
print(f"Building Fast-Plaid index at {index_path}...")
print("Embedding images...")
doc_vecs = _embed_images(self._model, self._processor, images)
fast_plaid_index, build_time = _build_fast_plaid_index(
index_path, doc_vecs, corpus_ids, images
)
print(f"Fast-Plaid index built in {build_time:.2f}s")
return fast_plaid_index, corpus_ids
else:
# Check if LEANN index exists
if not rebuild:
retriever = _load_retriever_if_index_exists(index_path)
if retriever is not None:
print(f"LEANN index already exists at {index_path}")
return retriever, corpus_ids
print(f"Building LEANN index at {index_path}...")
print("Embedding images...")
doc_vecs = _embed_images(self._model, self._processor, images)
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
print("LEANN index built")
return retriever, corpus_ids
def search_queries(
self,
queries: dict[str, str],
corpus_ids: list[str],
index_or_retriever: Any,
fast_plaid_index_path: Optional[str] = None,
task_prompt: Optional[dict[str, str]] = None,
) -> dict[str, dict[str, float]]:
"""
Search queries against the index.
Args:
queries: dict mapping query_id to query text
corpus_ids: list of corpus_ids in the same order as the index
index_or_retriever: index or retriever object
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
Returns:
results: dict mapping query_id to dict of {corpus_id: score}
"""
self._load_model_if_needed()
print(f"Searching {len(queries)} queries (top_k={self.top_k})...")
query_ids = list(queries.keys())
query_texts = [queries[qid] for qid in query_ids]
# Note: ColPaliEngineWrapper does NOT use task prompt from metadata
# It uses query_prefix + text + query_augmentation_token (handled in _embed_queries)
# So we don't append task_prompt here to match MTEB behavior
# Embed queries
print("Embedding queries...")
query_vecs = _embed_queries(self._model, self._processor, query_texts)
results = {}
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
if self.use_fast_plaid:
# Fast-Plaid search
search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, self.top_k)
query_results = {}
for score, doc_id in search_results:
if doc_id < len(corpus_ids):
corpus_id = corpus_ids[doc_id]
query_results[corpus_id] = float(score)
else:
# LEANN search
import torch
query_np = (
query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
)
search_results = index_or_retriever.search_exact(query_np, topk=self.top_k)
query_results = {}
for score, doc_id in search_results:
if doc_id < len(corpus_ids):
corpus_id = corpus_ids[doc_id]
query_results[corpus_id] = float(score)
results[query_id] = query_results
return results
@staticmethod
def evaluate_results(
results: dict[str, dict[str, float]],
qrels: dict[str, dict[str, int]],
k_values: Optional[list[int]] = None,
) -> dict[str, float]:
"""
Evaluate retrieval results using NDCG and other metrics.
Args:
results: dict mapping query_id to dict of {corpus_id: score}
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
k_values: List of k values for evaluation metrics
Returns:
Dictionary of metric scores
"""
try:
from mteb._evaluators.retrieval_metrics import (
calculate_retrieval_scores,
make_score_dict,
)
except ImportError:
raise ImportError(
"pytrec_eval is required for evaluation. Install with: pip install pytrec-eval"
)
if k_values is None:
k_values = [1, 3, 5, 10, 100]
# Check if we have any queries to evaluate
if len(results) == 0:
print("Warning: No queries to evaluate. Returning zero scores.")
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
print(f"Evaluating results with k_values={k_values}...")
print(f"Before filtering: {len(results)} results, {len(qrels)} qrels")
# Filter to ensure qrels and results have the same query set
# This matches MTEB behavior: only evaluate queries that exist in both
# pytrec_eval only evaluates queries in qrels, so we need to ensure
# results contains all queries in qrels, and filter out queries not in qrels
results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
qrels_filtered = {
qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered
}
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
if len(results_filtered) != len(qrels_filtered):
print(
f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries"
)
missing_in_results = set(qrels.keys()) - set(results.keys())
if missing_in_results:
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
print(f"First 5 missing queries: {list(missing_in_results)[:5]}")
# Convert qrels to pytrec_eval format
qrels_pytrec = {}
for qid, rel_docs in qrels_filtered.items():
qrels_pytrec[qid] = dict(rel_docs.items())
# Evaluate
eval_result = calculate_retrieval_scores(
results=results_filtered,
qrels=qrels_pytrec,
k_values=k_values,
)
# Format scores
scores = make_score_dict(
ndcg=eval_result.ndcg,
_map=eval_result.map,
recall=eval_result.recall,
precision=eval_result.precision,
mrr=eval_result.mrr,
naucs=eval_result.naucs,
naucs_mrr=eval_result.naucs_mrr,
cv_recall=eval_result.cv_recall,
task_scores={},
)
return scores

View File

@@ -1,12 +1,19 @@
## Jupyter-style notebook script ## Jupyter-style notebook script
# %% # %%
# uv pip install matplotlib qwen_vl_utils # uv pip install matplotlib qwen_vl_utils
import argparse
import faulthandler
import os import os
import time
from typing import Any, Optional from typing import Any, Optional
import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
# Enable faulthandler to get stack trace on segfault
faulthandler.enable()
from leann_multi_vector import ( # utility functions/classes from leann_multi_vector import ( # utility functions/classes
_ensure_repo_paths_importable, _ensure_repo_paths_importable,
@@ -18,6 +25,11 @@ from leann_multi_vector import ( # utility functions/classes
_build_index, _build_index,
_load_retriever_if_index_exists, _load_retriever_if_index_exists,
_generate_similarity_map, _generate_similarity_map,
_build_fast_plaid_index,
_load_fast_plaid_index_if_exists,
_search_fast_plaid,
_get_fast_plaid_image,
_get_fast_plaid_metadata,
QwenVL, QwenVL,
) )
@@ -31,19 +43,52 @@ MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended) # Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True USE_HF_DATASET: bool = True
# Single dataset name (used when DATASET_NAMES is None)
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector" DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
DATASET_SPLIT: str = "train" # Multiple datasets to combine (if provided, DATASET_NAME is ignored)
# Can be:
# - List of strings: ["dataset1", "dataset2"]
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
# - Mixed: ["dataset1", ("dataset2", "config2")]
#
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
# - "pixparse/arxiv-papers" (if available, arXiv papers)
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
# - "huggingface/document-images" (if available)
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
DATASET_NAMES = [
"weaviate/arXiv-AI-papers-multi-vector",
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
]
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
# Set to None to try loading all available splits automatically
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
# Image field name in the dataset (auto-detect if None)
# Common names: "page_image", "image", "images", "img"
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False) # Local pages (used when USE_HF_DATASET == False)
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf" PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages" PAGES_DIR: str = "./pages"
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
# If set, images will be loaded directly from this folder
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
# Whether to recursively search subdirectories when loading from custom folder
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
# Index + retrieval settings # Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann" # Use a different index path for larger dataset to avoid overwriting existing index
INDEX_PATH: str = "./indexes/colvision_large.leann"
# Fast-Plaid index settings (alternative to LEANN index)
# These are now command-line arguments (see CLI overrides section)
TOPK: int = 3 TOPK: int = 3
FIRST_STAGE_K: int = 500 FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
# Artifacts # Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png" SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
@@ -54,12 +99,97 @@ ANSWER: bool = True
MAX_NEW_TOKENS: int = 1024 MAX_NEW_TOKENS: int = 1024
# %%
# CLI overrides
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
parser.add_argument(
"--search-method",
type=str,
choices=["ann", "exact", "exact-all"],
default="ann",
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
)
parser.add_argument(
"--query",
type=str,
default=QUERY,
help=f"Query string to search for. Default: '{QUERY}'",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
default=False,
help="Set to True to use fast-plaid instead of LEANN. Default: False",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default="./indexes/colvision_fastplaid",
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
)
parser.add_argument(
"--topk",
type=int,
default=TOPK,
help=f"Number of top results to retrieve. Default: {TOPK}",
)
parser.add_argument(
"--custom-folder",
type=str,
default=None,
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
)
parser.add_argument(
"--recursive",
action="store_true",
default=False,
help="Recursively search subdirectories when loading images from custom folder. Default: False",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
default=False,
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
)
cli_args, _unknown = parser.parse_known_args()
SEARCH_METHOD: str = cli_args.search_method
QUERY = cli_args.query # Override QUERY with CLI argument if provided
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
# %% # %%
# Step 1: Check if we can skip data loading (index already exists) # Step 1: Check if we can skip data loading (index already exists)
retriever: Optional[Any] = None retriever: Optional[Any] = None
fast_plaid_index: Optional[Any] = None
need_to_build_index = REBUILD_INDEX need_to_build_index = REBUILD_INDEX
if USE_FAST_PLAID:
# Fast-Plaid index handling
if not REBUILD_INDEX:
try:
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
if fast_plaid_index is not None:
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
need_to_build_index = False
else:
print(f"Fast-Plaid index not found, will build new index")
need_to_build_index = True
except Exception as e:
# If loading fails (e.g., memory error, corrupted index), rebuild
print(f"Warning: Failed to load Fast-Plaid index: {e}")
print("Will rebuild the index...")
need_to_build_index = True
fast_plaid_index = None
else:
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
need_to_build_index = True
else:
# Original LEANN index handling
if not REBUILD_INDEX: if not REBUILD_INDEX:
retriever = _load_retriever_if_index_exists(INDEX_PATH) retriever = _load_retriever_if_index_exists(INDEX_PATH)
if retriever is not None: if retriever is not None:
@@ -69,23 +199,247 @@ if not REBUILD_INDEX:
else: else:
print(f"Index not found, will build new index") print(f"Index not found, will build new index")
need_to_build_index = True need_to_build_index = True
else:
print(f"REBUILD_INDEX=True, will rebuild index")
need_to_build_index = True
# Step 2: Load data only if we need to build the index # Step 2: Load data only if we need to build the index
if need_to_build_index: if need_to_build_index:
print("Loading dataset...") print("Loading dataset...")
if USE_HF_DATASET: # Check for custom folder path first (takes precedence)
from datasets import load_dataset if CUSTOM_FOLDER_PATH:
if not os.path.isdir(CUSTOM_FOLDER_PATH):
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
if CUSTOM_FOLDER_RECURSIVE:
print(" (recursive mode: searching subdirectories)")
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
print(f" Found {len(filepaths)} image files")
if not images:
raise RuntimeError(
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
)
print(f" Successfully loaded {len(images)} images")
# Use filenames as identifiers instead of full paths for cleaner metadata
filepaths = [os.path.basename(fp) for fp in filepaths]
elif USE_HF_DATASET:
from datasets import load_dataset, concatenate_datasets, DatasetDict
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT) # Determine which datasets to load
if DATASET_NAMES is not None:
dataset_names_to_load = DATASET_NAMES
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
else:
dataset_names_to_load = [DATASET_NAME]
print(f"Loading single dataset: {DATASET_NAME}")
# Load and combine datasets
all_datasets_to_concat = []
for dataset_entry in dataset_names_to_load:
# Handle both string and tuple formats
if isinstance(dataset_entry, tuple):
dataset_name, config_name = dataset_entry
else:
dataset_name = dataset_entry
config_name = None
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
# Load dataset to check available splits
# If config_name is provided, use it; otherwise try without config
try:
if config_name:
dataset_dict = load_dataset(dataset_name, config_name)
else:
dataset_dict = load_dataset(dataset_name)
except ValueError as e:
if "Config name is missing" in str(e):
# Try to get available configs and suggest
from datasets import get_dataset_config_names
try:
available_configs = get_dataset_config_names(dataset_name)
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Available configs: {available_configs}. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
except Exception:
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
raise
# Determine which splits to load
if DATASET_SPLITS is None:
# Auto-detect: try to load all available splits
available_splits = list(dataset_dict.keys())
print(f" Auto-detected splits: {available_splits}")
splits_to_load = available_splits
else:
splits_to_load = DATASET_SPLITS
# Load and concatenate multiple splits for this dataset
datasets_to_concat = []
for split in splits_to_load:
if split not in dataset_dict:
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
continue
split_dataset = dataset_dict[split]
print(f" Loaded split '{split}': {len(split_dataset)} pages")
datasets_to_concat.append(split_dataset)
if not datasets_to_concat:
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
continue
# Concatenate splits for this dataset
if len(datasets_to_concat) > 1:
combined_dataset = concatenate_datasets(datasets_to_concat)
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
else:
combined_dataset = datasets_to_concat[0]
all_datasets_to_concat.append(combined_dataset)
if not all_datasets_to_concat:
raise RuntimeError("No valid datasets or splits found.")
# Concatenate all datasets
if len(all_datasets_to_concat) > 1:
dataset = concatenate_datasets(all_datasets_to_concat)
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
else:
dataset = all_datasets_to_concat[0]
# Apply MAX_DOCS limit if specified
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset)) N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
if N < len(dataset):
print(f"Limiting to {N} pages (from {len(dataset)} total)")
dataset = dataset.select(range(N))
# Auto-detect image field name if not specified
if IMAGE_FIELD_NAME is None:
# Check multiple samples to find the most common image field
# (useful when datasets are merged and may have different field names)
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
field_counts = {}
# Check first few samples to find image fields
num_samples_to_check = min(10, len(dataset))
for sample_idx in range(num_samples_to_check):
sample = dataset[sample_idx]
for field in possible_image_fields:
if field in sample and sample[field] is not None:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
field_counts[field] = field_counts.get(field, 0) + 1
# Choose the most common field, or first found if tied
if field_counts:
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
else:
# Fallback: check first sample only
sample = dataset[0]
image_field = None
for field in possible_image_fields:
if field in sample:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
image_field = field
break
if image_field is None:
raise RuntimeError(
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
f"Please specify IMAGE_FIELD_NAME manually."
)
print(f"Auto-detected image field: '{image_field}'")
else:
image_field = IMAGE_FIELD_NAME
if image_field not in dataset[0]:
raise RuntimeError(
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
)
filepaths: list[str] = [] filepaths: list[str] = []
images: list[Image.Image] = [] images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N): for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
p = dataset[i] p = dataset[i]
# Compose a descriptive identifier for printing later # Try to compose a descriptive identifier
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}" # Handle different dataset structures
identifier_parts = []
# Helper function to safely get field value
def safe_get(field_name, default=None):
if field_name in p and p[field_name] is not None:
return p[field_name]
return default
# Try to get various identifier fields
if safe_get("paper_arxiv_id"):
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
if safe_get("paper_title"):
identifier_parts.append(f"title:{p['paper_title']}")
if safe_get("page_number") is not None:
try:
identifier_parts.append(f"page:{int(p['page_number'])}")
except (ValueError, TypeError):
# If conversion fails, use the raw value or skip
if p['page_number']:
identifier_parts.append(f"page:{p['page_number']}")
if safe_get("page_id"):
identifier_parts.append(f"id:{p['page_id']}")
elif safe_get("questionId"):
identifier_parts.append(f"qid:{p['questionId']}")
elif safe_get("docId"):
identifier_parts.append(f"docId:{p['docId']}")
elif safe_get("id"):
identifier_parts.append(f"id:{p['id']}")
# If no identifier parts found, create one from index
if identifier_parts:
identifier = "|".join(identifier_parts)
else:
# Create identifier from available fields or index
fallback_parts = []
# Try common fields that might exist
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
if safe_get(field):
fallback_parts.append(f"{field}:{p[field]}")
break
if fallback_parts:
identifier = "|".join(fallback_parts) + f"|idx:{i}"
else:
identifier = f"doc_{i}"
filepaths.append(identifier) filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
# Get image - try detected field first, then fallback to other common fields
img = None
if image_field in p and p[image_field] is not None:
img = p[image_field]
else:
# Fallback: try other common image field names
for fallback_field in ["image", "page_image", "images", "img"]:
if fallback_field in p and p[fallback_field] is not None:
img = p[fallback_field]
break
if img is None:
raise RuntimeError(
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
f"Expected field: {image_field}"
)
# Ensure it's a PIL Image
if not isinstance(img, Image.Image):
if hasattr(img, 'convert'):
img = img.convert('RGB')
else:
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
images.append(img)
else: else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR) _maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR) filepaths, images = _load_images_from_dir(PAGES_DIR)
@@ -94,6 +448,19 @@ if need_to_build_index:
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist." f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
) )
print(f"Loaded {len(images)} images") print(f"Loaded {len(images)} images")
# Memory check before loading model
try:
import psutil
import torch
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
if torch.cuda.is_available():
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
except ImportError:
pass
else: else:
print("Skipping dataset loading (using existing index)") print("Skipping dataset loading (using existing index)")
filepaths = [] # Not needed when using existing index filepaths = [] # Not needed when using existing index
@@ -102,36 +469,152 @@ else:
# %% # %%
# Step 3: Load model and processor (only if we need to build index or perform search) # Step 3: Load model and processor (only if we need to build index or perform search)
print("Step 3: Loading model and processor...")
print(f" Model: {MODEL}")
try:
import sys
print(f" Python version: {sys.version}")
print(f" Python executable: {sys.executable}")
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL) model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}") print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# Memory check after loading model
try:
import psutil
import torch
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
if torch.cuda.is_available():
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
except ImportError:
pass
except Exception as e:
print(f"✗ Error loading model: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
# %% # %%
# %% # %%
# Step 4: Build index if needed # Step 4: Build index if needed
if need_to_build_index and retriever is None: if need_to_build_index:
print("Building index...") print("Step 4: Building index...")
print(f" Number of images: {len(images)}")
print(f" Number of filepaths: {len(filepaths)}")
try:
print(" Embedding images...")
doc_vecs = _embed_images(model, processor, images) doc_vecs = _embed_images(model, processor, images)
print(f" Embedded {len(doc_vecs)} documents")
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
except Exception as e:
print(f"Error embedding images: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
if USE_FAST_PLAID:
# Build Fast-Plaid index
print(" Building Fast-Plaid index...")
try:
fast_plaid_index, build_secs = _build_fast_plaid_index(
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
)
from pathlib import Path
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
except Exception as e:
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
finally:
# Clear memory
print(" Clearing memory...")
del images, filepaths, doc_vecs
else:
# Build original LEANN index
try:
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images) retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}") print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
except Exception as e:
print(f"Error building LEANN index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
finally:
# Clear memory # Clear memory
print(" Clearing memory...")
del images, filepaths, doc_vecs del images, filepaths, doc_vecs
# Note: Images are now stored in the index, retriever will load them on-demand from disk # Note: Images are now stored separately, retriever/fast_plaid_index will reference them
# %% # %%
# Step 5: Embed query and search # Step 5: Embed query and search
_t0 = time.perf_counter()
q_vec = _embed_queries(model, processor, [QUERY])[0] q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK) query_embed_secs = time.perf_counter() - _t0
print(f"[Search] Method: {SEARCH_METHOD}")
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
# Run the selected search method and time it
if USE_FAST_PLAID:
# Fast-Plaid search
if fast_plaid_index is None:
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
if fast_plaid_index is None:
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
else:
# Original LEANN search
query_np = q_vec.float().numpy()
if SEARCH_METHOD == "ann":
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact":
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact-all":
results = retriever.search_exact_all(query_np, topk=TOPK)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
else:
results = []
if not results: if not results:
print("No results found.") print("No results found.")
else: else:
print(f'Top {len(results)} results for query: "{QUERY}"') print(f'Top {len(results)} results for query: "{QUERY}"')
print("\n[DEBUG] Retrieval details:")
top_images: list[Image.Image] = [] top_images: list[Image.Image] = []
image_hashes = {} # Track image hashes to detect duplicates
for rank, (score, doc_id) in enumerate(results, start=1): for rank, (score, doc_id) in enumerate(results, start=1):
# Retrieve image from index instead of memory # Retrieve image and metadata based on index type
if USE_FAST_PLAID:
# Fast-Plaid: load image and get metadata
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
if image is None:
print(f"Warning: Could not find image for doc_id {doc_id}")
continue
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
top_images.append(image)
else:
# Original LEANN: retrieve from retriever
image = retriever.get_image(doc_id) image = retriever.get_image(doc_id)
if image is None: if image is None:
print(f"Warning: Could not retrieve image for doc_id {doc_id}") print(f"Warning: Could not retrieve image for doc_id {doc_id}")
@@ -139,10 +622,29 @@ else:
metadata = retriever.get_metadata(doc_id) metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown" path = metadata.get("filepath", "unknown") if metadata else "unknown"
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(image) top_images.append(image)
# Calculate image hash to detect duplicates
import hashlib
import io
# Convert image to bytes for hashing
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
image_bytes = img_bytes.getvalue()
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
# Check if this image was already seen
duplicate_info = ""
if image_hash in image_hashes:
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
else:
image_hashes[image_hash] = rank
# Print detailed information
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
if metadata:
print(f" Metadata: {metadata}")
if SAVE_TOP_IMAGE: if SAVE_TOP_IMAGE:
from pathlib import Path as _Path from pathlib import Path as _Path
@@ -161,7 +663,6 @@ else:
except Exception: except Exception:
print(f"Saved retrieved page (rank {rank}) to: {out_path}") print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %% # %%
# Step 6: Similarity maps for top-K results # Step 6: Similarity maps for top-K results
@@ -204,6 +705,9 @@ if results and SIMILARITY_MAP:
# Step 7: Optional answer generation # Step 7: Optional answer generation
if results and ANSWER: if results and ANSWER:
qwen = QwenVL(device=device_str) qwen = QwenVL(device=device_str)
_t0 = time.perf_counter()
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS) response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
gen_secs = time.perf_counter() - _t0
print(f"[Timing] Generation: {gen_secs:.3f}s")
print("\nAnswer:") print("\nAnswer:")
print(response) print(response)

View File

@@ -0,0 +1,448 @@
#!/usr/bin/env python3
"""
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
This script uses the interface from leann_multi_vector.py to:
1. Download ViDoRe v1 datasets
2. Build indexes (LEANN or Fast-Plaid)
3. Perform retrieval
4. Evaluate using NDCG metrics
Usage:
# Evaluate all ViDoRe v1 tasks
python vidore_v1_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
# Use Fast-Plaid index
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
"""
import argparse
import json
import os
from typing import Optional
from datasets import load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
)
_ensure_repo_paths_importable(__file__)
# ViDoRe v1 task configurations
# Prompts match MTEB task metadata prompts
VIDORE_V1_TASKS = {
"VidoreArxivQARetrieval": {
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreDocVQARetrieval": {
"dataset_path": "vidore/docvqa_test_subsampled_beir",
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreInfoVQARetrieval": {
"dataset_path": "vidore/infovqa_test_subsampled_beir",
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreTabfquadRetrieval": {
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreTatdqaRetrieval": {
"dataset_path": "vidore/tatdqa_test_beir",
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreShiftProjectRetrieval": {
"dataset_path": "vidore/shiftproject_test_beir",
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAAIRetrieval": {
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAEnergyRetrieval": {
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
}
# Task name aliases (short names -> full names)
TASK_ALIASES = {
"arxivqa": "VidoreArxivQARetrieval",
"docvqa": "VidoreDocVQARetrieval",
"infovqa": "VidoreInfoVQARetrieval",
"tabfquad": "VidoreTabfquadRetrieval",
"tatdqa": "VidoreTatdqaRetrieval",
"shiftproject": "VidoreShiftProjectRetrieval",
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
}
def normalize_task_name(task_name: str) -> str:
"""Normalize task name (handle aliases)."""
task_name_lower = task_name.lower()
if task_name in VIDORE_V1_TASKS:
return task_name
if task_name_lower in TASK_ALIASES:
return TASK_ALIASES[task_name_lower]
# Try partial match
for alias, full_name in TASK_ALIASES.items():
if alias in task_name_lower or task_name_lower in alias:
return full_name
return task_name
def get_safe_model_name(model_name: str) -> str:
"""Get a safe model name for use in file paths."""
import hashlib
import os
# If it's a path, use basename or hash
if os.path.exists(model_name) and os.path.isdir(model_name):
# Use basename if it's reasonable, otherwise use hash
basename = os.path.basename(model_name.rstrip("/"))
if basename and len(basename) < 100 and not basename.startswith("."):
return basename
# Use hash for very long or problematic paths
return hashlib.md5(model_name.encode()).hexdigest()[:16]
# For HuggingFace model names, replace / with _
return model_name.replace("/", "_").replace(":", "_")
def load_vidore_v1_data(
dataset_path: str,
revision: Optional[str] = None,
split: str = "test",
):
"""
Load ViDoRe v1 dataset.
Returns:
corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
"""
print(f"Loading dataset: {dataset_path} (split={split})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
queries = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
}
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered
def evaluate_task(
task_name: str,
model_name: str,
index_path: str,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
rebuild_index: bool = False,
top_k: int = 1000,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None,
):
"""
Evaluate a single ViDoRe v1 task.
"""
print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}")
print(f"{'=' * 80}")
# Normalize task name (handle aliases)
task_name = normalize_task_name(task_name)
# Get task config
if task_name not in VIDORE_V1_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
task_config = VIDORE_V1_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
# Load data
corpus, queries, qrels = load_vidore_v1_data(
dataset_path=dataset_path,
revision=revision,
split="test",
)
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 20, 100, 1000]
# Check if we have any queries
if len(queries) == 0:
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
# Return zero scores
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
# Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name,
use_fast_plaid=use_fast_plaid,
top_k=top_k,
first_stage_k=first_stage_k,
k_values=k_values,
)
# Build or load index
# Use safe model name for index path (different models need different indexes)
safe_model_name = get_safe_model_name(model_name)
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus,
index_path=index_path_full,
rebuild=rebuild_index,
)
# Search queries
task_prompt = task_config.get("prompt")
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,
index_or_retriever=index_or_retriever,
fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt,
)
# Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results
print(f"\n{'=' * 80}")
print(f"Results for {task_name}:")
print(f"{'=' * 80}")
for metric, value in scores.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f:
json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}")
return scores
def main():
parser = argparse.ArgumentParser(
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
)
parser.add_argument(
"--model",
type=str,
default="colqwen2",
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="Specific task to evaluate (or 'all' for all tasks)",
)
parser.add_argument(
"--tasks",
type=str,
default="all",
help="Tasks to evaluate: 'all' or comma-separated list",
)
parser.add_argument(
"--index-path",
type=str,
default=None,
help="Path to LEANN index (auto-generated if not provided)",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
help="Use Fast-Plaid instead of LEANN",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default=None,
help="Path to Fast-Plaid index (auto-generated if not provided)",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Rebuild index even if it exists",
)
parser.add_argument(
"--top-k",
type=int,
default=1000,
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
)
parser.add_argument(
"--first-stage-k",
type=int,
default=500,
help="First stage k for LEANN search",
)
parser.add_argument(
"--k-values",
type=str,
default="1,3,5,10,20,100,1000",
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
)
parser.add_argument(
"--output-dir",
type=str,
default="./vidore_v1_results",
help="Output directory for results",
)
args = parser.parse_args()
# Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate
if args.task:
tasks_to_eval = [normalize_task_name(args.task)]
elif args.tasks.lower() == "all":
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
else:
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task
all_scores = {}
for task_name in tasks_to_eval:
try:
scores = evaluate_task(
task_name=task_name,
model_name=args.model,
index_path=args.index_path,
use_fast_plaid=args.use_fast_plaid,
fast_plaid_index_path=args.fast_plaid_index_path,
rebuild_index=args.rebuild_index,
top_k=args.top_k,
first_stage_k=args.first_stage_k,
k_values=k_values,
output_dir=args.output_dir,
)
all_scores[task_name] = scores
except Exception as e:
print(f"\nError evaluating {task_name}: {e}")
import traceback
traceback.print_exc()
continue
# Print summary
if all_scores:
print(f"\n{'=' * 80}")
print("SUMMARY")
print(f"{'=' * 80}")
for task_name, scores in all_scores.items():
print(f"\n{task_name}:")
# Print main metrics
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
if metric in scores:
print(f" {metric}: {scores[metric]:.5f}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,439 @@
#!/usr/bin/env python3
"""
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
This script uses the interface from leann_multi_vector.py to:
1. Download ViDoRe v2 datasets
2. Build indexes (LEANN or Fast-Plaid)
3. Perform retrieval
4. Evaluate using NDCG metrics
Usage:
# Evaluate all ViDoRe v2 tasks
python vidore_v2_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
# Use Fast-Plaid index
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
"""
import argparse
import json
import os
from typing import Optional
from datasets import load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
)
_ensure_repo_paths_importable(__file__)
# Language name to dataset language field value mapping
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
LANGUAGE_MAPPING = {
"english": "eng-Latn",
"french": "fra-Latn",
"spanish": "spa-Latn",
"german": "deu-Latn",
}
# ViDoRe v2 task configurations
# Prompts match MTEB task metadata prompts
VIDORE_V2_TASKS = {
"Vidore2ESGReportsRetrieval": {
"dataset_path": "vidore/esg_reports_v2",
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2EconomicsReportsRetrieval": {
"dataset_path": "vidore/economics_reports_v2",
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2BioMedicalLecturesRetrieval": {
"dataset_path": "vidore/biomedical_lectures_v2",
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2ESGReportsHLRetrieval": {
"dataset_path": "vidore/esg_reports_human_labeled_v2",
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
# Note: This dataset doesn't have language filtering - all queries are English
"languages": None, # No language filtering needed
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
}
def load_vidore_v2_data(
dataset_path: str,
revision: Optional[str] = None,
split: str = "test",
language: Optional[str] = None,
):
"""
Load ViDoRe v2 dataset.
Returns:
corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
"""
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
# Check if dataset has language field before filtering
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
if language and has_language_field:
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
dataset_language = LANGUAGE_MAPPING.get(language, language)
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
# Check if filtering resulted in empty dataset
if len(query_ds_filtered) == 0:
print(
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
)
# Try with original language value (dataset might use simple names like 'english')
print(f"Trying with original language value '{language}'...")
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
if len(query_ds_filtered) == 0:
# Try to get a sample to see actual language values
try:
sample_ds = load_dataset(
dataset_path, "queries", split=split, revision=revision
)
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
sample_langs = set(sample_ds["language"])
print(f"Available language values in dataset: {sample_langs}")
except Exception:
pass
else:
print(
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
)
query_ds = query_ds_filtered
queries = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
}
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered
def evaluate_task(
task_name: str,
model_name: str,
index_path: str,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
language: Optional[str] = None,
rebuild_index: bool = False,
top_k: int = 100,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None,
):
"""
Evaluate a single ViDoRe v2 task.
"""
print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}")
print(f"{'=' * 80}")
# Get task config
if task_name not in VIDORE_V2_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
task_config = VIDORE_V2_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
# Determine language
if language is None:
# Use first language if multiple available
languages = task_config.get("languages")
if languages is None:
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
language = None
elif len(languages) == 1:
language = languages[0]
else:
language = None
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 100]
# Load data
corpus, queries, qrels = load_vidore_v2_data(
dataset_path=dataset_path,
revision=revision,
split="test",
language=language,
)
# Check if we have any queries
if len(queries) == 0:
print(
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
)
# Return zero scores
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
# Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name,
use_fast_plaid=use_fast_plaid,
top_k=top_k,
first_stage_k=first_stage_k,
k_values=k_values,
)
# Build or load index
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{model_name}"
if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus,
index_path=index_path_full,
rebuild=rebuild_index,
)
# Search queries
task_prompt = task_config.get("prompt")
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,
index_or_retriever=index_or_retriever,
fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt,
)
# Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results
print(f"\n{'=' * 80}")
print(f"Results for {task_name}:")
print(f"{'=' * 80}")
for metric, value in scores.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f:
json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}")
return scores
def main():
parser = argparse.ArgumentParser(
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
)
parser.add_argument(
"--model",
type=str,
default="colqwen2",
choices=["colqwen2", "colpali"],
help="Model to use",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="Specific task to evaluate (or 'all' for all tasks)",
)
parser.add_argument(
"--tasks",
type=str,
default="all",
help="Tasks to evaluate: 'all' or comma-separated list",
)
parser.add_argument(
"--index-path",
type=str,
default=None,
help="Path to LEANN index (auto-generated if not provided)",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
help="Use Fast-Plaid instead of LEANN",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default=None,
help="Path to Fast-Plaid index (auto-generated if not provided)",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Rebuild index even if it exists",
)
parser.add_argument(
"--language",
type=str,
default=None,
help="Language to evaluate (default: first available)",
)
parser.add_argument(
"--top-k",
type=int,
default=100,
help="Top-k results to retrieve",
)
parser.add_argument(
"--first-stage-k",
type=int,
default=500,
help="First stage k for LEANN search",
)
parser.add_argument(
"--k-values",
type=str,
default="1,3,5,10,100",
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
)
parser.add_argument(
"--output-dir",
type=str,
default="./vidore_v2_results",
help="Output directory for results",
)
args = parser.parse_args()
# Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate
if args.task:
tasks_to_eval = [args.task]
elif args.tasks.lower() == "all":
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
else:
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task
all_scores = {}
for task_name in tasks_to_eval:
try:
scores = evaluate_task(
task_name=task_name,
model_name=args.model,
index_path=args.index_path,
use_fast_plaid=args.use_fast_plaid,
fast_plaid_index_path=args.fast_plaid_index_path,
language=args.language,
rebuild_index=args.rebuild_index,
top_k=args.top_k,
first_stage_k=args.first_stage_k,
k_values=k_values,
output_dir=args.output_dir,
)
all_scores[task_name] = scores
except Exception as e:
print(f"\nError evaluating {task_name}: {e}")
import traceback
traceback.print_exc()
continue
# Print summary
if all_scores:
print(f"\n{'=' * 80}")
print("SUMMARY")
print(f"{'=' * 80}")
for task_name, scores in all_scores.items():
print(f"\n{task_name}:")
# Print main metrics
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
if metric in scores:
print(f" {metric}: {scores[metric]:.5f}")
if __name__ == "__main__":
main()

200
docs/COLQWEN_GUIDE.md Normal file
View File

@@ -0,0 +1,200 @@
# ColQwen Integration Guide
Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali models.
## Quick Start
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
### 1. Install Dependencies
```bash
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
brew install poppler # macOS only, for PDF processing
```
### 2. Basic Usage
```bash
# Build index from PDFs
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
# Search with text queries
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
# Interactive Q&A
python -m apps.colqwen_rag ask research_papers --interactive
```
## Commands
### Build Index
```bash
python -m apps.colqwen_rag build \
--pdfs ./pdf_directory/ \
--index my_index \
--model colqwen2 \
--pages-dir ./page_images/ # Optional: save page images
```
**Options:**
- `--pdfs`: Directory containing PDF files (or single PDF path)
- `--index`: Name for the index (required)
- `--model`: `colqwen2` (default) or `colpali`
- `--pages-dir`: Directory to save page images (optional)
### Search Index
```bash
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
```
**Options:**
- `--top-k`: Number of results to return (default: 5)
- `--model`: Model used for search (should match build model)
### Interactive Q&A
```bash
python -m apps.colqwen_rag ask my_index --interactive
```
**Commands in interactive mode:**
- Type your questions naturally
- `help`: Show available commands
- `quit`/`exit`/`q`: Exit interactive mode
## 🧪 Test & Reproduce Results
Run the reproduction test for issue #119:
```bash
python test_colqwen_reproduction.py
```
This will:
1. ✅ Check dependencies
2. 📥 Download sample PDF (Attention Is All You Need paper)
3. 🏗️ Build test index
4. 🔍 Run sample queries
5. 📊 Show how to generate similarity maps
## 🎨 Advanced: Similarity Maps
For visual similarity analysis, use the existing advanced script:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector/
python multi-vector-leann-similarity-map.py
```
Edit the script to customize:
- `QUERY`: Your question
- `MODEL`: "colqwen2" or "colpali"
- `USE_HF_DATASET`: Use HuggingFace dataset or local PDFs
- `SIMILARITY_MAP`: Generate heatmaps
- `ANSWER`: Enable Qwen-VL answer generation
## 🔧 How It Works
### ColQwen2 vs ColPali
- **ColQwen2** (`vidore/colqwen2-v1.0`): Latest vision-language model
- **ColPali** (`vidore/colpali-v1.2`): Proven multimodal retriever
### Architecture
1. **PDF → Images**: Convert PDF pages to images (150 DPI)
2. **Vision Encoding**: Process images with ColQwen2/ColPali
3. **Multi-Vector Index**: Build LEANN HNSW index with multiple embeddings per page
4. **Query Processing**: Encode text queries with same model
5. **Similarity Search**: Find most relevant pages/regions
6. **Visual Maps**: Generate attention heatmaps (optional)
### Device Support
- **CUDA**: Best performance with GPU acceleration
- **MPS**: Apple Silicon Mac support
- **CPU**: Fallback for any system (slower)
Auto-detection: CUDA > MPS > CPU
## 📊 Performance Tips
### For Best Performance:
```bash
# Use ColQwen2 for latest features
--model colqwen2
# Save page images for reuse
--pages-dir ./cached_pages/
# Adjust batch size based on GPU memory
# (automatically handled)
```
### For Large Document Sets:
- Process PDFs in batches
- Use SSD storage for index files
- Consider using CUDA if available
## 🔗 Related Resources
- **Fast-PLAID**: https://github.com/lightonai/fast-plaid
- **Pylate**: https://github.com/lightonai/pylate
- **ColBERT**: https://github.com/stanford-futuredata/ColBERT
- **ColPali Paper**: Vision-Language Models for Document Retrieval
- **Issue #119**: https://github.com/yichuan-w/LEANN/issues/119
## 🐛 Troubleshooting
### PDF Conversion Issues (macOS)
```bash
# Install poppler
brew install poppler
which pdfinfo && pdfinfo -v
```
### Memory Issues
- Reduce batch size (automatically handled)
- Use CPU instead of GPU: `export CUDA_VISIBLE_DEVICES=""`
- Process fewer PDFs at once
### Model Download Issues
- Ensure internet connection for first run
- Models are cached after first download
- Use HuggingFace mirrors if needed
### Import Errors
```bash
# Ensure all dependencies installed
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
# Check PyTorch installation
python -c "import torch; print(torch.__version__)"
```
## 💡 Examples
### Research Paper Analysis
```bash
# Index your research papers
python -m apps.colqwen_rag build --pdfs ~/Papers/AI/ --index ai_papers
# Ask research questions
python -m apps.colqwen_rag search ai_papers "What are the limitations of transformer models?"
python -m apps.colqwen_rag search ai_papers "How does BERT compare to GPT?"
```
### Document Q&A
```bash
# Index business documents
python -m apps.colqwen_rag build --pdfs ~/Documents/Reports/ --index reports
# Interactive analysis
python -m apps.colqwen_rag ask reports --interactive
```
### Visual Analysis
```bash
# Generate similarity maps for specific queries
cd apps/multimodal/vision-based-pdf-multi-vector/
# Edit multi-vector-leann-similarity-map.py with your query
python multi-vector-leann-similarity-map.py
# Check ./figures/ for generated heatmaps
```
---
**🎯 This integration makes ColQwen as easy to use as other LEANN features while maintaining the full power of multimodal document understanding!**

View File

@@ -454,7 +454,7 @@ leann search my-index "your query" \
### 2) Run remote builds with SkyPilot (cloud GPU) ### 2) Run remote builds with SkyPilot (cloud GPU)
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`. Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
```bash ```bash
# One-time: install and configure SkyPilot # One-time: install and configure SkyPilot

View File

@@ -7,7 +7,7 @@ name = "leann-core"
version = "0.3.5" version = "0.3.5"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.10"
license = { text = "MIT" } license = { text = "MIT" }
# All required dependencies included # All required dependencies included

View File

@@ -1251,15 +1251,15 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge." "Please provide the best answer you can based on this context and your knowledge."
) )
print("The context provided to the LLM is:") logger.info("The context provided to the LLM is:")
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}") logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
print("-" * 150) logger.info("-" * 150)
for r in results: for r in results:
chunk_relevance = f"{r.score:.3f}" chunk_relevance = f"{r.score:.3f}"
chunk_id = r.id chunk_id = r.id
chunk_content = r.text[:60] chunk_content = r.text[:60]
chunk_source = r.metadata.get("source", "")[:80] chunk_source = r.metadata.get("source", "")[:80]
print( logger.info(
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}" f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
) )
ask_time = time.time() ask_time = time.time()

View File

@@ -12,7 +12,13 @@ from typing import Any, Optional
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url from .settings import (
resolve_anthropic_api_key,
resolve_anthropic_base_url,
resolve_ollama_host,
resolve_openai_api_key,
resolve_openai_base_url,
)
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -845,6 +851,81 @@ class OpenAIChat(LLMInterface):
return f"Error: Could not get a response from OpenAI. Details: {e}" return f"Error: Could not get a response from OpenAI. Details: {e}"
class AnthropicChat(LLMInterface):
"""LLM interface for Anthropic Claude models."""
def __init__(
self,
model: str = "claude-haiku-4-5",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model
self.base_url = resolve_anthropic_base_url(base_url)
self.api_key = resolve_anthropic_api_key(api_key)
if not self.api_key:
raise ValueError(
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
)
logger.info(
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try:
import anthropic
# Allow custom Anthropic-compatible endpoints via base_url
self.client = anthropic.Anthropic(
api_key=self.api_key,
base_url=self.base_url,
)
except ImportError:
raise ImportError(
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
)
def ask(self, prompt: str, **kwargs) -> str:
logger.info(f"Sending request to Anthropic with model {self.model}")
try:
# Anthropic API parameters
params = {
"model": self.model,
"max_tokens": kwargs.get("max_tokens", 1000),
"messages": [{"role": "user", "content": prompt}],
}
# Add optional parameters
if "temperature" in kwargs:
params["temperature"] = kwargs["temperature"]
if "top_p" in kwargs:
params["top_p"] = kwargs["top_p"]
response = self.client.messages.create(**params)
# Extract text from response
response_text = response.content[0].text
# Log token usage
print(
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
f"input tokens = {response.usage.input_tokens}, "
f"output tokens = {response.usage.output_tokens}"
)
if response.stop_reason == "max_tokens":
print("The query is exceeding the maximum allowed number of tokens")
return response_text.strip()
except Exception as e:
logger.error(f"Error communicating with Anthropic: {e}")
return f"Error: Could not get a response from Anthropic. Details: {e}"
class SimulatedChat(LLMInterface): class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development.""" """A simple simulated chat for testing and development."""
@@ -897,6 +978,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
) )
elif llm_type == "gemini": elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "anthropic":
return AnthropicChat(
model=model or "claude-3-5-sonnet-20241022",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "simulated": elif llm_type == "simulated":
return SimulatedChat() return SimulatedChat()
else: else:

View File

@@ -11,7 +11,12 @@ from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher from .api import LeannBuilder, LeannChat, LeannSearcher
from .interactive_utils import create_cli_session from .interactive_utils import create_cli_session
from .registry import register_project_directory from .registry import register_project_directory
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url from .settings import (
resolve_anthropic_base_url,
resolve_ollama_host,
resolve_openai_api_key,
resolve_openai_base_url,
)
def extract_pdf_text_with_pymupdf(file_path: str) -> str: def extract_pdf_text_with_pymupdf(file_path: str) -> str:
@@ -291,7 +296,7 @@ Examples:
"--llm", "--llm",
type=str, type=str,
default="ollama", default="ollama",
choices=["simulated", "ollama", "hf", "openai"], choices=["simulated", "ollama", "hf", "openai", "anthropic"],
help="LLM provider (default: ollama)", help="LLM provider (default: ollama)",
) )
ask_parser.add_argument( ask_parser.add_argument(
@@ -341,7 +346,7 @@ Examples:
"--api-key", "--api-key",
type=str, type=str,
default=None, default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)", help="API key for cloud LLM providers (OpenAI, Anthropic)",
) )
# List command # List command
@@ -1616,6 +1621,12 @@ Examples:
resolved_api_key = resolve_openai_api_key(args.api_key) resolved_api_key = resolve_openai_api_key(args.api_key)
if resolved_api_key: if resolved_api_key:
llm_config["api_key"] = resolved_api_key llm_config["api_key"] = resolved_api_key
elif args.llm == "anthropic":
# For Anthropic, pass base_url and API key if provided
if args.api_base:
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
if args.api_key:
llm_config["api_key"] = args.api_key
chat = LeannChat(index_path=index_path, llm_config=llm_config) chat = LeannChat(index_path=index_path, llm_config=llm_config)

View File

@@ -9,6 +9,7 @@ from typing import Any
# Default fallbacks to preserve current behaviour while keeping them in one place. # Default fallbacks to preserve current behaviour while keeping them in one place.
_DEFAULT_OLLAMA_HOST = "http://localhost:11434" _DEFAULT_OLLAMA_HOST = "http://localhost:11434"
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1" _DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
def _clean_url(value: str) -> str: def _clean_url(value: str) -> str:
@@ -52,6 +53,23 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
return _clean_url(_DEFAULT_OPENAI_BASE_URL) return _clean_url(_DEFAULT_OPENAI_BASE_URL)
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
"""Resolve the base URL for Anthropic-compatible services."""
candidates = (
explicit,
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
os.getenv("ANTHROPIC_BASE_URL"),
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
def resolve_openai_api_key(explicit: str | None = None) -> str | None: def resolve_openai_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for OpenAI-compatible services.""" """Resolve the API key for OpenAI-compatible services."""
@@ -61,6 +79,15 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
return os.getenv("OPENAI_API_KEY") return os.getenv("OPENAI_API_KEY")
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for Anthropic services."""
if explicit:
return explicit
return os.getenv("ANTHROPIC_API_KEY")
def encode_provider_options(options: dict[str, Any] | None) -> str | None: def encode_provider_options(options: dict[str, Any] | None) -> str | None:
"""Serialize provider options for child processes.""" """Serialize provider options for child processes."""

View File

@@ -53,6 +53,11 @@ leann build my-project --docs $(git ls-files)
# Start Claude Code # Start Claude Code
claude claude
``` ```
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
```bash
leann build my-project --docs $(git ls-files) --no-recompute
```
## 🚀 Advanced Usage Examples to build the index ## 🚀 Advanced Usage Examples to build the index

View File

@@ -7,7 +7,7 @@ name = "leann"
version = "0.3.5" version = "0.3.5"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!" description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.10"
license = { text = "MIT" } license = { text = "MIT" }
authors = [ authors = [
{ name = "LEANN Team" } { name = "LEANN Team" }
@@ -18,10 +18,10 @@ classifiers = [
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
] ]
# Default installation: core + hnsw + diskann # Default installation: core + hnsw + diskann

View File

@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-workspace" name = "leann-workspace"
version = "0.1.0" version = "0.1.0"
requires-python = ">=3.9" requires-python = ">=3.10"
dependencies = [ dependencies = [
"leann-core", "leann-core",

1163
uv.lock generated
View File

File diff suppressed because it is too large Load Diff