Compare commits
17 Commits
feat/add-c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
198044d033 | ||
|
|
a2e5f5294b | ||
|
|
8a2ea37871 | ||
|
|
7ddb4772c0 | ||
|
|
a1c21adbce | ||
|
|
d1b3c93a5a | ||
|
|
a6ee95b18a | ||
|
|
17cbd07b25 | ||
|
|
3629ccf8f7 | ||
|
|
0175bc9c20 | ||
|
|
af47dfdde7 | ||
|
|
f13bd02fbd | ||
|
|
a0bbf831db | ||
|
|
86287d8832 | ||
|
|
13beb98164 | ||
|
|
9b7353f336 | ||
|
|
9dd0e0b26f |
118
.github/workflows/build-reusable.yml
vendored
118
.github/workflows/build-reusable.yml
vendored
@@ -28,15 +28,36 @@ jobs:
|
||||
run: |
|
||||
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||
|
||||
type-check:
|
||||
name: Type Check with ty
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install ty
|
||||
run: uv tool install ty
|
||||
|
||||
- name: Run ty type checker
|
||||
run: |
|
||||
# Run ty on core packages, apps, and tests
|
||||
ty check packages/leann-core/src apps tests
|
||||
|
||||
build:
|
||||
needs: lint
|
||||
needs: [lint, type-check]
|
||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-22.04
|
||||
python: '3.9'
|
||||
# Note: Python 3.9 dropped - uses PEP 604 union syntax (str | None)
|
||||
# which requires Python 3.10+
|
||||
- os: ubuntu-22.04
|
||||
python: '3.10'
|
||||
- os: ubuntu-22.04
|
||||
@@ -46,8 +67,6 @@ jobs:
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
# ARM64 Linux builds
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.9'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.10'
|
||||
- os: ubuntu-24.04-arm
|
||||
@@ -56,8 +75,6 @@ jobs:
|
||||
python: '3.12'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.13'
|
||||
- os: macos-14
|
||||
python: '3.9'
|
||||
- os: macos-14
|
||||
python: '3.10'
|
||||
- os: macos-14
|
||||
@@ -66,8 +83,6 @@ jobs:
|
||||
python: '3.12'
|
||||
- os: macos-14
|
||||
python: '3.13'
|
||||
- os: macos-15
|
||||
python: '3.9'
|
||||
- os: macos-15
|
||||
python: '3.10'
|
||||
- os: macos-15
|
||||
@@ -76,16 +91,24 @@ jobs:
|
||||
python: '3.12'
|
||||
- os: macos-15
|
||||
python: '3.13'
|
||||
- os: macos-13
|
||||
python: '3.9'
|
||||
- os: macos-13
|
||||
# Intel Mac builds (x86_64) - replaces deprecated macos-13
|
||||
# Note: Python 3.13 excluded - PyTorch has no wheels for macOS x86_64 + Python 3.13
|
||||
# (PyTorch <=2.4.1 lacks cp313, PyTorch >=2.5.0 dropped Intel Mac support)
|
||||
- os: macos-15-intel
|
||||
python: '3.10'
|
||||
- os: macos-13
|
||||
- os: macos-15-intel
|
||||
python: '3.11'
|
||||
- os: macos-13
|
||||
- os: macos-15-intel
|
||||
python: '3.12'
|
||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||
# macOS 26 (beta) - arm64
|
||||
- os: macos-26
|
||||
python: '3.10'
|
||||
- os: macos-26
|
||||
python: '3.11'
|
||||
- os: macos-26
|
||||
python: '3.12'
|
||||
- os: macos-26
|
||||
python: '3.13'
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
@@ -204,13 +227,16 @@ jobs:
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
# Set deployment target based on runner
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
@@ -224,14 +250,16 @@ jobs:
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||
# But Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
# Set deployment target based on runner
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
@@ -269,16 +297,19 @@ jobs:
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Determine deployment target based on runner OS
|
||||
# Must match the Homebrew libraries for each macOS version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
HNSW_TARGET="13.0"
|
||||
DISKANN_TARGET="13.3"
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
HNSW_TARGET="14.0"
|
||||
DISKANN_TARGET="14.0"
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
HNSW_TARGET="14.0"
|
||||
DISKANN_TARGET="14.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
HNSW_TARGET="26.0"
|
||||
DISKANN_TARGET="26.0"
|
||||
fi
|
||||
|
||||
# Repair HNSW wheel
|
||||
@@ -334,12 +365,15 @@ jobs:
|
||||
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||
|
||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
63
README.md
63
README.md
@@ -36,7 +36,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
@@ -201,7 +201,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
|
||||
|
||||
#### LLM Backend
|
||||
|
||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
|
||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
|
||||
|
||||
|
||||
<details>
|
||||
@@ -269,6 +269,7 @@ Below is a list of base URLs for common providers to get you started.
|
||||
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||
| **Anthropic** | `https://api.anthropic.com/v1` |
|
||||
|
||||
|
||||
|
||||
@@ -328,7 +329,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||
|
||||
@@ -391,6 +392,54 @@ python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authenticat
|
||||
|
||||
</details>
|
||||
|
||||
### 🎨 ColQwen: Multimodal PDF Retrieval with Vision-Language Models
|
||||
|
||||
Search through PDFs using both text and visual understanding with ColQwen2/ColPali models. Perfect for research papers, technical documents, and any PDFs with complex layouts, figures, or diagrams.
|
||||
|
||||
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||
|
||||
```bash
|
||||
# Build index from PDFs
|
||||
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||
|
||||
# Search with text queries
|
||||
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||
|
||||
# Interactive Q&A
|
||||
python -m apps.colqwen_rag ask research_papers --interactive
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: ColQwen Setup & Usage</strong></summary>
|
||||
|
||||
#### Prerequisites
|
||||
```bash
|
||||
# Install dependencies
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
brew install poppler # macOS only, for PDF processing
|
||||
```
|
||||
|
||||
#### Build Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag build \
|
||||
--pdfs ./pdf_directory/ \
|
||||
--index my_index \
|
||||
--model colqwen2 # or colpali
|
||||
```
|
||||
|
||||
#### Search
|
||||
```bash
|
||||
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||
```
|
||||
|
||||
#### Models
|
||||
- **ColQwen2** (`colqwen2`): Latest vision-language model with improved performance
|
||||
- **ColPali** (`colpali`): Proven multimodal retriever
|
||||
|
||||
For detailed usage, see the [ColQwen Guide](docs/COLQWEN_GUIDE.md).
|
||||
|
||||
</details>
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||
@@ -1057,10 +1106,10 @@ Options:
|
||||
leann ask INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
**List Command:**
|
||||
|
||||
@@ -257,8 +257,8 @@ class BaseRAGExample(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load data from the source. Returns list of text chunks as dicts with 'text' and 'metadata' keys."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
@@ -282,8 +282,8 @@ class BaseRAGExample(ABC):
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
|
||||
"""Build LEANN index from text chunks (dicts with 'text' and 'metadata' keys)."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
@@ -314,8 +314,14 @@ class BaseRAGExample(ABC):
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for text in batch:
|
||||
builder.add_text(text)
|
||||
for item in batch:
|
||||
# Handle both dict format (from create_text_chunks) and plain strings
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text", "")
|
||||
metadata = item.get("metadata")
|
||||
builder.add_text(text, metadata)
|
||||
else:
|
||||
builder.add_text(item)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
|
||||
@@ -6,6 +6,7 @@ Supports Chrome browser history.
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -85,7 +86,7 @@ class BrowserRAG(BaseRAGExample):
|
||||
|
||||
return profiles
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load browser history and convert to text chunks."""
|
||||
# Determine Chrome profiles
|
||||
if args.chrome_profile and not args.auto_find_profiles:
|
||||
|
||||
@@ -5,6 +5,7 @@ Supports ChatGPT export data from chat.html files.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -80,7 +81,7 @@ class ChatGPTRAG(BaseRAGExample):
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load ChatGPT export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ Supports Claude export data from JSON files.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -80,7 +81,7 @@ class ClaudeRAG(BaseRAGExample):
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load Claude export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ optimized chunking parameters.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -77,7 +78,7 @@ class CodeRAG(BaseRAGExample):
|
||||
help="Try to preserve import statements in chunks (default: True)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load code files and convert to AST-aware chunks."""
|
||||
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||
print(f"📁 Including extensions: {args.include_extensions}")
|
||||
@@ -88,14 +89,6 @@ class CodeRAG(BaseRAGExample):
|
||||
if not repo_path.exists():
|
||||
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||
|
||||
# Load code files with filtering
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
"required_exts": args.include_extensions,
|
||||
"exclude_hidden": True,
|
||||
}
|
||||
|
||||
# Create exclusion filter
|
||||
def file_filter(file_path: str) -> bool:
|
||||
"""Filter out unwanted files and directories."""
|
||||
@@ -120,8 +113,11 @@ class CodeRAG(BaseRAGExample):
|
||||
# Load documents with file filtering
|
||||
documents = SimpleDirectoryReader(
|
||||
args.repo_dir,
|
||||
file_extractor=None, # Use default extractors
|
||||
**reader_kwargs,
|
||||
file_extractor=None,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=args.include_extensions,
|
||||
exclude_hidden=True,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Apply custom filtering
|
||||
|
||||
364
apps/colqwen_rag.py
Normal file
364
apps/colqwen_rag.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
|
||||
|
||||
Usage:
|
||||
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
|
||||
python -m apps.colqwen_rag search my_index "How does attention work?"
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
# Add LEANN packages to path
|
||||
_repo_root = Path(__file__).resolve().parents[1]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
import torch # noqa: E402
|
||||
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
|
||||
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
|
||||
from pdf2image import convert_from_path # noqa: E402
|
||||
from PIL import Image # noqa: E402
|
||||
from torch.utils.data import DataLoader # noqa: E402
|
||||
from tqdm import tqdm # noqa: E402
|
||||
|
||||
# Import the existing multi-vector implementation
|
||||
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
|
||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||
|
||||
|
||||
class ColQwenRAG:
|
||||
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
|
||||
|
||||
def __init__(self, model_type: str = "colpali"):
|
||||
"""
|
||||
Initialize ColQwen RAG system.
|
||||
|
||||
Args:
|
||||
model_type: "colqwen2" or "colpali"
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.device = self._get_device()
|
||||
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
|
||||
if self.device.type == "mps":
|
||||
self.dtype = torch.float32
|
||||
elif self.device.type == "cuda":
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
|
||||
|
||||
# Load model and processor with MPS-optimized settings
|
||||
try:
|
||||
if model_type == "colqwen2":
|
||||
self.model_name = "vidore/colqwen2-v1.0"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
|
||||
else: # colpali
|
||||
self.model_name = "vidore/colpali-v1.2"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
|
||||
except Exception as e:
|
||||
if "memory" in str(e).lower() or "offload" in str(e).lower():
|
||||
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_type == "colqwen2":
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_device(self):
|
||||
"""Auto-select best available device."""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
|
||||
"""
|
||||
Build multimodal index from PDF files.
|
||||
|
||||
Args:
|
||||
pdf_paths: List of PDF file paths
|
||||
index_name: Name for the index
|
||||
pages_dir: Directory to save page images (optional)
|
||||
"""
|
||||
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
|
||||
|
||||
# Convert PDFs to images
|
||||
all_images = []
|
||||
all_metadata = []
|
||||
|
||||
if pages_dir:
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
|
||||
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
|
||||
try:
|
||||
images = convert_from_path(pdf_path, dpi=150)
|
||||
pdf_name = Path(pdf_path).stem
|
||||
|
||||
for i, image in enumerate(images):
|
||||
# Save image if pages_dir specified
|
||||
if pages_dir:
|
||||
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
|
||||
image.save(image_path)
|
||||
|
||||
all_images.append(image)
|
||||
all_metadata.append(
|
||||
{
|
||||
"pdf_path": pdf_path,
|
||||
"pdf_name": pdf_name,
|
||||
"page_number": i + 1,
|
||||
"image_path": str(image_path) if pages_dir else None,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {pdf_path}: {e}")
|
||||
continue
|
||||
|
||||
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
|
||||
print(f"All metadata: {all_metadata}")
|
||||
|
||||
# Generate embeddings
|
||||
print("🧠 Generating embeddings...")
|
||||
embeddings = self._embed_images(all_images)
|
||||
|
||||
# Build LEANN index
|
||||
print("🔍 Building LEANN index...")
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=embeddings.shape[-1],
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Create collection and insert data
|
||||
leann_mv.create_collection()
|
||||
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
|
||||
data = {
|
||||
"doc_id": i,
|
||||
"filepath": metadata.get("image_path", ""),
|
||||
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
|
||||
}
|
||||
leann_mv.insert(data)
|
||||
|
||||
# Build the index
|
||||
leann_mv.create_index()
|
||||
print(f"✅ Index '{index_name}' built successfully!")
|
||||
|
||||
return leann_mv
|
||||
|
||||
def search(self, index_name: str, query: str, top_k: int = 5):
|
||||
"""
|
||||
Search the index with a text query.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to search
|
||||
query: Text query
|
||||
top_k: Number of results to return
|
||||
"""
|
||||
print(f"🔍 Searching '{index_name}' for: '{query}'")
|
||||
|
||||
# Load index
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=128, # Will be updated when loading
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = self._embed_query(query)
|
||||
|
||||
# Search (returns list of (score, doc_id) tuples)
|
||||
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
|
||||
|
||||
# Display results
|
||||
print(f"\n📋 Top {len(search_results)} results:")
|
||||
for i, (score, doc_id) in enumerate(search_results, 1):
|
||||
# Get metadata for this doc_id (we need to load the metadata)
|
||||
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
|
||||
|
||||
return search_results
|
||||
|
||||
def ask(self, index_name: str, interactive: bool = False):
|
||||
"""
|
||||
Interactive Q&A with the indexed documents.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to query
|
||||
interactive: Whether to run in interactive mode
|
||||
"""
|
||||
print(f"💬 ColQwen Chat with '{index_name}'")
|
||||
|
||||
if interactive:
|
||||
print("Type 'quit' to exit, 'help' for commands")
|
||||
while True:
|
||||
try:
|
||||
query = input("\n🤔 Your question: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
elif query.lower() == "help":
|
||||
print("Commands: quit/exit/q (exit), help (this message)")
|
||||
continue
|
||||
elif not query:
|
||||
continue
|
||||
|
||||
self.search(index_name, query, top_k=3)
|
||||
|
||||
# TODO: Add answer generation with Qwen-VL
|
||||
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
else:
|
||||
query = input("🤔 Your question: ").strip()
|
||||
if query:
|
||||
self.search(index_name, query)
|
||||
|
||||
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
|
||||
"""Generate embeddings for a list of images."""
|
||||
dataset = ListDataset(images)
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
|
||||
|
||||
embeddings = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Embedding images"):
|
||||
batch_images = cast(list, batch)
|
||||
batch_inputs = self.processor.process_images(batch_images).to(self.device)
|
||||
batch_embeddings = self.model(**batch_inputs)
|
||||
embeddings.append(batch_embeddings.cpu())
|
||||
|
||||
return torch.cat(embeddings, dim=0)
|
||||
|
||||
def _embed_query(self, query: str) -> torch.Tensor:
|
||||
"""Generate embedding for a text query."""
|
||||
with torch.no_grad():
|
||||
query_inputs = self.processor.process_queries([query]).to(self.device)
|
||||
query_embedding = self.model(**query_inputs)
|
||||
return query_embedding.cpu()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
|
||||
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
|
||||
build_parser.add_argument("--index", required=True, help="Index name")
|
||||
build_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
build_parser.add_argument("--pages-dir", help="Directory to save page images")
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search the index")
|
||||
search_parser.add_argument("index", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
|
||||
search_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
|
||||
ask_parser.add_argument("index", help="Index name")
|
||||
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
||||
ask_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Initialize ColQwen RAG
|
||||
if args.command == "build":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
|
||||
# Get PDF files
|
||||
pdf_dir = Path(args.pdfs)
|
||||
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
|
||||
pdf_paths = [str(pdf_dir)]
|
||||
elif pdf_dir.is_dir():
|
||||
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
|
||||
else:
|
||||
print(f"❌ Invalid PDF path: {args.pdfs}")
|
||||
return
|
||||
|
||||
if not pdf_paths:
|
||||
print(f"❌ No PDF files found in {args.pdfs}")
|
||||
return
|
||||
|
||||
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
|
||||
|
||||
elif args.command == "search":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.search(args.index, args.query, args.top_k)
|
||||
|
||||
elif args.command == "ask":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.ask(args.index, args.interactive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,6 +5,7 @@ Supports PDF, TXT, MD, and other document formats.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -51,7 +52,7 @@ class DocumentRAG(BaseRAGExample):
|
||||
help="Enable AST-aware chunking for code files in the data directory",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
@@ -65,16 +66,12 @@ class DocumentRAG(BaseRAGExample):
|
||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||
|
||||
# Load documents
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
if args.file_types:
|
||||
reader_kwargs["required_exts"] = args.file_types
|
||||
|
||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||
show_progress=True
|
||||
)
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=args.file_types if args.file_types else None,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
if not documents:
|
||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||
|
||||
@@ -127,11 +127,12 @@ class EmlxMboxReader(MboxReader):
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
directory: Path,
|
||||
file: Path, # Note: for EmlxMboxReader, this is actually a directory
|
||||
extra_info: dict | None = None,
|
||||
fs: AbstractFileSystem | None = None,
|
||||
) -> list[Document]:
|
||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||
directory = file # Rename for clarity - this is a directory of .emlx files
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ Supports Apple Mail on macOS.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -64,7 +65,7 @@ class EmailRAG(BaseRAGExample):
|
||||
|
||||
return messages_dirs
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load emails and convert to text chunks."""
|
||||
# Determine mail directories
|
||||
if args.mail_path:
|
||||
|
||||
@@ -86,7 +86,7 @@ class WeChatHistoryReader(BaseReader):
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip()
|
||||
return result.returncode == 0 and bool(result.stdout.strip())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -314,7 +314,9 @@ class WeChatHistoryReader(BaseReader):
|
||||
|
||||
return concatenated_groups
|
||||
|
||||
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
||||
def _create_concatenated_content(
|
||||
self, message_group: dict, contact_name: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create concatenated content from a group of messages.
|
||||
|
||||
|
||||
219
apps/image_rag.py
Normal file
219
apps/image_rag.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLIP Image RAG Application
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
||||
You can index a directory of images and search them using text queries.
|
||||
|
||||
Usage:
|
||||
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
||||
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
|
||||
|
||||
class ImageRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for images using CLIP embeddings.
|
||||
|
||||
This class provides a complete RAG pipeline for image data, including
|
||||
CLIP embedding generation, indexing, and text-based image search.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Image RAG",
|
||||
description="RAG application for images using CLIP embeddings",
|
||||
default_index_name="image_index",
|
||||
)
|
||||
# Override default embedding model to use CLIP
|
||||
self.embedding_model_default = "clip-ViT-L-14"
|
||||
self.embedding_mode_default = "sentence-transformers"
|
||||
self._image_data: list[dict] = []
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add image-specific arguments."""
|
||||
image_group = parser.add_argument_group("Image Parameters")
|
||||
image_group.add_argument(
|
||||
"--image-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing images to index",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--image-extensions",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
||||
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for CLIP embedding generation (default: 32)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||
self._image_data = self._load_images_and_embeddings(args)
|
||||
return [entry["text"] for entry in self._image_data]
|
||||
|
||||
def _load_images_and_embeddings(self, args) -> list[dict]:
|
||||
"""Helper to process images and produce embeddings/metadata."""
|
||||
image_dir = Path(args.image_dir)
|
||||
if not image_dir.exists():
|
||||
raise ValueError(f"Image directory does not exist: {image_dir}")
|
||||
|
||||
print(f"📸 Loading images from {image_dir}...")
|
||||
|
||||
# Find all image files
|
||||
image_files = []
|
||||
for ext in args.image_extensions:
|
||||
image_files.extend(image_dir.rglob(f"*{ext}"))
|
||||
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
||||
|
||||
if not image_files:
|
||||
raise ValueError(
|
||||
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
||||
)
|
||||
|
||||
print(f"✅ Found {len(image_files)} images")
|
||||
|
||||
# Limit if max_items is set
|
||||
if args.max_items > 0:
|
||||
image_files = image_files[: args.max_items]
|
||||
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
||||
|
||||
# Load CLIP model
|
||||
print("🔍 Loading CLIP model...")
|
||||
model = SentenceTransformer(self.embedding_model_default)
|
||||
|
||||
# Process images and generate embeddings
|
||||
print("🖼️ Processing images and generating embeddings...")
|
||||
image_data = []
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
for image_path in tqdm(image_files, desc="Processing images"):
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
batch_images.append(image)
|
||||
batch_paths.append(image_path)
|
||||
|
||||
# Process in batches
|
||||
if len(batch_images) >= args.batch_size:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=args.batch_size,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to process {image_path}: {e}")
|
||||
continue
|
||||
|
||||
# Process remaining images
|
||||
if batch_images:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=len(batch_images),
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"✅ Processed {len(image_data)} images")
|
||||
return image_data
|
||||
|
||||
async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
|
||||
"""Build index using pre-computed CLIP embeddings."""
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if not self._image_data or len(self._image_data) != len(texts):
|
||||
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
||||
|
||||
print("🔨 Building LEANN index with CLIP embeddings...")
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=self.embedding_model_default,
|
||||
embedding_mode=self.embedding_mode_default,
|
||||
is_recompute=False,
|
||||
distance_metric="cosine",
|
||||
graph_degree=args.graph_degree,
|
||||
build_complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
)
|
||||
|
||||
for text, data in zip(texts, self._image_data):
|
||||
builder.add_text(text=text, metadata=data["metadata"])
|
||||
|
||||
ids = [str(i) for i in range(len(self._image_data))]
|
||||
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
||||
pickle.dump((ids, embeddings), f)
|
||||
pkl_path = f.name
|
||||
|
||||
try:
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
builder.build_index_from_embeddings(index_path, pkl_path)
|
||||
print(f"✅ Index built successfully at {index_path}")
|
||||
return index_path
|
||||
finally:
|
||||
Path(pkl_path).unlink()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the image RAG application."""
|
||||
import asyncio
|
||||
|
||||
app = ImageRAG()
|
||||
asyncio.run(app.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,6 +6,7 @@ This example demonstrates how to build a RAG system on your iMessage conversatio
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from leann.chunking_utils import create_text_chunks
|
||||
|
||||
@@ -56,7 +57,7 @@ class IMessageRAG(BaseRAGExample):
|
||||
help="Overlap between text chunks (default: 200)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load iMessage history and convert to text chunks."""
|
||||
print("Loading iMessage conversation history...")
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import concurrent.futures
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@@ -11,6 +13,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||
@@ -96,12 +100,63 @@ def _natural_sort_key(name: str) -> int:
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||
filenames = sorted(filenames, key=_natural_sort_key)
|
||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||
images = [Image.open(p) for p in filepaths]
|
||||
return filepaths, images
|
||||
def _load_images_from_dir(
|
||||
pages_dir: str, recursive: bool = False
|
||||
) -> tuple[list[str], list[Image.Image]]:
|
||||
"""
|
||||
Load images from a directory.
|
||||
|
||||
Args:
|
||||
pages_dir: Directory path containing images
|
||||
recursive: If True, recursively search subdirectories (default: False)
|
||||
|
||||
Returns:
|
||||
Tuple of (filepaths, images)
|
||||
"""
|
||||
|
||||
# Supported image extensions
|
||||
extensions = ("*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG", "*.webp", "*.WEBP")
|
||||
|
||||
if recursive:
|
||||
# Recursive search
|
||||
filepaths = []
|
||||
for ext in extensions:
|
||||
pattern = os.path.join(pages_dir, "**", ext)
|
||||
filepaths.extend(glob.glob(pattern, recursive=True))
|
||||
else:
|
||||
# Non-recursive search (only top-level directory)
|
||||
filepaths = []
|
||||
for ext in extensions:
|
||||
pattern = os.path.join(pages_dir, ext)
|
||||
filepaths.extend(glob.glob(pattern))
|
||||
|
||||
# Sort files naturally
|
||||
filepaths = sorted(filepaths, key=lambda x: _natural_sort_key(os.path.basename(x)))
|
||||
|
||||
# Load images with error handling
|
||||
images = []
|
||||
valid_filepaths = []
|
||||
failed_count = 0
|
||||
|
||||
for filepath in filepaths:
|
||||
try:
|
||||
img = Image.open(filepath)
|
||||
# Convert to RGB if necessary (handles RGBA, P, etc.)
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
images.append(img)
|
||||
valid_filepaths.append(filepath)
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
print(f"Warning: Failed to load image {filepath}: {e}")
|
||||
continue
|
||||
|
||||
if failed_count > 0:
|
||||
print(
|
||||
f"Warning: Failed to load {failed_count} image(s) out of {len(filepaths)} total files"
|
||||
)
|
||||
|
||||
return valid_filepaths, images
|
||||
|
||||
|
||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||
@@ -151,36 +206,99 @@ def _select_device_and_dtype():
|
||||
|
||||
|
||||
def _load_colvision(model_choice: str):
|
||||
import os
|
||||
|
||||
import torch
|
||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||
from colpali_engine.models import (
|
||||
ColPali,
|
||||
ColQwen2,
|
||||
ColQwen2_5,
|
||||
ColQwen2_5_Processor,
|
||||
ColQwen2Processor,
|
||||
)
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
|
||||
# Set environment variables to ensure models are downloaded from HuggingFace
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
|
||||
# Log model loading info
|
||||
logger.info(f"Loading ColVision model: {model_choice}")
|
||||
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
|
||||
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
|
||||
|
||||
device_str, device, dtype = _select_device_and_dtype()
|
||||
|
||||
# Determine model name and type
|
||||
# IMPORTANT: Check colqwen2.5 BEFORE colqwen2 to avoid false matches
|
||||
model_choice_lower = model_choice.lower()
|
||||
if model_choice == "colqwen2":
|
||||
model_name = "vidore/colqwen2-v1.0"
|
||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||
attn_implementation = (
|
||||
"flash_attention_2"
|
||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||
else "eager"
|
||||
)
|
||||
model_type = "colqwen2"
|
||||
elif model_choice == "colqwen2.5" or model_choice == "colqwen25":
|
||||
model_name = "vidore/colqwen2.5-v0.2"
|
||||
model_type = "colqwen2.5"
|
||||
elif model_choice == "colpali":
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model_type = "colpali"
|
||||
elif (
|
||||
"colqwen2.5" in model_choice_lower
|
||||
or "colqwen25" in model_choice_lower
|
||||
or "colqwen2_5" in model_choice_lower
|
||||
):
|
||||
# Handle HuggingFace model names like "vidore/colqwen2.5-v0.2"
|
||||
model_name = model_choice
|
||||
model_type = "colqwen2.5"
|
||||
elif "colqwen2" in model_choice_lower and "colqwen2-v1.0" in model_choice_lower:
|
||||
# Handle HuggingFace model names like "vidore/colqwen2-v1.0" (but not colqwen2.5)
|
||||
model_name = model_choice
|
||||
model_type = "colqwen2"
|
||||
elif "colpali" in model_choice_lower:
|
||||
# Handle HuggingFace model names like "vidore/colpali-v1.2"
|
||||
model_name = model_choice
|
||||
model_type = "colpali"
|
||||
else:
|
||||
# Default to colpali for backward compatibility
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model_type = "colpali"
|
||||
|
||||
# Load model based on type
|
||||
attn_implementation = (
|
||||
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
|
||||
)
|
||||
|
||||
# Load model from HuggingFace Hub (not Google Drive)
|
||||
# Use local_files_only=False to ensure download from HF if not cached
|
||||
if model_type == "colqwen2.5":
|
||||
model = ColQwen2_5.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||
).eval()
|
||||
processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
|
||||
elif model_type == "colqwen2":
|
||||
model = ColQwen2.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||
).eval()
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
else:
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False)
|
||||
else: # colpali
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||
).eval()
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
processor = cast(
|
||||
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
|
||||
)
|
||||
|
||||
return model_name, model, processor, device_str, device, dtype
|
||||
|
||||
|
||||
@@ -18,10 +18,11 @@ _repo_root = Path(__file__).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
sys.path.insert(0, str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
sys.path.insert(0, str(_leann_hnsw_pkg))
|
||||
|
||||
from leann_multi_vector import LeannMultiVector
|
||||
|
||||
import torch
|
||||
from colpali_engine.models import ColPali
|
||||
@@ -93,9 +94,9 @@ for batch_doc in tqdm(dataloader):
|
||||
print(ds[0].shape)
|
||||
|
||||
# %%
|
||||
# Build HNSW index via LeannRetriever primitives and run search
|
||||
# Build HNSW index via LeannMultiVector primitives and run search
|
||||
index_path = "./indexes/colpali.leann"
|
||||
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||
retriever = LeannMultiVector(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||
retriever.create_collection()
|
||||
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||
for i in range(len(filepaths)):
|
||||
|
||||
@@ -5,7 +5,7 @@ import argparse
|
||||
import faulthandler
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -62,7 +62,7 @@ DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||
DATASET_NAMES = [
|
||||
"weaviate/arXiv-AI-papers-multi-vector",
|
||||
("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||
]
|
||||
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||
# Set to None to try loading all available splits automatically
|
||||
@@ -75,6 +75,11 @@ MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||
# Local pages (used when USE_HF_DATASET == False)
|
||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||
PAGES_DIR: str = "./pages"
|
||||
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
|
||||
# If set, images will be loaded directly from this folder
|
||||
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
|
||||
# Whether to recursively search subdirectories when loading from custom folder
|
||||
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
|
||||
|
||||
# Index + retrieval settings
|
||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||
@@ -83,7 +88,7 @@ INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||
# These are now command-line arguments (see CLI overrides section)
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = True
|
||||
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
@@ -128,12 +133,33 @@ parser.add_argument(
|
||||
default=TOPK,
|
||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Recursively search subdirectories when loading images from custom folder. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
|
||||
)
|
||||
cli_args, _unknown = parser.parse_known_args()
|
||||
SEARCH_METHOD: str = cli_args.search_method
|
||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
|
||||
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
|
||||
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
|
||||
|
||||
# %%
|
||||
|
||||
@@ -180,8 +206,24 @@ else:
|
||||
# Step 2: Load data only if we need to build the index
|
||||
if need_to_build_index:
|
||||
print("Loading dataset...")
|
||||
if USE_HF_DATASET:
|
||||
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
||||
# Check for custom folder path first (takes precedence)
|
||||
if CUSTOM_FOLDER_PATH:
|
||||
if not os.path.isdir(CUSTOM_FOLDER_PATH):
|
||||
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
|
||||
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
|
||||
if CUSTOM_FOLDER_RECURSIVE:
|
||||
print(" (recursive mode: searching subdirectories)")
|
||||
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
|
||||
print(f" Found {len(filepaths)} image files")
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
|
||||
)
|
||||
print(f" Successfully loaded {len(images)} images")
|
||||
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||
elif USE_HF_DATASET:
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
|
||||
|
||||
# Determine which datasets to load
|
||||
if DATASET_NAMES is not None:
|
||||
@@ -239,12 +281,12 @@ if need_to_build_index:
|
||||
splits_to_load = DATASET_SPLITS
|
||||
|
||||
# Load and concatenate multiple splits for this dataset
|
||||
datasets_to_concat = []
|
||||
datasets_to_concat: list[Dataset] = []
|
||||
for split in splits_to_load:
|
||||
if split not in dataset_dict:
|
||||
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
|
||||
continue
|
||||
split_dataset = dataset_dict[split]
|
||||
split_dataset = cast(Dataset, dataset_dict[split])
|
||||
print(f" Loaded split '{split}': {len(split_dataset)} pages")
|
||||
datasets_to_concat.append(split_dataset)
|
||||
|
||||
@@ -621,7 +663,6 @@ else:
|
||||
except Exception:
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
|
||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||
|
||||
# %%
|
||||
# Step 6: Similarity maps for top-K results
|
||||
|
||||
@@ -25,9 +25,9 @@ Usage:
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from datasets import load_dataset
|
||||
from datasets import Dataset, load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
@@ -90,6 +90,51 @@ VIDORE_V1_TASKS = {
|
||||
},
|
||||
}
|
||||
|
||||
# Task name aliases (short names -> full names)
|
||||
TASK_ALIASES = {
|
||||
"arxivqa": "VidoreArxivQARetrieval",
|
||||
"docvqa": "VidoreDocVQARetrieval",
|
||||
"infovqa": "VidoreInfoVQARetrieval",
|
||||
"tabfquad": "VidoreTabfquadRetrieval",
|
||||
"tatdqa": "VidoreTatdqaRetrieval",
|
||||
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||
}
|
||||
|
||||
|
||||
def normalize_task_name(task_name: str) -> str:
|
||||
"""Normalize task name (handle aliases)."""
|
||||
task_name_lower = task_name.lower()
|
||||
if task_name in VIDORE_V1_TASKS:
|
||||
return task_name
|
||||
if task_name_lower in TASK_ALIASES:
|
||||
return TASK_ALIASES[task_name_lower]
|
||||
# Try partial match
|
||||
for alias, full_name in TASK_ALIASES.items():
|
||||
if alias in task_name_lower or task_name_lower in alias:
|
||||
return full_name
|
||||
return task_name
|
||||
|
||||
|
||||
def get_safe_model_name(model_name: str) -> str:
|
||||
"""Get a safe model name for use in file paths."""
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
# If it's a path, use basename or hash
|
||||
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||
# Use basename if it's reasonable, otherwise use hash
|
||||
basename = os.path.basename(model_name.rstrip("/"))
|
||||
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||
return basename
|
||||
# Use hash for very long or problematic paths
|
||||
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||
# For HuggingFace model names, replace / with _
|
||||
return model_name.replace("/", "_").replace(":", "_")
|
||||
|
||||
|
||||
def load_vidore_v1_data(
|
||||
dataset_path: str,
|
||||
@@ -106,40 +151,43 @@ def load_vidore_v1_data(
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
|
||||
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
|
||||
|
||||
queries = {}
|
||||
queries: dict[str, str] = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
queries[query_id] = row_dict["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
# Load corpus (images) - cast to Dataset
|
||||
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
|
||||
|
||||
corpus = {}
|
||||
corpus: dict[str, Any] = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
if "image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["image"]
|
||||
elif "page_image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
# Load qrels (relevance judgments) - cast to Dataset
|
||||
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
|
||||
|
||||
qrels = {}
|
||||
qrels: dict[str, dict[str, int]] = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
qrels[query_id][corpus_id] = int(row_dict["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
@@ -181,13 +229,16 @@ def evaluate_task(
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Normalize task name (handle aliases)
|
||||
task_name = normalize_task_name(task_name)
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V1_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V1_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
dataset_path = str(task_config["dataset_path"])
|
||||
revision = str(task_config["revision"])
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v1_data(
|
||||
@@ -223,11 +274,13 @@ def evaluate_task(
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
# Use safe model name for index path (different models need different indexes)
|
||||
safe_model_name = get_safe_model_name(model_name)
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
@@ -236,7 +289,7 @@ def evaluate_task(
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
@@ -281,8 +334,7 @@ def main():
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
choices=["colqwen2", "colpali"],
|
||||
help="Model to use",
|
||||
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
@@ -350,11 +402,11 @@ def main():
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [args.task]
|
||||
tasks_to_eval = [normalize_task_name(args.task)]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ Usage:
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from datasets import load_dataset
|
||||
from datasets import Dataset, load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
@@ -91,8 +91,8 @@ def load_vidore_v2_data(
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
|
||||
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
|
||||
|
||||
# Check if dataset has language field before filtering
|
||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||
@@ -112,8 +112,9 @@ def load_vidore_v2_data(
|
||||
if len(query_ds_filtered) == 0:
|
||||
# Try to get a sample to see actual language values
|
||||
try:
|
||||
sample_ds = load_dataset(
|
||||
dataset_path, "queries", split=split, revision=revision
|
||||
sample_ds = cast(
|
||||
Dataset,
|
||||
load_dataset(dataset_path, "queries", split=split, revision=revision),
|
||||
)
|
||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||
sample_langs = set(sample_ds["language"])
|
||||
@@ -126,37 +127,40 @@ def load_vidore_v2_data(
|
||||
)
|
||||
query_ds = query_ds_filtered
|
||||
|
||||
queries = {}
|
||||
queries: dict[str, str] = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
queries[query_id] = row_dict["query"]
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
# Load corpus (images) - cast to Dataset
|
||||
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
|
||||
|
||||
corpus = {}
|
||||
corpus: dict[str, Any] = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row:
|
||||
corpus[corpus_id] = row["image"]
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
if "image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["image"]
|
||||
elif "page_image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
# Load qrels (relevance judgments) - cast to Dataset
|
||||
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
|
||||
|
||||
qrels = {}
|
||||
qrels: dict[str, dict[str, int]] = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
qrels[query_id][corpus_id] = int(row_dict["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
@@ -204,13 +208,13 @@ def evaluate_task(
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V2_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
dataset_path = str(task_config["dataset_path"])
|
||||
revision = str(task_config["revision"])
|
||||
|
||||
# Determine language
|
||||
if language is None:
|
||||
# Use first language if multiple available
|
||||
languages = task_config.get("languages")
|
||||
languages = cast(Optional[list[str]], task_config.get("languages"))
|
||||
if languages is None:
|
||||
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||
language = None
|
||||
@@ -269,7 +273,7 @@ def evaluate_task(
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
|
||||
@@ -177,7 +177,9 @@ class SlackMCPReader:
|
||||
break
|
||||
|
||||
# If we get here, all retries failed or it's not a retryable error
|
||||
raise last_exception
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
raise RuntimeError("Unexpected error: no exception captured during retry loop")
|
||||
|
||||
async def fetch_slack_messages(
|
||||
self, channel: Optional[str] = None, limit: int = 100
|
||||
@@ -267,7 +269,10 @@ class SlackMCPReader:
|
||||
messages = json.loads(content["text"])
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, try to parse as CSV format (Slack MCP server format)
|
||||
messages = self._parse_csv_messages(content["text"], channel)
|
||||
text_content = content.get("text", "")
|
||||
messages = self._parse_csv_messages(
|
||||
text_content if text_content else "", channel or "unknown"
|
||||
)
|
||||
else:
|
||||
messages = result["content"]
|
||||
else:
|
||||
|
||||
@@ -11,6 +11,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
||||
@@ -139,7 +140,7 @@ class SlackMCPRAG(BaseRAGExample):
|
||||
print("4. Try running the MCP server command directly to test it")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load Slack messages via MCP server."""
|
||||
print(f"Connecting to Slack MCP server: {args.mcp_server}")
|
||||
|
||||
@@ -188,7 +189,8 @@ class SlackMCPRAG(BaseRAGExample):
|
||||
print(sample_text)
|
||||
print("-" * 40)
|
||||
|
||||
return texts
|
||||
# Convert strings to dict format expected by base class
|
||||
return [{"text": text, "metadata": {"source": "slack"}} for text in texts]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading Slack data: {e}")
|
||||
|
||||
@@ -11,6 +11,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
||||
@@ -116,7 +117,7 @@ class TwitterMCPRAG(BaseRAGExample):
|
||||
print("5. Try running the MCP server command directly to test it")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load Twitter bookmarks via MCP server."""
|
||||
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
|
||||
|
||||
@@ -156,7 +157,8 @@ class TwitterMCPRAG(BaseRAGExample):
|
||||
print(sample_text)
|
||||
print("-" * 50)
|
||||
|
||||
return texts
|
||||
# Convert strings to dict format expected by base class
|
||||
return [{"text": text, "metadata": {"source": "twitter"}} for text in texts]
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading Twitter bookmarks: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@ Supports WeChat chat history export and search.
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -91,7 +92,7 @@ class WeChatRAG(BaseRAGExample):
|
||||
print(f"Export error: {e}")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load WeChat history and convert to text chunks."""
|
||||
# Initialize WeChat reader with export capabilities
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
200
docs/COLQWEN_GUIDE.md
Normal file
200
docs/COLQWEN_GUIDE.md
Normal file
@@ -0,0 +1,200 @@
|
||||
# ColQwen Integration Guide
|
||||
|
||||
Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali models.
|
||||
|
||||
## Quick Start
|
||||
|
||||
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||
|
||||
### 1. Install Dependencies
|
||||
```bash
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
brew install poppler # macOS only, for PDF processing
|
||||
```
|
||||
|
||||
### 2. Basic Usage
|
||||
```bash
|
||||
# Build index from PDFs
|
||||
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||
|
||||
# Search with text queries
|
||||
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||
|
||||
# Interactive Q&A
|
||||
python -m apps.colqwen_rag ask research_papers --interactive
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
### Build Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag build \
|
||||
--pdfs ./pdf_directory/ \
|
||||
--index my_index \
|
||||
--model colqwen2 \
|
||||
--pages-dir ./page_images/ # Optional: save page images
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--pdfs`: Directory containing PDF files (or single PDF path)
|
||||
- `--index`: Name for the index (required)
|
||||
- `--model`: `colqwen2` (default) or `colpali`
|
||||
- `--pages-dir`: Directory to save page images (optional)
|
||||
|
||||
### Search Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--top-k`: Number of results to return (default: 5)
|
||||
- `--model`: Model used for search (should match build model)
|
||||
|
||||
### Interactive Q&A
|
||||
```bash
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
```
|
||||
|
||||
**Commands in interactive mode:**
|
||||
- Type your questions naturally
|
||||
- `help`: Show available commands
|
||||
- `quit`/`exit`/`q`: Exit interactive mode
|
||||
|
||||
## 🧪 Test & Reproduce Results
|
||||
|
||||
Run the reproduction test for issue #119:
|
||||
```bash
|
||||
python test_colqwen_reproduction.py
|
||||
```
|
||||
|
||||
This will:
|
||||
1. ✅ Check dependencies
|
||||
2. 📥 Download sample PDF (Attention Is All You Need paper)
|
||||
3. 🏗️ Build test index
|
||||
4. 🔍 Run sample queries
|
||||
5. 📊 Show how to generate similarity maps
|
||||
|
||||
## 🎨 Advanced: Similarity Maps
|
||||
|
||||
For visual similarity analysis, use the existing advanced script:
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
|
||||
Edit the script to customize:
|
||||
- `QUERY`: Your question
|
||||
- `MODEL`: "colqwen2" or "colpali"
|
||||
- `USE_HF_DATASET`: Use HuggingFace dataset or local PDFs
|
||||
- `SIMILARITY_MAP`: Generate heatmaps
|
||||
- `ANSWER`: Enable Qwen-VL answer generation
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
### ColQwen2 vs ColPali
|
||||
- **ColQwen2** (`vidore/colqwen2-v1.0`): Latest vision-language model
|
||||
- **ColPali** (`vidore/colpali-v1.2`): Proven multimodal retriever
|
||||
|
||||
### Architecture
|
||||
1. **PDF → Images**: Convert PDF pages to images (150 DPI)
|
||||
2. **Vision Encoding**: Process images with ColQwen2/ColPali
|
||||
3. **Multi-Vector Index**: Build LEANN HNSW index with multiple embeddings per page
|
||||
4. **Query Processing**: Encode text queries with same model
|
||||
5. **Similarity Search**: Find most relevant pages/regions
|
||||
6. **Visual Maps**: Generate attention heatmaps (optional)
|
||||
|
||||
### Device Support
|
||||
- **CUDA**: Best performance with GPU acceleration
|
||||
- **MPS**: Apple Silicon Mac support
|
||||
- **CPU**: Fallback for any system (slower)
|
||||
|
||||
Auto-detection: CUDA > MPS > CPU
|
||||
|
||||
## 📊 Performance Tips
|
||||
|
||||
### For Best Performance:
|
||||
```bash
|
||||
# Use ColQwen2 for latest features
|
||||
--model colqwen2
|
||||
|
||||
# Save page images for reuse
|
||||
--pages-dir ./cached_pages/
|
||||
|
||||
# Adjust batch size based on GPU memory
|
||||
# (automatically handled)
|
||||
```
|
||||
|
||||
### For Large Document Sets:
|
||||
- Process PDFs in batches
|
||||
- Use SSD storage for index files
|
||||
- Consider using CUDA if available
|
||||
|
||||
## 🔗 Related Resources
|
||||
|
||||
- **Fast-PLAID**: https://github.com/lightonai/fast-plaid
|
||||
- **Pylate**: https://github.com/lightonai/pylate
|
||||
- **ColBERT**: https://github.com/stanford-futuredata/ColBERT
|
||||
- **ColPali Paper**: Vision-Language Models for Document Retrieval
|
||||
- **Issue #119**: https://github.com/yichuan-w/LEANN/issues/119
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### PDF Conversion Issues (macOS)
|
||||
```bash
|
||||
# Install poppler
|
||||
brew install poppler
|
||||
which pdfinfo && pdfinfo -v
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
- Reduce batch size (automatically handled)
|
||||
- Use CPU instead of GPU: `export CUDA_VISIBLE_DEVICES=""`
|
||||
- Process fewer PDFs at once
|
||||
|
||||
### Model Download Issues
|
||||
- Ensure internet connection for first run
|
||||
- Models are cached after first download
|
||||
- Use HuggingFace mirrors if needed
|
||||
|
||||
### Import Errors
|
||||
```bash
|
||||
# Ensure all dependencies installed
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
|
||||
# Check PyTorch installation
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
```
|
||||
|
||||
## 💡 Examples
|
||||
|
||||
### Research Paper Analysis
|
||||
```bash
|
||||
# Index your research papers
|
||||
python -m apps.colqwen_rag build --pdfs ~/Papers/AI/ --index ai_papers
|
||||
|
||||
# Ask research questions
|
||||
python -m apps.colqwen_rag search ai_papers "What are the limitations of transformer models?"
|
||||
python -m apps.colqwen_rag search ai_papers "How does BERT compare to GPT?"
|
||||
```
|
||||
|
||||
### Document Q&A
|
||||
```bash
|
||||
# Index business documents
|
||||
python -m apps.colqwen_rag build --pdfs ~/Documents/Reports/ --index reports
|
||||
|
||||
# Interactive analysis
|
||||
python -m apps.colqwen_rag ask reports --interactive
|
||||
```
|
||||
|
||||
### Visual Analysis
|
||||
```bash
|
||||
# Generate similarity maps for specific queries
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
# Edit multi-vector-leann-similarity-map.py with your query
|
||||
python multi-vector-leann-similarity-map.py
|
||||
# Check ./figures/ for generated heatmaps
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**🎯 This integration makes ColQwen as easy to use as other LEANN features while maintaining the full power of multimodal document understanding!**
|
||||
@@ -454,7 +454,7 @@ leann search my-index "your query" \
|
||||
|
||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
|
||||
|
||||
```bash
|
||||
# One-time: install and configure SkyPilot
|
||||
|
||||
@@ -7,7 +7,7 @@ name = "leann-core"
|
||||
version = "0.3.5"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.10"
|
||||
license = { text = "MIT" }
|
||||
|
||||
# All required dependencies included
|
||||
|
||||
@@ -1251,15 +1251,15 @@ class LeannChat:
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
print("The context provided to the LLM is:")
|
||||
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
||||
print("-" * 150)
|
||||
logger.info("The context provided to the LLM is:")
|
||||
logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
||||
logger.info("-" * 150)
|
||||
for r in results:
|
||||
chunk_relevance = f"{r.score:.3f}"
|
||||
chunk_id = r.id
|
||||
chunk_content = r.text[:60]
|
||||
chunk_source = r.metadata.get("source", "")[:80]
|
||||
print(
|
||||
logger.info(
|
||||
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
|
||||
)
|
||||
ask_time = time.time()
|
||||
|
||||
@@ -12,7 +12,13 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
from .settings import (
|
||||
resolve_anthropic_api_key,
|
||||
resolve_anthropic_base_url,
|
||||
resolve_ollama_host,
|
||||
resolve_openai_api_key,
|
||||
resolve_openai_base_url,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -845,6 +851,81 @@ class OpenAIChat(LLMInterface):
|
||||
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||
|
||||
|
||||
class AnthropicChat(LLMInterface):
|
||||
"""LLM interface for Anthropic Claude models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-haiku-4-5",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = resolve_anthropic_base_url(base_url)
|
||||
self.api_key = resolve_anthropic_api_key(api_key)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
|
||||
model,
|
||||
self.base_url,
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
# Allow custom Anthropic-compatible endpoints via base_url
|
||||
self.client = anthropic.Anthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
|
||||
)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
logger.info(f"Sending request to Anthropic with model {self.model}")
|
||||
|
||||
try:
|
||||
# Anthropic API parameters
|
||||
params = {
|
||||
"model": self.model,
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if "temperature" in kwargs:
|
||||
params["temperature"] = kwargs["temperature"]
|
||||
if "top_p" in kwargs:
|
||||
params["top_p"] = kwargs["top_p"]
|
||||
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
# Extract text from response
|
||||
response_text = response.content[0].text
|
||||
|
||||
# Log token usage
|
||||
print(
|
||||
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
|
||||
f"input tokens = {response.usage.input_tokens}, "
|
||||
f"output tokens = {response.usage.output_tokens}"
|
||||
)
|
||||
|
||||
if response.stop_reason == "max_tokens":
|
||||
print("The query is exceeding the maximum allowed number of tokens")
|
||||
|
||||
return response_text.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error communicating with Anthropic: {e}")
|
||||
return f"Error: Could not get a response from Anthropic. Details: {e}"
|
||||
|
||||
|
||||
class SimulatedChat(LLMInterface):
|
||||
"""A simple simulated chat for testing and development."""
|
||||
|
||||
@@ -897,6 +978,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
)
|
||||
elif llm_type == "gemini":
|
||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "anthropic":
|
||||
return AnthropicChat(
|
||||
model=model or "claude-3-5-sonnet-20241022",
|
||||
api_key=llm_config.get("api_key"),
|
||||
base_url=llm_config.get("base_url"),
|
||||
)
|
||||
elif llm_type == "simulated":
|
||||
return SimulatedChat()
|
||||
else:
|
||||
|
||||
@@ -239,11 +239,11 @@ def create_ast_chunks(
|
||||
|
||||
chunks = chunk_builder.chunkify(code_content)
|
||||
for chunk in chunks:
|
||||
chunk_text = None
|
||||
astchunk_metadata = {}
|
||||
chunk_text: str | None = None
|
||||
astchunk_metadata: dict[str, Any] = {}
|
||||
|
||||
if hasattr(chunk, "text"):
|
||||
chunk_text = chunk.text
|
||||
chunk_text = str(chunk.text) if chunk.text else None
|
||||
elif isinstance(chunk, str):
|
||||
chunk_text = chunk
|
||||
elif isinstance(chunk, dict):
|
||||
|
||||
@@ -11,10 +11,15 @@ from tqdm import tqdm
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .interactive_utils import create_cli_session
|
||||
from .registry import register_project_directory
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
from .settings import (
|
||||
resolve_anthropic_base_url,
|
||||
resolve_ollama_host,
|
||||
resolve_openai_api_key,
|
||||
resolve_openai_base_url,
|
||||
)
|
||||
|
||||
|
||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str | None:
|
||||
"""Extract text from PDF using PyMuPDF for better quality."""
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
@@ -30,7 +35,7 @@ def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||
return None
|
||||
|
||||
|
||||
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
|
||||
def extract_pdf_text_with_pdfplumber(file_path: str) -> str | None:
|
||||
"""Extract text from PDF using pdfplumber for better quality."""
|
||||
try:
|
||||
import pdfplumber
|
||||
@@ -291,7 +296,7 @@ Examples:
|
||||
"--llm",
|
||||
type=str,
|
||||
default="ollama",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
choices=["simulated", "ollama", "hf", "openai", "anthropic"],
|
||||
help="LLM provider (default: ollama)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
@@ -341,7 +346,7 @@ Examples:
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||
help="API key for cloud LLM providers (OpenAI, Anthropic)",
|
||||
)
|
||||
|
||||
# List command
|
||||
@@ -1616,6 +1621,12 @@ Examples:
|
||||
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||
if resolved_api_key:
|
||||
llm_config["api_key"] = resolved_api_key
|
||||
elif args.llm == "anthropic":
|
||||
# For Anthropic, pass base_url and API key if provided
|
||||
if args.api_base:
|
||||
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
|
||||
if args.api_key:
|
||||
llm_config["api_key"] = args.api_key
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
|
||||
@@ -451,7 +451,8 @@ def compute_embeddings_sentence_transformers(
|
||||
# TODO: Haven't tested this yet
|
||||
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
||||
try:
|
||||
torch.backends.mkldnn.enabled = True
|
||||
# PyTorch's ContextProp type is complex; cast for type checker
|
||||
torch.backends.mkldnn.enabled = True # type: ignore[assignment]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -11,14 +11,15 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
# Try to import readline with fallback for Windows
|
||||
HAS_READLINE = False
|
||||
readline = None # type: ignore[assignment]
|
||||
try:
|
||||
import readline
|
||||
import readline # type: ignore[no-redef]
|
||||
|
||||
HAS_READLINE = True
|
||||
except ImportError:
|
||||
# Windows doesn't have readline by default
|
||||
HAS_READLINE = False
|
||||
readline = None
|
||||
pass
|
||||
|
||||
|
||||
class InteractiveSession:
|
||||
|
||||
@@ -7,7 +7,7 @@ operators for different data types including numbers, strings, booleans, and lis
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,7 +47,7 @@ class MetadataFilterEngine:
|
||||
}
|
||||
|
||||
def apply_filters(
|
||||
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
|
||||
self, search_results: list[dict[str, Any]], metadata_filters: Optional[MetadataFilters]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Apply metadata filters to a list of search results.
|
||||
|
||||
@@ -56,7 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> int:
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
Ensures the embedding server is running if recompute is needed.
|
||||
This is a helper for subclasses.
|
||||
@@ -81,7 +83,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
}
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
port=port if port is not None else 5557,
|
||||
model_name=self.embedding_model,
|
||||
embedding_mode=self.embedding_mode,
|
||||
passages_file=passages_source_file,
|
||||
@@ -98,7 +100,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int = 5557,
|
||||
zmq_port: Optional[int] = None,
|
||||
query_template: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any
|
||||
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
|
||||
|
||||
|
||||
def _clean_url(value: str) -> str:
|
||||
@@ -52,6 +53,23 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||
|
||||
|
||||
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
|
||||
"""Resolve the base URL for Anthropic-compatible services."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
|
||||
os.getenv("ANTHROPIC_BASE_URL"),
|
||||
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
|
||||
|
||||
|
||||
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for OpenAI-compatible services."""
|
||||
|
||||
@@ -61,6 +79,15 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for Anthropic services."""
|
||||
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
return os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
|
||||
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||
"""Serialize provider options for child processes."""
|
||||
|
||||
|
||||
@@ -53,6 +53,11 @@ leann build my-project --docs $(git ls-files)
|
||||
# Start Claude Code
|
||||
claude
|
||||
```
|
||||
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
|
||||
|
||||
```bash
|
||||
leann build my-project --docs $(git ls-files) --no-recompute
|
||||
```
|
||||
|
||||
## 🚀 Advanced Usage Examples to build the index
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ name = "leann"
|
||||
version = "0.3.5"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.10"
|
||||
license = { text = "MIT" }
|
||||
authors = [
|
||||
{ name = "LEANN Team" }
|
||||
@@ -18,10 +18,10 @@ classifiers = [
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
]
|
||||
|
||||
# Default installation: core + hnsw + diskann
|
||||
|
||||
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "leann-workspace"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
dependencies = [
|
||||
"leann-core",
|
||||
@@ -157,6 +157,19 @@ exclude = ["localhost", "127.0.0.1", "example.com"]
|
||||
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
|
||||
scheme = ["https", "http"]
|
||||
|
||||
[tool.ty]
|
||||
# Type checking with ty (Astral's fast Python type checker)
|
||||
# ty is 10-100x faster than mypy. See: https://docs.astral.sh/ty/
|
||||
|
||||
[tool.ty.environment]
|
||||
python-version = "3.11"
|
||||
extra-paths = ["apps", "packages/leann-core/src"]
|
||||
|
||||
[tool.ty.rules]
|
||||
# Disable some noisy rules that have many false positives
|
||||
possibly-missing-attribute = "ignore"
|
||||
unresolved-import = "ignore" # Many optional dependencies
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
|
||||
@@ -91,7 +91,7 @@ def test_large_index():
|
||||
builder.build_index(index_path)
|
||||
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search(["word10 word20"], top_k=10)
|
||||
assert len(results[0]) == 10
|
||||
results = searcher.search("word10 word20", top_k=10)
|
||||
assert len(results) == 10
|
||||
# Cleanup
|
||||
searcher.cleanup()
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
@@ -175,7 +175,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
@@ -230,7 +230,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
@@ -307,7 +307,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
@@ -376,7 +376,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a document so builder is created
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
@@ -432,7 +432,7 @@ class TestPromptTemplateFlowsToComputeEmbeddings:
|
||||
cli = LeannCLI()
|
||||
|
||||
# Mock load_documents to return a simple document
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||
|
||||
parser = cli.create_parser()
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ def check_lmstudio_available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_lmstudio_first_model() -> str:
|
||||
def get_lmstudio_first_model() -> str | None:
|
||||
"""Get the first available model from LM Studio."""
|
||||
try:
|
||||
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||
@@ -91,6 +91,7 @@ class TestPromptTemplateOpenAI:
|
||||
model_name = get_lmstudio_first_model()
|
||||
if not model_name:
|
||||
pytest.skip("No models loaded in LM Studio")
|
||||
assert model_name is not None # Type narrowing for type checker
|
||||
|
||||
texts = ["artificial intelligence", "machine learning"]
|
||||
prompt_template = "search_query: "
|
||||
@@ -120,6 +121,7 @@ class TestPromptTemplateOpenAI:
|
||||
model_name = get_lmstudio_first_model()
|
||||
if not model_name:
|
||||
pytest.skip("No models loaded in LM Studio")
|
||||
assert model_name is not None # Type narrowing for type checker
|
||||
|
||||
text = "machine learning"
|
||||
base_url = "http://localhost:1234/v1"
|
||||
@@ -271,6 +273,7 @@ class TestLMStudioSDK:
|
||||
model_name = get_lmstudio_first_model()
|
||||
if not model_name:
|
||||
pytest.skip("No models loaded in LM Studio")
|
||||
assert model_name is not None # Type narrowing for type checker
|
||||
|
||||
try:
|
||||
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||
|
||||
@@ -581,7 +581,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
||||
|
||||
# Create a concrete implementation for testing
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
def search(
|
||||
self,
|
||||
query,
|
||||
top_k,
|
||||
complexity=64,
|
||||
beam_width=1,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
pruning_strategy="global",
|
||||
zmq_port=None,
|
||||
**kwargs,
|
||||
):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
@@ -625,7 +636,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
||||
|
||||
# Create a concrete implementation for testing
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
def search(
|
||||
self,
|
||||
query,
|
||||
top_k,
|
||||
complexity=64,
|
||||
beam_width=1,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
pruning_strategy="global",
|
||||
zmq_port=None,
|
||||
**kwargs,
|
||||
):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
@@ -671,7 +693,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
def search(
|
||||
self,
|
||||
query,
|
||||
top_k,
|
||||
complexity=64,
|
||||
beam_width=1,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
pruning_strategy="global",
|
||||
zmq_port=None,
|
||||
**kwargs,
|
||||
):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
@@ -710,7 +743,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
def search(
|
||||
self,
|
||||
query,
|
||||
top_k,
|
||||
complexity=64,
|
||||
beam_width=1,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
pruning_strategy="global",
|
||||
zmq_port=None,
|
||||
**kwargs,
|
||||
):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
@@ -774,7 +818,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
class TestSearcher(BaseSearcher):
|
||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
||||
def search(
|
||||
self,
|
||||
query,
|
||||
top_k,
|
||||
complexity=64,
|
||||
beam_width=1,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
pruning_strategy="global",
|
||||
zmq_port=None,
|
||||
**kwargs,
|
||||
):
|
||||
return {"labels": [], "distances": []}
|
||||
|
||||
searcher = object.__new__(TestSearcher)
|
||||
|
||||
@@ -97,17 +97,17 @@ def test_backend_options():
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use smaller model in CI to avoid memory issues
|
||||
if os.environ.get("CI") == "true":
|
||||
model_args = {
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"dimensions": 384,
|
||||
}
|
||||
else:
|
||||
model_args = {}
|
||||
is_ci = os.environ.get("CI") == "true"
|
||||
embedding_model = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" if is_ci else "facebook/contriever"
|
||||
)
|
||||
dimensions = 384 if is_ci else None
|
||||
|
||||
# Test HNSW backend (as shown in README)
|
||||
hnsw_path = str(Path(temp_dir) / "test_hnsw.leann")
|
||||
builder_hnsw = LeannBuilder(backend_name="hnsw", **model_args)
|
||||
builder_hnsw = LeannBuilder(
|
||||
backend_name="hnsw", embedding_model=embedding_model, dimensions=dimensions
|
||||
)
|
||||
builder_hnsw.add_text("Test document for HNSW backend")
|
||||
builder_hnsw.build_index(hnsw_path)
|
||||
assert Path(hnsw_path).parent.exists()
|
||||
@@ -115,7 +115,9 @@ def test_backend_options():
|
||||
|
||||
# Test DiskANN backend (mentioned as available option)
|
||||
diskann_path = str(Path(temp_dir) / "test_diskann.leann")
|
||||
builder_diskann = LeannBuilder(backend_name="diskann", **model_args)
|
||||
builder_diskann = LeannBuilder(
|
||||
backend_name="diskann", embedding_model=embedding_model, dimensions=dimensions
|
||||
)
|
||||
builder_diskann.add_text("Test document for DiskANN backend")
|
||||
builder_diskann.build_index(diskann_path)
|
||||
assert Path(diskann_path).parent.exists()
|
||||
|
||||
Reference in New Issue
Block a user