Compare commits
4 Commits
embed-laun
...
fix-update
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd5c052bd8 | ||
|
|
2f77d0185c | ||
|
|
82d536b2ae | ||
|
|
f42e086383 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -105,6 +105,3 @@ apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weavia
|
|||||||
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||||
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||||
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||||
|
|
||||||
# AUR build directory (Arch Linux)
|
|
||||||
paru-bin/
|
|
||||||
|
|||||||
401
README.md
401
README.md
@@ -8,12 +8,8 @@
|
|||||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||||
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q">
|
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||||
<img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
|
||||||
</a>
|
|
||||||
<a href="assets/wechat_user_group.JPG" title="Join WeChat group">
|
|
||||||
<img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group">
|
|
||||||
</a>
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
@@ -24,7 +20,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
|||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#mcp-integration-rag-on-live-data-from-any-platform), [Twitter](#mcp-integration-rag-on-live-data-from-any-platform)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||||
@@ -76,9 +72,8 @@ uv venv
|
|||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install leann
|
uv pip install leann
|
||||||
```
|
```
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
> Low-resource? See "Low-resource setups" in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>
|
<summary>
|
||||||
@@ -181,7 +176,7 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
|||||||
|
|
||||||
## RAG on Everything!
|
## RAG on Everything!
|
||||||
|
|
||||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, ChatGPT conversations, Claude conversations, iMessage conversations, and **live data from any platform through MCP (Model Context Protocol) servers** - including Slack, Twitter, and more.
|
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -547,386 +542,10 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 🤖 ChatGPT Chat History: Your Personal AI Conversation Archive!
|
|
||||||
|
|
||||||
Transform your ChatGPT conversations into a searchable knowledge base! Search through all your ChatGPT discussions about coding, research, brainstorming, and more.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.chatgpt_rag --export-path chatgpt_export.html --query "How do I create a list in Python?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your ChatGPT discussions again.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: How to Export ChatGPT Data</strong></summary>
|
|
||||||
|
|
||||||
**Step-by-step export process:**
|
|
||||||
|
|
||||||
1. **Sign in to ChatGPT**
|
|
||||||
2. **Click your profile icon** in the top right corner
|
|
||||||
3. **Navigate to Settings** → **Data Controls**
|
|
||||||
4. **Click "Export"** under Export Data
|
|
||||||
5. **Confirm the export** request
|
|
||||||
6. **Download the ZIP file** from the email link (expires in 24 hours)
|
|
||||||
7. **Extract or use directly** with LEANN
|
|
||||||
|
|
||||||
**Supported formats:**
|
|
||||||
- `.html` files from ChatGPT exports
|
|
||||||
- `.zip` archives from ChatGPT
|
|
||||||
- Directories with multiple export files
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: ChatGPT-Specific Arguments</strong></summary>
|
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
|
||||||
--export-path PATH # Path to ChatGPT export file (.html/.zip) or directory (default: ./chatgpt_export)
|
|
||||||
--separate-messages # Process each message separately instead of concatenated conversations
|
|
||||||
--chunk-size N # Text chunk size (default: 512)
|
|
||||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
|
||||||
```bash
|
|
||||||
# Basic usage with HTML export
|
|
||||||
python -m apps.chatgpt_rag --export-path conversations.html
|
|
||||||
|
|
||||||
# Process ZIP archive from ChatGPT
|
|
||||||
python -m apps.chatgpt_rag --export-path chatgpt_export.zip
|
|
||||||
|
|
||||||
# Search with specific query
|
|
||||||
python -m apps.chatgpt_rag --export-path chatgpt_data.html --query "Python programming help"
|
|
||||||
|
|
||||||
# Process individual messages for fine-grained search
|
|
||||||
python -m apps.chatgpt_rag --separate-messages --export-path chatgpt_export.html
|
|
||||||
|
|
||||||
# Process directory containing multiple exports
|
|
||||||
python -m apps.chatgpt_rag --export-path ./chatgpt_exports/ --max-items 1000
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
|
||||||
|
|
||||||
Once your ChatGPT conversations are indexed, you can search with queries like:
|
|
||||||
- "What did I ask ChatGPT about Python programming?"
|
|
||||||
- "Show me conversations about machine learning algorithms"
|
|
||||||
- "Find discussions about web development frameworks"
|
|
||||||
- "What coding advice did ChatGPT give me?"
|
|
||||||
- "Search for conversations about debugging techniques"
|
|
||||||
- "Find ChatGPT's recommendations for learning resources"
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 🤖 Claude Chat History: Your Personal AI Conversation Archive!
|
|
||||||
|
|
||||||
Transform your Claude conversations into a searchable knowledge base! Search through all your Claude discussions about coding, research, brainstorming, and more.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.claude_rag --export-path claude_export.json --query "What did I ask about Python dictionaries?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your Claude discussions again.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: How to Export Claude Data</strong></summary>
|
|
||||||
|
|
||||||
**Step-by-step export process:**
|
|
||||||
|
|
||||||
1. **Open Claude** in your browser
|
|
||||||
2. **Navigate to Settings** (look for gear icon or settings menu)
|
|
||||||
3. **Find Export/Download** options in your account settings
|
|
||||||
4. **Download conversation data** (usually in JSON format)
|
|
||||||
5. **Place the file** in your project directory
|
|
||||||
|
|
||||||
*Note: Claude export methods may vary depending on the interface you're using. Check Claude's help documentation for the most current export instructions.*
|
|
||||||
|
|
||||||
**Supported formats:**
|
|
||||||
- `.json` files (recommended)
|
|
||||||
- `.zip` archives containing JSON data
|
|
||||||
- Directories with multiple export files
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: Claude-Specific Arguments</strong></summary>
|
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
|
||||||
--export-path PATH # Path to Claude export file (.json/.zip) or directory (default: ./claude_export)
|
|
||||||
--separate-messages # Process each message separately instead of concatenated conversations
|
|
||||||
--chunk-size N # Text chunk size (default: 512)
|
|
||||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
|
||||||
```bash
|
|
||||||
# Basic usage with JSON export
|
|
||||||
python -m apps.claude_rag --export-path my_claude_conversations.json
|
|
||||||
|
|
||||||
# Process ZIP archive from Claude
|
|
||||||
python -m apps.claude_rag --export-path claude_export.zip
|
|
||||||
|
|
||||||
# Search with specific query
|
|
||||||
python -m apps.claude_rag --export-path claude_data.json --query "machine learning advice"
|
|
||||||
|
|
||||||
# Process individual messages for fine-grained search
|
|
||||||
python -m apps.claude_rag --separate-messages --export-path claude_export.json
|
|
||||||
|
|
||||||
# Process directory containing multiple exports
|
|
||||||
python -m apps.claude_rag --export-path ./claude_exports/ --max-items 1000
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
|
||||||
|
|
||||||
Once your Claude conversations are indexed, you can search with queries like:
|
|
||||||
- "What did I ask Claude about Python programming?"
|
|
||||||
- "Show me conversations about machine learning algorithms"
|
|
||||||
- "Find discussions about software architecture patterns"
|
|
||||||
- "What debugging advice did Claude give me?"
|
|
||||||
- "Search for conversations about data structures"
|
|
||||||
- "Find Claude's recommendations for learning resources"
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 💬 iMessage History: Your Personal Conversation Archive!
|
|
||||||
|
|
||||||
Transform your iMessage conversations into a searchable knowledge base! Search through all your text messages, group chats, and conversations with friends, family, and colleagues.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.imessage_rag --query "What did we discuss about the weekend plans?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Unlock your message history.** Never lose track of important conversations, shared links, or memorable moments from your iMessage history.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: How to Access iMessage Data</strong></summary>
|
|
||||||
|
|
||||||
**iMessage data location:**
|
|
||||||
|
|
||||||
iMessage conversations are stored in a SQLite database on your Mac at:
|
|
||||||
```
|
|
||||||
~/Library/Messages/chat.db
|
|
||||||
```
|
|
||||||
|
|
||||||
**Important setup requirements:**
|
|
||||||
|
|
||||||
1. **Grant Full Disk Access** to your terminal or IDE:
|
|
||||||
- Open **System Preferences** → **Security & Privacy** → **Privacy**
|
|
||||||
- Select **Full Disk Access** from the left sidebar
|
|
||||||
- Click the **+** button and add your terminal app (Terminal, iTerm2) or IDE (VS Code, etc.)
|
|
||||||
- Restart your terminal/IDE after granting access
|
|
||||||
|
|
||||||
2. **Alternative: Use a backup database**
|
|
||||||
- If you have Time Machine backups or manual copies of the database
|
|
||||||
- Use `--db-path` to specify a custom location
|
|
||||||
|
|
||||||
**Supported formats:**
|
|
||||||
- Direct access to `~/Library/Messages/chat.db` (default)
|
|
||||||
- Custom database path with `--db-path`
|
|
||||||
- Works with backup copies of the database
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: iMessage-Specific Arguments</strong></summary>
|
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
|
||||||
--db-path PATH # Path to chat.db file (default: ~/Library/Messages/chat.db)
|
|
||||||
--concatenate-conversations # Group messages by conversation (default: True)
|
|
||||||
--no-concatenate-conversations # Process each message individually
|
|
||||||
--chunk-size N # Text chunk size (default: 1000)
|
|
||||||
--chunk-overlap N # Overlap between chunks (default: 200)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
|
||||||
```bash
|
|
||||||
# Basic usage (requires Full Disk Access)
|
|
||||||
python -m apps.imessage_rag
|
|
||||||
|
|
||||||
# Search with specific query
|
|
||||||
python -m apps.imessage_rag --query "family dinner plans"
|
|
||||||
|
|
||||||
# Use custom database path
|
|
||||||
python -m apps.imessage_rag --db-path /path/to/backup/chat.db
|
|
||||||
|
|
||||||
# Process individual messages instead of conversations
|
|
||||||
python -m apps.imessage_rag --no-concatenate-conversations
|
|
||||||
|
|
||||||
# Limit processing for testing
|
|
||||||
python -m apps.imessage_rag --max-items 100 --query "weekend"
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
|
||||||
|
|
||||||
Once your iMessage conversations are indexed, you can search with queries like:
|
|
||||||
- "What did we discuss about vacation plans?"
|
|
||||||
- "Find messages about restaurant recommendations"
|
|
||||||
- "Show me conversations with John about the project"
|
|
||||||
- "Search for shared links about technology"
|
|
||||||
- "Find group chat discussions about weekend events"
|
|
||||||
- "What did mom say about the family gathering?"
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### MCP Integration: RAG on Live Data from Any Platform
|
|
||||||
|
|
||||||
Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
|
|
||||||
|
|
||||||
**Key Benefits:**
|
|
||||||
- **Live Data Access**: Fetch real-time data without manual exports
|
|
||||||
- **Standardized Protocol**: Use any MCP-compatible server
|
|
||||||
- **Easy Extension**: Add new platforms with minimal code
|
|
||||||
- **Secure Access**: MCP servers handle authentication
|
|
||||||
|
|
||||||
#### 💬 Slack Messages: Search Your Team Conversations
|
|
||||||
|
|
||||||
Transform your Slack workspace into a searchable knowledge base! Find discussions, decisions, and shared knowledge across all your channels.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Test MCP server connection
|
|
||||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
|
|
||||||
|
|
||||||
# Index and search Slack messages
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--workspace-name "my-team" \
|
|
||||||
--channels general dev-team random \
|
|
||||||
--query "What did we decide about the product launch?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**📖 Comprehensive Setup Guide**: For detailed setup instructions, troubleshooting common issues (like "users cache is not ready yet"), and advanced configuration options, see our [**Slack Setup Guide**](docs/slack-setup-guide.md).
|
|
||||||
|
|
||||||
**Quick Setup:**
|
|
||||||
1. Install a Slack MCP server (e.g., `npm install -g slack-mcp-server`)
|
|
||||||
2. Create a Slack App and get API credentials (see detailed guide above)
|
|
||||||
3. Set environment variables:
|
|
||||||
```bash
|
|
||||||
export SLACK_BOT_TOKEN="xoxb-your-bot-token"
|
|
||||||
export SLACK_APP_TOKEN="xapp-your-app-token" # Optional
|
|
||||||
```
|
|
||||||
4. Test connection with `--test-connection` flag
|
|
||||||
|
|
||||||
**Arguments:**
|
|
||||||
- `--mcp-server`: Command to start the Slack MCP server
|
|
||||||
- `--workspace-name`: Slack workspace name for organization
|
|
||||||
- `--channels`: Specific channels to index (optional)
|
|
||||||
- `--concatenate-conversations`: Group messages by channel (default: true)
|
|
||||||
- `--max-messages-per-channel`: Limit messages per channel (default: 100)
|
|
||||||
- `--max-retries`: Maximum retries for cache sync issues (default: 5)
|
|
||||||
- `--retry-delay`: Initial delay between retries in seconds (default: 2.0)
|
|
||||||
|
|
||||||
#### 🐦 Twitter Bookmarks: Your Personal Tweet Library
|
|
||||||
|
|
||||||
Search through your Twitter bookmarks! Find that perfect article, thread, or insight you saved for later.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Test MCP server connection
|
|
||||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --test-connection
|
|
||||||
|
|
||||||
# Index and search Twitter bookmarks
|
|
||||||
python -m apps.twitter_rag \
|
|
||||||
--mcp-server "twitter-mcp-server" \
|
|
||||||
--max-bookmarks 1000 \
|
|
||||||
--query "What AI articles did I bookmark about machine learning?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Setup Requirements:**
|
|
||||||
1. Install a Twitter MCP server (e.g., `npm install -g twitter-mcp-server`)
|
|
||||||
2. Get Twitter API credentials:
|
|
||||||
- Apply for a Twitter Developer Account at [developer.twitter.com](https://developer.twitter.com)
|
|
||||||
- Create a new app in the Twitter Developer Portal
|
|
||||||
- Generate API keys and access tokens with "Read" permissions
|
|
||||||
- For bookmarks access, you may need Twitter API v2 with appropriate scopes
|
|
||||||
```bash
|
|
||||||
export TWITTER_API_KEY="your-api-key"
|
|
||||||
export TWITTER_API_SECRET="your-api-secret"
|
|
||||||
export TWITTER_ACCESS_TOKEN="your-access-token"
|
|
||||||
export TWITTER_ACCESS_TOKEN_SECRET="your-access-token-secret"
|
|
||||||
```
|
|
||||||
3. Test connection with `--test-connection` flag
|
|
||||||
|
|
||||||
**Arguments:**
|
|
||||||
- `--mcp-server`: Command to start the Twitter MCP server
|
|
||||||
- `--username`: Filter bookmarks by username (optional)
|
|
||||||
- `--max-bookmarks`: Maximum bookmarks to fetch (default: 1000)
|
|
||||||
- `--no-tweet-content`: Exclude tweet content, only metadata
|
|
||||||
- `--no-metadata`: Exclude engagement metadata
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
|
||||||
|
|
||||||
**Slack Queries:**
|
|
||||||
- "What did the team discuss about the project deadline?"
|
|
||||||
- "Find messages about the new feature launch"
|
|
||||||
- "Show me conversations about budget planning"
|
|
||||||
- "What decisions were made in the dev-team channel?"
|
|
||||||
|
|
||||||
**Twitter Queries:**
|
|
||||||
- "What AI articles did I bookmark last month?"
|
|
||||||
- "Find tweets about machine learning techniques"
|
|
||||||
- "Show me bookmarked threads about startup advice"
|
|
||||||
- "What Python tutorials did I save?"
|
|
||||||
|
|
||||||
</details>
|
|
||||||
<summary><strong>🔧 Using MCP with CLI Commands</strong></summary>
|
|
||||||
|
|
||||||
**Want to use MCP data with regular LEANN CLI?** You can combine MCP apps with CLI commands:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Step 1: Use MCP app to fetch and index data
|
|
||||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --workspace-name "my-team"
|
|
||||||
|
|
||||||
# Step 2: The data is now indexed and available via CLI
|
|
||||||
leann search slack_messages "project deadline"
|
|
||||||
leann ask slack_messages "What decisions were made about the product launch?"
|
|
||||||
|
|
||||||
# Same for Twitter bookmarks
|
|
||||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server"
|
|
||||||
leann search twitter_bookmarks "machine learning articles"
|
|
||||||
```
|
|
||||||
|
|
||||||
**MCP vs Manual Export:**
|
|
||||||
- **MCP**: Live data, automatic updates, requires server setup
|
|
||||||
- **Manual Export**: One-time setup, works offline, requires manual data export
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>🔧 Adding New MCP Platforms</strong></summary>
|
|
||||||
|
|
||||||
Want to add support for other platforms? LEANN's MCP integration is designed for easy extension:
|
|
||||||
|
|
||||||
1. **Find or create an MCP server** for your platform
|
|
||||||
2. **Create a reader class** following the pattern in `apps/slack_data/slack_mcp_reader.py`
|
|
||||||
3. **Create a RAG application** following the pattern in `apps/slack_rag.py`
|
|
||||||
4. **Test and contribute** back to the community!
|
|
||||||
|
|
||||||
**Popular MCP servers to explore:**
|
|
||||||
- GitHub repositories and issues
|
|
||||||
- Discord messages
|
|
||||||
- Notion pages
|
|
||||||
- Google Drive documents
|
|
||||||
- And many more in the MCP ecosystem!
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>AST‑Aware Code Chunking</strong></summary>
|
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
||||||
|
|
||||||
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||||
|
|
||||||
@@ -954,7 +573,7 @@ Try our fully agentic pipeline with auto query rewriting, semantic search planni
|
|||||||
|
|
||||||
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||||
|
|
||||||
## Command Line Interface
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||||
|
|
||||||
@@ -1196,7 +815,7 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||||
|
|
||||||
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan), [Aakash Suresh](https://github.com/ASuresh0524)
|
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
||||||
|
|
||||||
|
|
||||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||||
@@ -1213,7 +832,3 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
Made with ❤️ by the Leann team
|
Made with ❤️ by the Leann team
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## 🤖 Explore LEANN with AI
|
|
||||||
|
|
||||||
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.
|
|
||||||
|
|||||||
@@ -10,39 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
|
||||||
# Optional import: older PyPI builds may not include interactive_utils
|
|
||||||
try:
|
|
||||||
from leann.interactive_utils import create_rag_session
|
|
||||||
except ImportError:
|
|
||||||
|
|
||||||
def create_rag_session(app_name: str, data_description: str):
|
|
||||||
class _SimpleSession:
|
|
||||||
def run_interactive_loop(self, handler):
|
|
||||||
print(f"Interactive session for {app_name}: {data_description}")
|
|
||||||
print("Interactive mode not available in this build")
|
|
||||||
|
|
||||||
return _SimpleSession()
|
|
||||||
|
|
||||||
|
|
||||||
from leann.registry import register_project_directory
|
from leann.registry import register_project_directory
|
||||||
|
|
||||||
# Optional import: older PyPI builds may not include settings
|
|
||||||
try:
|
|
||||||
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
except ImportError:
|
|
||||||
# Minimal fallbacks if settings helpers are unavailable
|
|
||||||
import os
|
|
||||||
|
|
||||||
def resolve_ollama_host(value: str | None) -> str | None:
|
|
||||||
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
|
|
||||||
|
|
||||||
def resolve_openai_api_key(value: str | None) -> str | None:
|
|
||||||
return value or os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
def resolve_openai_base_url(value: str | None) -> str | None:
|
|
||||||
return value or os.getenv("OPENAI_BASE_URL")
|
|
||||||
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -180,14 +149,14 @@ class BaseRAGExample(ABC):
|
|||||||
ast_group.add_argument(
|
ast_group.add_argument(
|
||||||
"--ast-chunk-size",
|
"--ast-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=300,
|
default=512,
|
||||||
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
|
help="Maximum characters per AST chunk (default: 512)",
|
||||||
)
|
)
|
||||||
ast_group.add_argument(
|
ast_group.add_argument(
|
||||||
"--ast-chunk-overlap",
|
"--ast-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
|
help="Overlap between AST chunks (default: 64)",
|
||||||
)
|
)
|
||||||
ast_group.add_argument(
|
ast_group.add_argument(
|
||||||
"--code-file-extensions",
|
"--code-file-extensions",
|
||||||
@@ -338,12 +307,19 @@ class BaseRAGExample(ABC):
|
|||||||
complexity=args.search_complexity,
|
complexity=args.search_complexity,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create interactive session
|
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||||
session = create_rag_session(
|
print("Type 'quit' or 'exit' to stop.\n")
|
||||||
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
|
|
||||||
)
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("You: ").strip()
|
||||||
|
if query.lower() in ["quit", "exit", "q"]:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
|
||||||
def handle_query(query: str):
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
# Prepare LLM kwargs with thinking budget if specified
|
||||||
llm_kwargs = {}
|
llm_kwargs = {}
|
||||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||||
@@ -357,7 +333,11 @@ class BaseRAGExample(ABC):
|
|||||||
)
|
)
|
||||||
print(f"\nAssistant: {response}\n")
|
print(f"\nAssistant: {response}\n")
|
||||||
|
|
||||||
session.run_interactive_loop(handle_query)
|
except KeyboardInterrupt:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
async def run_single_query(self, args, index_path: str, query: str):
|
async def run_single_query(self, args, index_path: str, query: str):
|
||||||
"""Run a single query against the index."""
|
"""Run a single query against the index."""
|
||||||
|
|||||||
@@ -1,413 +0,0 @@
|
|||||||
"""
|
|
||||||
ChatGPT export data reader.
|
|
||||||
|
|
||||||
Reads and processes ChatGPT export data from chat.html files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from zipfile import ZipFile
|
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTReader(BaseReader):
|
|
||||||
"""
|
|
||||||
ChatGPT export data reader.
|
|
||||||
|
|
||||||
Reads ChatGPT conversation data from exported chat.html files or zip archives.
|
|
||||||
Processes conversations into structured documents with metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup # noqa
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
|
||||||
|
|
||||||
self.concatenate_conversations = concatenate_conversations
|
|
||||||
|
|
||||||
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
|
|
||||||
"""
|
|
||||||
Extract chat.html from ChatGPT export zip file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
zip_path: Path to the ChatGPT export zip file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HTML content as string, or None if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with ZipFile(zip_path, "r") as zip_file:
|
|
||||||
# Look for chat.html or conversations.html
|
|
||||||
html_files = [
|
|
||||||
f
|
|
||||||
for f in zip_file.namelist()
|
|
||||||
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
|
|
||||||
]
|
|
||||||
|
|
||||||
if not html_files:
|
|
||||||
print(f"No HTML chat file found in {zip_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Use the first HTML file found
|
|
||||||
html_file = html_files[0]
|
|
||||||
print(f"Found HTML file: {html_file}")
|
|
||||||
|
|
||||||
with zip_file.open(html_file) as f:
|
|
||||||
return f.read().decode("utf-8", errors="ignore")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error extracting HTML from zip {zip_path}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Parse ChatGPT HTML export to extract conversations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
html_content: HTML content from ChatGPT export
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of conversation dictionaries
|
|
||||||
"""
|
|
||||||
soup = BeautifulSoup(html_content, "html.parser")
|
|
||||||
conversations = []
|
|
||||||
|
|
||||||
# Try different possible structures for ChatGPT exports
|
|
||||||
# Structure 1: Look for conversation containers
|
|
||||||
conversation_containers = soup.find_all(
|
|
||||||
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not conversation_containers:
|
|
||||||
# Structure 2: Look for message containers directly
|
|
||||||
conversation_containers = [soup] # Use the entire document as one conversation
|
|
||||||
|
|
||||||
for container in conversation_containers:
|
|
||||||
conversation = self._extract_conversation_from_container(container)
|
|
||||||
if conversation and conversation.get("messages"):
|
|
||||||
conversations.append(conversation)
|
|
||||||
|
|
||||||
# If no structured conversations found, try to extract all text as one conversation
|
|
||||||
if not conversations:
|
|
||||||
all_text = soup.get_text(separator="\n", strip=True)
|
|
||||||
if all_text:
|
|
||||||
conversations.append(
|
|
||||||
{
|
|
||||||
"title": "ChatGPT Conversation",
|
|
||||||
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
|
|
||||||
"timestamp": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return conversations
|
|
||||||
|
|
||||||
def _extract_conversation_from_container(self, container) -> dict | None:
|
|
||||||
"""
|
|
||||||
Extract conversation data from a container element.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
container: BeautifulSoup element containing conversation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with conversation data or None
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# Look for message elements with various possible structures
|
|
||||||
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
|
|
||||||
|
|
||||||
for selector in message_selectors:
|
|
||||||
message_elements = container.select(selector)
|
|
||||||
if message_elements:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
message_elements = []
|
|
||||||
|
|
||||||
# If no structured messages found, treat the entire container as one message
|
|
||||||
if not message_elements:
|
|
||||||
text_content = container.get_text(separator="\n", strip=True)
|
|
||||||
if text_content:
|
|
||||||
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
|
|
||||||
else:
|
|
||||||
for element in message_elements:
|
|
||||||
message = self._extract_message_from_element(element)
|
|
||||||
if message:
|
|
||||||
messages.append(message)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try to extract conversation title
|
|
||||||
title_element = container.find(["h1", "h2", "h3", "title"])
|
|
||||||
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
|
|
||||||
|
|
||||||
# Try to extract timestamp from various possible locations
|
|
||||||
timestamp = self._extract_timestamp_from_container(container)
|
|
||||||
|
|
||||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
|
||||||
|
|
||||||
def _extract_message_from_element(self, element) -> dict | None:
|
|
||||||
"""
|
|
||||||
Extract message data from an element.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
element: BeautifulSoup element containing message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with message data or None
|
|
||||||
"""
|
|
||||||
text_content = element.get_text(separator=" ", strip=True)
|
|
||||||
|
|
||||||
# Skip empty or very short messages
|
|
||||||
if not text_content or len(text_content.strip()) < 3:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try to determine role (user/assistant) from class names or content
|
|
||||||
role = "mixed" # Default role
|
|
||||||
|
|
||||||
class_names = " ".join(element.get("class", [])).lower()
|
|
||||||
if "user" in class_names or "human" in class_names:
|
|
||||||
role = "user"
|
|
||||||
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
|
|
||||||
role = "assistant"
|
|
||||||
elif text_content.lower().startswith(("you:", "user:", "me:")):
|
|
||||||
role = "user"
|
|
||||||
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
|
|
||||||
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
|
|
||||||
role = "assistant"
|
|
||||||
text_content = re.sub(
|
|
||||||
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to extract timestamp
|
|
||||||
timestamp = self._extract_timestamp_from_element(element)
|
|
||||||
|
|
||||||
return {"role": role, "content": text_content, "timestamp": timestamp}
|
|
||||||
|
|
||||||
def _extract_timestamp_from_element(self, element) -> str | None:
|
|
||||||
"""Extract timestamp from element."""
|
|
||||||
# Look for timestamp in various attributes and child elements
|
|
||||||
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
|
|
||||||
for attr in timestamp_attrs:
|
|
||||||
if element.get(attr):
|
|
||||||
return element.get(attr)
|
|
||||||
|
|
||||||
# Look for time elements
|
|
||||||
time_element = element.find("time")
|
|
||||||
if time_element:
|
|
||||||
return time_element.get("datetime") or time_element.get_text(strip=True)
|
|
||||||
|
|
||||||
# Look for date-like text patterns
|
|
||||||
text = element.get_text()
|
|
||||||
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
|
|
||||||
|
|
||||||
for pattern in date_patterns:
|
|
||||||
match = re.search(pattern, text)
|
|
||||||
if match:
|
|
||||||
return match.group()
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_timestamp_from_container(self, container) -> str | None:
|
|
||||||
"""Extract timestamp from conversation container."""
|
|
||||||
return self._extract_timestamp_from_element(container)
|
|
||||||
|
|
||||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
|
||||||
"""
|
|
||||||
Create concatenated content from conversation messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation: Dictionary containing conversation data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted concatenated content
|
|
||||||
"""
|
|
||||||
title = conversation.get("title", "ChatGPT Conversation")
|
|
||||||
messages = conversation.get("messages", [])
|
|
||||||
timestamp = conversation.get("timestamp", "Unknown")
|
|
||||||
|
|
||||||
# Build message content
|
|
||||||
message_parts = []
|
|
||||||
for message in messages:
|
|
||||||
role = message.get("role", "mixed")
|
|
||||||
content = message.get("content", "")
|
|
||||||
msg_timestamp = message.get("timestamp", "")
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
prefix = "[You]"
|
|
||||||
elif role == "assistant":
|
|
||||||
prefix = "[ChatGPT]"
|
|
||||||
else:
|
|
||||||
prefix = "[Message]"
|
|
||||||
|
|
||||||
# Add timestamp if available
|
|
||||||
if msg_timestamp:
|
|
||||||
prefix += f" ({msg_timestamp})"
|
|
||||||
|
|
||||||
message_parts.append(f"{prefix}: {content}")
|
|
||||||
|
|
||||||
concatenated_text = "\n\n".join(message_parts)
|
|
||||||
|
|
||||||
# Create final document content
|
|
||||||
doc_content = f"""Conversation: {title}
|
|
||||||
Date: {timestamp}
|
|
||||||
Messages ({len(messages)} messages):
|
|
||||||
|
|
||||||
{concatenated_text}
|
|
||||||
"""
|
|
||||||
return doc_content
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
|
||||||
"""
|
|
||||||
Load ChatGPT export data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Directory containing ChatGPT export files or path to specific file
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum number of conversations to process
|
|
||||||
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
|
|
||||||
include_metadata (bool): Whether to include metadata in documents
|
|
||||||
"""
|
|
||||||
docs: list[Document] = []
|
|
||||||
max_count = load_kwargs.get("max_count", -1)
|
|
||||||
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
|
|
||||||
include_metadata = load_kwargs.get("include_metadata", True)
|
|
||||||
|
|
||||||
if not chatgpt_export_path:
|
|
||||||
print("No ChatGPT export path provided")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
export_path = Path(chatgpt_export_path)
|
|
||||||
|
|
||||||
if not export_path.exists():
|
|
||||||
print(f"ChatGPT export path not found: {export_path}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
html_content = None
|
|
||||||
|
|
||||||
# Handle different input types
|
|
||||||
if export_path.is_file():
|
|
||||||
if export_path.suffix.lower() == ".zip":
|
|
||||||
# Extract HTML from zip file
|
|
||||||
html_content = self._extract_html_from_zip(export_path)
|
|
||||||
elif export_path.suffix.lower() == ".html":
|
|
||||||
# Read HTML file directly
|
|
||||||
try:
|
|
||||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
|
||||||
html_content = f.read()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading HTML file {export_path}: {e}")
|
|
||||||
return docs
|
|
||||||
else:
|
|
||||||
print(f"Unsupported file type: {export_path.suffix}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
elif export_path.is_dir():
|
|
||||||
# Look for HTML files in directory
|
|
||||||
html_files = list(export_path.glob("*.html"))
|
|
||||||
zip_files = list(export_path.glob("*.zip"))
|
|
||||||
|
|
||||||
if html_files:
|
|
||||||
# Use first HTML file found
|
|
||||||
html_file = html_files[0]
|
|
||||||
print(f"Found HTML file: {html_file}")
|
|
||||||
try:
|
|
||||||
with open(html_file, encoding="utf-8", errors="ignore") as f:
|
|
||||||
html_content = f.read()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading HTML file {html_file}: {e}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
elif zip_files:
|
|
||||||
# Use first zip file found
|
|
||||||
zip_file = zip_files[0]
|
|
||||||
print(f"Found zip file: {zip_file}")
|
|
||||||
html_content = self._extract_html_from_zip(zip_file)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"No HTML or zip files found in {export_path}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
if not html_content:
|
|
||||||
print("No HTML content found to process")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
# Parse conversations from HTML
|
|
||||||
print("Parsing ChatGPT conversations from HTML...")
|
|
||||||
conversations = self._parse_chatgpt_html(html_content)
|
|
||||||
|
|
||||||
if not conversations:
|
|
||||||
print("No conversations found in HTML content")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
print(f"Found {len(conversations)} conversations")
|
|
||||||
|
|
||||||
# Process conversations into documents
|
|
||||||
count = 0
|
|
||||||
for conversation in conversations:
|
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
# Create one document per conversation with concatenated messages
|
|
||||||
doc_content = self._create_concatenated_content(conversation)
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
if include_metadata:
|
|
||||||
metadata = {
|
|
||||||
"title": conversation.get("title", "ChatGPT Conversation"),
|
|
||||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
|
||||||
"message_count": len(conversation.get("messages", [])),
|
|
||||||
"source": "ChatGPT Export",
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=doc_content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Create separate documents for each message
|
|
||||||
for message in conversation.get("messages", []):
|
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
role = message.get("role", "mixed")
|
|
||||||
content = message.get("content", "")
|
|
||||||
msg_timestamp = message.get("timestamp", "")
|
|
||||||
|
|
||||||
if not content.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create document content with context
|
|
||||||
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
|
|
||||||
Role: {role}
|
|
||||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
|
||||||
Message: {content}
|
|
||||||
"""
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
if include_metadata:
|
|
||||||
metadata = {
|
|
||||||
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
|
|
||||||
"role": role,
|
|
||||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
|
||||||
"source": "ChatGPT Export",
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=doc_content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
print(f"Created {len(docs)} documents from ChatGPT export")
|
|
||||||
return docs
|
|
||||||
@@ -1,186 +0,0 @@
|
|||||||
"""
|
|
||||||
ChatGPT RAG example using the unified interface.
|
|
||||||
Supports ChatGPT export data from chat.html files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
from chunking import create_text_chunks
|
|
||||||
|
|
||||||
from .chatgpt_data.chatgpt_reader import ChatGPTReader
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTRAG(BaseRAGExample):
|
|
||||||
"""RAG example for ChatGPT conversation data."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.max_items_default = -1 # Process all conversations by default
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="ChatGPT",
|
|
||||||
description="Process and query ChatGPT conversation exports with LEANN",
|
|
||||||
default_index_name="chatgpt_conversations_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add ChatGPT-specific arguments."""
|
|
||||||
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--export-path",
|
|
||||||
type=str,
|
|
||||||
default="./chatgpt_export",
|
|
||||||
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
|
|
||||||
)
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Concatenate messages within conversations for better context (default: True)",
|
|
||||||
)
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--separate-messages",
|
|
||||||
action="store_true",
|
|
||||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
|
||||||
)
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
|
||||||
)
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
|
|
||||||
"""
|
|
||||||
Find ChatGPT export files in the given path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
export_path: Path to search for exports
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of paths to ChatGPT export files
|
|
||||||
"""
|
|
||||||
export_files = []
|
|
||||||
|
|
||||||
if export_path.is_file():
|
|
||||||
if export_path.suffix.lower() in [".zip", ".html"]:
|
|
||||||
export_files.append(export_path)
|
|
||||||
elif export_path.is_dir():
|
|
||||||
# Look for zip and html files
|
|
||||||
export_files.extend(export_path.glob("*.zip"))
|
|
||||||
export_files.extend(export_path.glob("*.html"))
|
|
||||||
|
|
||||||
return export_files
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load ChatGPT export data and convert to text chunks."""
|
|
||||||
export_path = Path(args.export_path)
|
|
||||||
|
|
||||||
if not export_path.exists():
|
|
||||||
print(f"ChatGPT export path not found: {export_path}")
|
|
||||||
print(
|
|
||||||
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
|
|
||||||
)
|
|
||||||
print("\nTo export your ChatGPT data:")
|
|
||||||
print("1. Sign in to ChatGPT")
|
|
||||||
print("2. Click on your profile icon → Settings → Data Controls")
|
|
||||||
print("3. Click 'Export' under Export Data")
|
|
||||||
print("4. Download the zip file from the email link")
|
|
||||||
print("5. Extract or place the file/directory at the specified path")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Find export files
|
|
||||||
export_files = self._find_chatgpt_exports(export_path)
|
|
||||||
|
|
||||||
if not export_files:
|
|
||||||
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(export_files)} ChatGPT export files")
|
|
||||||
|
|
||||||
# Create reader with appropriate settings
|
|
||||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
|
||||||
reader = ChatGPTReader(concatenate_conversations=concatenate)
|
|
||||||
|
|
||||||
# Process each export file
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, export_file in enumerate(export_files):
|
|
||||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply max_items limit per file
|
|
||||||
max_per_file = -1
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_file = remaining
|
|
||||||
|
|
||||||
# Load conversations
|
|
||||||
documents = reader.load_data(
|
|
||||||
chatgpt_export_path=str(export_file),
|
|
||||||
max_count=max_per_file,
|
|
||||||
include_metadata=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
print(f"Processed {len(documents)} conversations from this file")
|
|
||||||
else:
|
|
||||||
print(f"No conversations loaded from {export_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {export_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No conversations found to process!")
|
|
||||||
print("\nTroubleshooting:")
|
|
||||||
print("- Ensure the export file is a valid ChatGPT export")
|
|
||||||
print("- Check that the HTML file contains conversation data")
|
|
||||||
print("- Try extracting the zip file and pointing to the HTML file directly")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
|
||||||
print("Now starting to split into text chunks... this may take some time")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Example queries for ChatGPT RAG
|
|
||||||
print("\n🤖 ChatGPT RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What did I ask about Python programming?'")
|
|
||||||
print("- 'Show me conversations about machine learning'")
|
|
||||||
print("- 'Find discussions about travel planning'")
|
|
||||||
print("- 'What advice did ChatGPT give me about career development?'")
|
|
||||||
print("- 'Search for conversations about cooking recipes'")
|
|
||||||
print("\nTo get started:")
|
|
||||||
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
|
|
||||||
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
|
|
||||||
print("3. Run this script to build your personal ChatGPT knowledge base!")
|
|
||||||
print("\nOr run without --query for interactive mode\n")
|
|
||||||
|
|
||||||
rag = ChatGPTRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
@@ -12,7 +12,6 @@ from pathlib import Path
|
|||||||
try:
|
try:
|
||||||
from leann.chunking_utils import (
|
from leann.chunking_utils import (
|
||||||
CODE_EXTENSIONS,
|
CODE_EXTENSIONS,
|
||||||
_traditional_chunks_as_dicts,
|
|
||||||
create_ast_chunks,
|
create_ast_chunks,
|
||||||
create_text_chunks,
|
create_text_chunks,
|
||||||
create_traditional_chunks,
|
create_traditional_chunks,
|
||||||
@@ -26,7 +25,6 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
|||||||
sys.path.insert(0, str(leann_src))
|
sys.path.insert(0, str(leann_src))
|
||||||
from leann.chunking_utils import (
|
from leann.chunking_utils import (
|
||||||
CODE_EXTENSIONS,
|
CODE_EXTENSIONS,
|
||||||
_traditional_chunks_as_dicts,
|
|
||||||
create_ast_chunks,
|
create_ast_chunks,
|
||||||
create_text_chunks,
|
create_text_chunks,
|
||||||
create_traditional_chunks,
|
create_traditional_chunks,
|
||||||
@@ -38,7 +36,6 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CODE_EXTENSIONS",
|
"CODE_EXTENSIONS",
|
||||||
"_traditional_chunks_as_dicts",
|
|
||||||
"create_ast_chunks",
|
"create_ast_chunks",
|
||||||
"create_text_chunks",
|
"create_text_chunks",
|
||||||
"create_traditional_chunks",
|
"create_traditional_chunks",
|
||||||
|
|||||||
@@ -1,420 +0,0 @@
|
|||||||
"""
|
|
||||||
Claude export data reader.
|
|
||||||
|
|
||||||
Reads and processes Claude conversation data from exported JSON files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from zipfile import ZipFile
|
|
||||||
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeReader(BaseReader):
|
|
||||||
"""
|
|
||||||
Claude export data reader.
|
|
||||||
|
|
||||||
Reads Claude conversation data from exported JSON files or zip archives.
|
|
||||||
Processes conversations into structured documents with metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
|
||||||
"""
|
|
||||||
self.concatenate_conversations = concatenate_conversations
|
|
||||||
|
|
||||||
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
|
|
||||||
"""
|
|
||||||
Extract JSON files from Claude export zip file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
zip_path: Path to the Claude export zip file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of JSON content strings, or empty list if not found
|
|
||||||
"""
|
|
||||||
json_contents = []
|
|
||||||
try:
|
|
||||||
with ZipFile(zip_path, "r") as zip_file:
|
|
||||||
# Look for JSON files
|
|
||||||
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
|
|
||||||
|
|
||||||
if not json_files:
|
|
||||||
print(f"No JSON files found in {zip_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(json_files)} JSON files in archive")
|
|
||||||
|
|
||||||
for json_file in json_files:
|
|
||||||
with zip_file.open(json_file) as f:
|
|
||||||
content = f.read().decode("utf-8", errors="ignore")
|
|
||||||
json_contents.append(content)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error extracting JSON from zip {zip_path}: {e}")
|
|
||||||
|
|
||||||
return json_contents
|
|
||||||
|
|
||||||
def _parse_claude_json(self, json_content: str) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Parse Claude JSON export to extract conversations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_content: JSON content from Claude export
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of conversation dictionaries
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = json.loads(json_content)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
print(f"Error parsing JSON: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
conversations = []
|
|
||||||
|
|
||||||
# Handle different possible JSON structures
|
|
||||||
if isinstance(data, list):
|
|
||||||
# If data is a list of conversations
|
|
||||||
for item in data:
|
|
||||||
conversation = self._extract_conversation_from_json(item)
|
|
||||||
if conversation:
|
|
||||||
conversations.append(conversation)
|
|
||||||
elif isinstance(data, dict):
|
|
||||||
# Check for common structures
|
|
||||||
if "conversations" in data:
|
|
||||||
# Structure: {"conversations": [...]}
|
|
||||||
for item in data["conversations"]:
|
|
||||||
conversation = self._extract_conversation_from_json(item)
|
|
||||||
if conversation:
|
|
||||||
conversations.append(conversation)
|
|
||||||
elif "messages" in data:
|
|
||||||
# Single conversation with messages
|
|
||||||
conversation = self._extract_conversation_from_json(data)
|
|
||||||
if conversation:
|
|
||||||
conversations.append(conversation)
|
|
||||||
else:
|
|
||||||
# Try to treat the whole object as a conversation
|
|
||||||
conversation = self._extract_conversation_from_json(data)
|
|
||||||
if conversation:
|
|
||||||
conversations.append(conversation)
|
|
||||||
|
|
||||||
return conversations
|
|
||||||
|
|
||||||
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
|
|
||||||
"""
|
|
||||||
Extract conversation data from a JSON object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conv_data: Dictionary containing conversation data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with conversation data or None
|
|
||||||
"""
|
|
||||||
if not isinstance(conv_data, dict):
|
|
||||||
return None
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# Look for messages in various possible structures
|
|
||||||
message_sources = []
|
|
||||||
if "messages" in conv_data:
|
|
||||||
message_sources = conv_data["messages"]
|
|
||||||
elif "chat" in conv_data:
|
|
||||||
message_sources = conv_data["chat"]
|
|
||||||
elif "conversation" in conv_data:
|
|
||||||
message_sources = conv_data["conversation"]
|
|
||||||
else:
|
|
||||||
# If no clear message structure, try to extract from the object itself
|
|
||||||
if "content" in conv_data and "role" in conv_data:
|
|
||||||
message_sources = [conv_data]
|
|
||||||
|
|
||||||
for msg_data in message_sources:
|
|
||||||
message = self._extract_message_from_json(msg_data)
|
|
||||||
if message:
|
|
||||||
messages.append(message)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Extract conversation metadata
|
|
||||||
title = self._extract_title_from_conversation(conv_data, messages)
|
|
||||||
timestamp = self._extract_timestamp_from_conversation(conv_data)
|
|
||||||
|
|
||||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
|
||||||
|
|
||||||
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
|
|
||||||
"""
|
|
||||||
Extract message data from a JSON message object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg_data: Dictionary containing message data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with message data or None
|
|
||||||
"""
|
|
||||||
if not isinstance(msg_data, dict):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Extract content from various possible fields
|
|
||||||
content = ""
|
|
||||||
content_fields = ["content", "text", "message", "body"]
|
|
||||||
for field in content_fields:
|
|
||||||
if msg_data.get(field):
|
|
||||||
content = str(msg_data[field])
|
|
||||||
break
|
|
||||||
|
|
||||||
if not content or len(content.strip()) < 3:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Extract role (user/assistant/human/ai/claude)
|
|
||||||
role = "mixed" # Default role
|
|
||||||
role_fields = ["role", "sender", "from", "author", "type"]
|
|
||||||
for field in role_fields:
|
|
||||||
if msg_data.get(field):
|
|
||||||
role_value = str(msg_data[field]).lower()
|
|
||||||
if role_value in ["user", "human", "person"]:
|
|
||||||
role = "user"
|
|
||||||
elif role_value in ["assistant", "ai", "claude", "bot"]:
|
|
||||||
role = "assistant"
|
|
||||||
break
|
|
||||||
|
|
||||||
# Extract timestamp
|
|
||||||
timestamp = self._extract_timestamp_from_message(msg_data)
|
|
||||||
|
|
||||||
return {"role": role, "content": content, "timestamp": timestamp}
|
|
||||||
|
|
||||||
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
|
|
||||||
"""Extract timestamp from message data."""
|
|
||||||
timestamp_fields = ["timestamp", "created_at", "date", "time"]
|
|
||||||
for field in timestamp_fields:
|
|
||||||
if msg_data.get(field):
|
|
||||||
return str(msg_data[field])
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
|
|
||||||
"""Extract timestamp from conversation data."""
|
|
||||||
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
|
|
||||||
for field in timestamp_fields:
|
|
||||||
if conv_data.get(field):
|
|
||||||
return str(conv_data[field])
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
|
|
||||||
"""Extract or generate title for conversation."""
|
|
||||||
# Try to find explicit title
|
|
||||||
title_fields = ["title", "name", "subject", "topic"]
|
|
||||||
for field in title_fields:
|
|
||||||
if conv_data.get(field):
|
|
||||||
return str(conv_data[field])
|
|
||||||
|
|
||||||
# Generate title from first user message
|
|
||||||
for message in messages:
|
|
||||||
if message.get("role") == "user":
|
|
||||||
content = message.get("content", "")
|
|
||||||
if content:
|
|
||||||
# Use first 50 characters as title
|
|
||||||
title = content[:50].strip()
|
|
||||||
if len(content) > 50:
|
|
||||||
title += "..."
|
|
||||||
return title
|
|
||||||
|
|
||||||
return "Claude Conversation"
|
|
||||||
|
|
||||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
|
||||||
"""
|
|
||||||
Create concatenated content from conversation messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation: Dictionary containing conversation data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted concatenated content
|
|
||||||
"""
|
|
||||||
title = conversation.get("title", "Claude Conversation")
|
|
||||||
messages = conversation.get("messages", [])
|
|
||||||
timestamp = conversation.get("timestamp", "Unknown")
|
|
||||||
|
|
||||||
# Build message content
|
|
||||||
message_parts = []
|
|
||||||
for message in messages:
|
|
||||||
role = message.get("role", "mixed")
|
|
||||||
content = message.get("content", "")
|
|
||||||
msg_timestamp = message.get("timestamp", "")
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
prefix = "[You]"
|
|
||||||
elif role == "assistant":
|
|
||||||
prefix = "[Claude]"
|
|
||||||
else:
|
|
||||||
prefix = "[Message]"
|
|
||||||
|
|
||||||
# Add timestamp if available
|
|
||||||
if msg_timestamp:
|
|
||||||
prefix += f" ({msg_timestamp})"
|
|
||||||
|
|
||||||
message_parts.append(f"{prefix}: {content}")
|
|
||||||
|
|
||||||
concatenated_text = "\n\n".join(message_parts)
|
|
||||||
|
|
||||||
# Create final document content
|
|
||||||
doc_content = f"""Conversation: {title}
|
|
||||||
Date: {timestamp}
|
|
||||||
Messages ({len(messages)} messages):
|
|
||||||
|
|
||||||
{concatenated_text}
|
|
||||||
"""
|
|
||||||
return doc_content
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
|
||||||
"""
|
|
||||||
Load Claude export data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Directory containing Claude export files or path to specific file
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum number of conversations to process
|
|
||||||
claude_export_path (str): Specific path to Claude export file/directory
|
|
||||||
include_metadata (bool): Whether to include metadata in documents
|
|
||||||
"""
|
|
||||||
docs: list[Document] = []
|
|
||||||
max_count = load_kwargs.get("max_count", -1)
|
|
||||||
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
|
|
||||||
include_metadata = load_kwargs.get("include_metadata", True)
|
|
||||||
|
|
||||||
if not claude_export_path:
|
|
||||||
print("No Claude export path provided")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
export_path = Path(claude_export_path)
|
|
||||||
|
|
||||||
if not export_path.exists():
|
|
||||||
print(f"Claude export path not found: {export_path}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
json_contents = []
|
|
||||||
|
|
||||||
# Handle different input types
|
|
||||||
if export_path.is_file():
|
|
||||||
if export_path.suffix.lower() == ".zip":
|
|
||||||
# Extract JSON from zip file
|
|
||||||
json_contents = self._extract_json_from_zip(export_path)
|
|
||||||
elif export_path.suffix.lower() == ".json":
|
|
||||||
# Read JSON file directly
|
|
||||||
try:
|
|
||||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
|
||||||
json_contents.append(f.read())
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading JSON file {export_path}: {e}")
|
|
||||||
return docs
|
|
||||||
else:
|
|
||||||
print(f"Unsupported file type: {export_path.suffix}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
elif export_path.is_dir():
|
|
||||||
# Look for JSON files in directory
|
|
||||||
json_files = list(export_path.glob("*.json"))
|
|
||||||
zip_files = list(export_path.glob("*.zip"))
|
|
||||||
|
|
||||||
if json_files:
|
|
||||||
print(f"Found {len(json_files)} JSON files in directory")
|
|
||||||
for json_file in json_files:
|
|
||||||
try:
|
|
||||||
with open(json_file, encoding="utf-8", errors="ignore") as f:
|
|
||||||
json_contents.append(f.read())
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading JSON file {json_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if zip_files:
|
|
||||||
print(f"Found {len(zip_files)} ZIP files in directory")
|
|
||||||
for zip_file in zip_files:
|
|
||||||
zip_contents = self._extract_json_from_zip(zip_file)
|
|
||||||
json_contents.extend(zip_contents)
|
|
||||||
|
|
||||||
if not json_files and not zip_files:
|
|
||||||
print(f"No JSON or ZIP files found in {export_path}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
if not json_contents:
|
|
||||||
print("No JSON content found to process")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
# Parse conversations from JSON content
|
|
||||||
print("Parsing Claude conversations from JSON...")
|
|
||||||
all_conversations = []
|
|
||||||
for json_content in json_contents:
|
|
||||||
conversations = self._parse_claude_json(json_content)
|
|
||||||
all_conversations.extend(conversations)
|
|
||||||
|
|
||||||
if not all_conversations:
|
|
||||||
print("No conversations found in JSON content")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
print(f"Found {len(all_conversations)} conversations")
|
|
||||||
|
|
||||||
# Process conversations into documents
|
|
||||||
count = 0
|
|
||||||
for conversation in all_conversations:
|
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
# Create one document per conversation with concatenated messages
|
|
||||||
doc_content = self._create_concatenated_content(conversation)
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
if include_metadata:
|
|
||||||
metadata = {
|
|
||||||
"title": conversation.get("title", "Claude Conversation"),
|
|
||||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
|
||||||
"message_count": len(conversation.get("messages", [])),
|
|
||||||
"source": "Claude Export",
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=doc_content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Create separate documents for each message
|
|
||||||
for message in conversation.get("messages", []):
|
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
role = message.get("role", "mixed")
|
|
||||||
content = message.get("content", "")
|
|
||||||
msg_timestamp = message.get("timestamp", "")
|
|
||||||
|
|
||||||
if not content.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create document content with context
|
|
||||||
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
|
|
||||||
Role: {role}
|
|
||||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
|
||||||
Message: {content}
|
|
||||||
"""
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
if include_metadata:
|
|
||||||
metadata = {
|
|
||||||
"conversation_title": conversation.get("title", "Claude Conversation"),
|
|
||||||
"role": role,
|
|
||||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
|
||||||
"source": "Claude Export",
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=doc_content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
print(f"Created {len(docs)} documents from Claude export")
|
|
||||||
return docs
|
|
||||||
@@ -1,189 +0,0 @@
|
|||||||
"""
|
|
||||||
Claude RAG example using the unified interface.
|
|
||||||
Supports Claude export data from JSON files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
from chunking import create_text_chunks
|
|
||||||
|
|
||||||
from .claude_data.claude_reader import ClaudeReader
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeRAG(BaseRAGExample):
|
|
||||||
"""RAG example for Claude conversation data."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.max_items_default = -1 # Process all conversations by default
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="Claude",
|
|
||||||
description="Process and query Claude conversation exports with LEANN",
|
|
||||||
default_index_name="claude_conversations_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add Claude-specific arguments."""
|
|
||||||
claude_group = parser.add_argument_group("Claude Parameters")
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--export-path",
|
|
||||||
type=str,
|
|
||||||
default="./claude_export",
|
|
||||||
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
|
|
||||||
)
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Concatenate messages within conversations for better context (default: True)",
|
|
||||||
)
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--separate-messages",
|
|
||||||
action="store_true",
|
|
||||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
|
||||||
)
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
|
||||||
)
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_claude_exports(self, export_path: Path) -> list[Path]:
|
|
||||||
"""
|
|
||||||
Find Claude export files in the given path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
export_path: Path to search for exports
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of paths to Claude export files
|
|
||||||
"""
|
|
||||||
export_files = []
|
|
||||||
|
|
||||||
if export_path.is_file():
|
|
||||||
if export_path.suffix.lower() in [".zip", ".json"]:
|
|
||||||
export_files.append(export_path)
|
|
||||||
elif export_path.is_dir():
|
|
||||||
# Look for zip and json files
|
|
||||||
export_files.extend(export_path.glob("*.zip"))
|
|
||||||
export_files.extend(export_path.glob("*.json"))
|
|
||||||
|
|
||||||
return export_files
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load Claude export data and convert to text chunks."""
|
|
||||||
export_path = Path(args.export_path)
|
|
||||||
|
|
||||||
if not export_path.exists():
|
|
||||||
print(f"Claude export path not found: {export_path}")
|
|
||||||
print(
|
|
||||||
"Please ensure you have exported your Claude data and placed it in the correct location."
|
|
||||||
)
|
|
||||||
print("\nTo export your Claude data:")
|
|
||||||
print("1. Open Claude in your browser")
|
|
||||||
print("2. Look for export/download options in settings or conversation menu")
|
|
||||||
print("3. Download the conversation data (usually in JSON format)")
|
|
||||||
print("4. Place the file/directory at the specified path")
|
|
||||||
print(
|
|
||||||
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Find export files
|
|
||||||
export_files = self._find_claude_exports(export_path)
|
|
||||||
|
|
||||||
if not export_files:
|
|
||||||
print(f"No Claude export files (.json or .zip) found in: {export_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(export_files)} Claude export files")
|
|
||||||
|
|
||||||
# Create reader with appropriate settings
|
|
||||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
|
||||||
reader = ClaudeReader(concatenate_conversations=concatenate)
|
|
||||||
|
|
||||||
# Process each export file
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, export_file in enumerate(export_files):
|
|
||||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply max_items limit per file
|
|
||||||
max_per_file = -1
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_file = remaining
|
|
||||||
|
|
||||||
# Load conversations
|
|
||||||
documents = reader.load_data(
|
|
||||||
claude_export_path=str(export_file),
|
|
||||||
max_count=max_per_file,
|
|
||||||
include_metadata=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
print(f"Processed {len(documents)} conversations from this file")
|
|
||||||
else:
|
|
||||||
print(f"No conversations loaded from {export_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {export_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No conversations found to process!")
|
|
||||||
print("\nTroubleshooting:")
|
|
||||||
print("- Ensure the export file is a valid Claude export")
|
|
||||||
print("- Check that the JSON file contains conversation data")
|
|
||||||
print("- Try using a different export format or method")
|
|
||||||
print("- Check Claude's documentation for current export procedures")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
|
||||||
print("Now starting to split into text chunks... this may take some time")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Example queries for Claude RAG
|
|
||||||
print("\n🤖 Claude RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What did I ask Claude about Python programming?'")
|
|
||||||
print("- 'Show me conversations about machine learning'")
|
|
||||||
print("- 'Find discussions about code optimization'")
|
|
||||||
print("- 'What advice did Claude give me about software design?'")
|
|
||||||
print("- 'Search for conversations about debugging techniques'")
|
|
||||||
print("\nTo get started:")
|
|
||||||
print("1. Export your Claude conversation data")
|
|
||||||
print("2. Place the JSON/ZIP file in ./claude_export/")
|
|
||||||
print("3. Run this script to build your personal Claude knowledge base!")
|
|
||||||
print("\nOr run without --query for interactive mode\n")
|
|
||||||
|
|
||||||
rag = ClaudeRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""iMessage data processing module."""
|
|
||||||
@@ -1,342 +0,0 @@
|
|||||||
"""
|
|
||||||
iMessage data reader.
|
|
||||||
|
|
||||||
Reads and processes iMessage conversation data from the macOS Messages database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
|
|
||||||
class IMessageReader(BaseReader):
|
|
||||||
"""
|
|
||||||
iMessage data reader.
|
|
||||||
|
|
||||||
Reads iMessage conversation data from the macOS Messages database (chat.db).
|
|
||||||
Processes conversations into structured documents with metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
|
||||||
"""
|
|
||||||
self.concatenate_conversations = concatenate_conversations
|
|
||||||
|
|
||||||
def _get_default_chat_db_path(self) -> Path:
|
|
||||||
"""
|
|
||||||
Get the default path to the iMessage chat database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the chat.db file
|
|
||||||
"""
|
|
||||||
home = Path.home()
|
|
||||||
return home / "Library" / "Messages" / "chat.db"
|
|
||||||
|
|
||||||
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
|
|
||||||
"""
|
|
||||||
Convert Cocoa timestamp to readable format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted timestamp string
|
|
||||||
"""
|
|
||||||
if cocoa_timestamp == 0:
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
|
|
||||||
# Convert to seconds and add to Unix epoch
|
|
||||||
cocoa_epoch = datetime(2001, 1, 1)
|
|
||||||
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
|
|
||||||
message_time = cocoa_epoch.timestamp() + unix_timestamp
|
|
||||||
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
except (ValueError, OSError):
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
def _get_contact_name(self, handle_id: str) -> str:
|
|
||||||
"""
|
|
||||||
Get a readable contact name from handle ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
handle_id: The handle ID (phone number or email)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted contact name
|
|
||||||
"""
|
|
||||||
if not handle_id:
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
# Clean up phone numbers and emails for display
|
|
||||||
if "@" in handle_id:
|
|
||||||
return handle_id # Email address
|
|
||||||
elif handle_id.startswith("+"):
|
|
||||||
return handle_id # International phone number
|
|
||||||
else:
|
|
||||||
# Try to format as phone number
|
|
||||||
digits = "".join(filter(str.isdigit, handle_id))
|
|
||||||
if len(digits) == 10:
|
|
||||||
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
|
|
||||||
elif len(digits) == 11 and digits[0] == "1":
|
|
||||||
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
|
|
||||||
else:
|
|
||||||
return handle_id
|
|
||||||
|
|
||||||
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Read messages from the iMessage database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_path: Path to the chat.db file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of message dictionaries
|
|
||||||
"""
|
|
||||||
if not db_path.exists():
|
|
||||||
print(f"iMessage database not found at: {db_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Connect to the database
|
|
||||||
conn = sqlite3.connect(str(db_path))
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# Query to get messages with chat and handle information
|
|
||||||
query = """
|
|
||||||
SELECT
|
|
||||||
m.ROWID as message_id,
|
|
||||||
m.text,
|
|
||||||
m.date,
|
|
||||||
m.is_from_me,
|
|
||||||
m.service,
|
|
||||||
c.chat_identifier,
|
|
||||||
c.display_name as chat_display_name,
|
|
||||||
h.id as handle_id,
|
|
||||||
c.ROWID as chat_id
|
|
||||||
FROM message m
|
|
||||||
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
|
|
||||||
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
|
|
||||||
LEFT JOIN handle h ON m.handle_id = h.ROWID
|
|
||||||
WHERE m.text IS NOT NULL AND m.text != ''
|
|
||||||
ORDER BY c.ROWID, m.date
|
|
||||||
"""
|
|
||||||
|
|
||||||
cursor.execute(query)
|
|
||||||
rows = cursor.fetchall()
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
for row in rows:
|
|
||||||
(
|
|
||||||
message_id,
|
|
||||||
text,
|
|
||||||
date,
|
|
||||||
is_from_me,
|
|
||||||
service,
|
|
||||||
chat_identifier,
|
|
||||||
chat_display_name,
|
|
||||||
handle_id,
|
|
||||||
chat_id,
|
|
||||||
) = row
|
|
||||||
|
|
||||||
message = {
|
|
||||||
"message_id": message_id,
|
|
||||||
"text": text,
|
|
||||||
"timestamp": self._convert_cocoa_timestamp(date),
|
|
||||||
"is_from_me": bool(is_from_me),
|
|
||||||
"service": service or "iMessage",
|
|
||||||
"chat_identifier": chat_identifier or "Unknown",
|
|
||||||
"chat_display_name": chat_display_name or "Unknown Chat",
|
|
||||||
"handle_id": handle_id or "Unknown",
|
|
||||||
"contact_name": self._get_contact_name(handle_id or ""),
|
|
||||||
"chat_id": chat_id,
|
|
||||||
}
|
|
||||||
messages.append(message)
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
print(f"Found {len(messages)} messages in database")
|
|
||||||
return messages
|
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
print(f"Error reading iMessage database: {e}")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Unexpected error reading iMessage database: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
|
|
||||||
"""
|
|
||||||
Group messages by chat ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of message dictionaries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping chat_id to list of messages
|
|
||||||
"""
|
|
||||||
chats = {}
|
|
||||||
for message in messages:
|
|
||||||
chat_id = message["chat_id"]
|
|
||||||
if chat_id not in chats:
|
|
||||||
chats[chat_id] = []
|
|
||||||
chats[chat_id].append(message)
|
|
||||||
|
|
||||||
return chats
|
|
||||||
|
|
||||||
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
|
|
||||||
"""
|
|
||||||
Create concatenated content from chat messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: The chat ID
|
|
||||||
messages: List of messages in the chat
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Concatenated text content
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Get chat info from first message
|
|
||||||
first_msg = messages[0]
|
|
||||||
chat_name = first_msg["chat_display_name"]
|
|
||||||
chat_identifier = first_msg["chat_identifier"]
|
|
||||||
|
|
||||||
# Build message content
|
|
||||||
message_parts = []
|
|
||||||
for message in messages:
|
|
||||||
timestamp = message["timestamp"]
|
|
||||||
is_from_me = message["is_from_me"]
|
|
||||||
text = message["text"]
|
|
||||||
contact_name = message["contact_name"]
|
|
||||||
|
|
||||||
if is_from_me:
|
|
||||||
prefix = "[You]"
|
|
||||||
else:
|
|
||||||
prefix = f"[{contact_name}]"
|
|
||||||
|
|
||||||
if timestamp != "Unknown":
|
|
||||||
prefix += f" ({timestamp})"
|
|
||||||
|
|
||||||
message_parts.append(f"{prefix}: {text}")
|
|
||||||
|
|
||||||
concatenated_text = "\n\n".join(message_parts)
|
|
||||||
|
|
||||||
doc_content = f"""Chat: {chat_name}
|
|
||||||
Identifier: {chat_identifier}
|
|
||||||
Messages ({len(messages)} messages):
|
|
||||||
|
|
||||||
{concatenated_text}
|
|
||||||
"""
|
|
||||||
return doc_content
|
|
||||||
|
|
||||||
def _create_individual_content(self, message: dict) -> str:
|
|
||||||
"""
|
|
||||||
Create content for individual message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: Message dictionary
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted message content
|
|
||||||
"""
|
|
||||||
timestamp = message["timestamp"]
|
|
||||||
is_from_me = message["is_from_me"]
|
|
||||||
text = message["text"]
|
|
||||||
contact_name = message["contact_name"]
|
|
||||||
chat_name = message["chat_display_name"]
|
|
||||||
|
|
||||||
sender = "You" if is_from_me else contact_name
|
|
||||||
|
|
||||||
return f"""Message from {sender} in chat "{chat_name}"
|
|
||||||
Time: {timestamp}
|
|
||||||
Content: {text}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
|
||||||
"""
|
|
||||||
Load iMessage data and return as documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Optional path to directory containing chat.db file.
|
|
||||||
If not provided, uses default macOS location.
|
|
||||||
**load_kwargs: Additional arguments (unused)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Document objects containing iMessage data
|
|
||||||
"""
|
|
||||||
docs = []
|
|
||||||
|
|
||||||
# Determine database path
|
|
||||||
if input_dir:
|
|
||||||
db_path = Path(input_dir) / "chat.db"
|
|
||||||
else:
|
|
||||||
db_path = self._get_default_chat_db_path()
|
|
||||||
|
|
||||||
print(f"Reading iMessage database from: {db_path}")
|
|
||||||
|
|
||||||
# Read messages from database
|
|
||||||
messages = self._read_messages_from_db(db_path)
|
|
||||||
if not messages:
|
|
||||||
return docs
|
|
||||||
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
# Group messages by chat and create concatenated documents
|
|
||||||
chats = self._group_messages_by_chat(messages)
|
|
||||||
|
|
||||||
for chat_id, chat_messages in chats.items():
|
|
||||||
if not chat_messages:
|
|
||||||
continue
|
|
||||||
|
|
||||||
content = self._create_concatenated_content(chat_id, chat_messages)
|
|
||||||
|
|
||||||
# Create metadata
|
|
||||||
first_msg = chat_messages[0]
|
|
||||||
last_msg = chat_messages[-1]
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"source": "iMessage",
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"chat_name": first_msg["chat_display_name"],
|
|
||||||
"chat_identifier": first_msg["chat_identifier"],
|
|
||||||
"message_count": len(chat_messages),
|
|
||||||
"first_message_date": first_msg["timestamp"],
|
|
||||||
"last_message_date": last_msg["timestamp"],
|
|
||||||
"participants": list(
|
|
||||||
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Create individual documents for each message
|
|
||||||
for message in messages:
|
|
||||||
content = self._create_individual_content(message)
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"source": "iMessage",
|
|
||||||
"message_id": message["message_id"],
|
|
||||||
"chat_id": message["chat_id"],
|
|
||||||
"chat_name": message["chat_display_name"],
|
|
||||||
"chat_identifier": message["chat_identifier"],
|
|
||||||
"timestamp": message["timestamp"],
|
|
||||||
"is_from_me": message["is_from_me"],
|
|
||||||
"contact_name": message["contact_name"],
|
|
||||||
"service": message["service"],
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
|
|
||||||
print(f"Created {len(docs)} documents from iMessage data")
|
|
||||||
return docs
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
"""
|
|
||||||
iMessage RAG Example.
|
|
||||||
|
|
||||||
This example demonstrates how to build a RAG system on your iMessage conversation history.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from leann.chunking_utils import create_text_chunks
|
|
||||||
|
|
||||||
from apps.base_rag_example import BaseRAGExample
|
|
||||||
from apps.imessage_data.imessage_reader import IMessageReader
|
|
||||||
|
|
||||||
|
|
||||||
class IMessageRAG(BaseRAGExample):
|
|
||||||
"""RAG example for iMessage conversation history."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
name="iMessage",
|
|
||||||
description="RAG on your iMessage conversation history",
|
|
||||||
default_index_name="imessage_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add iMessage-specific arguments."""
|
|
||||||
imessage_group = parser.add_argument_group("iMessage Parameters")
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--db-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
|
|
||||||
)
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Concatenate messages within conversations for better context (default: True)",
|
|
||||||
)
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--no-concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
help="Process each message individually instead of concatenating by conversation",
|
|
||||||
)
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--chunk-size",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Maximum characters per text chunk (default: 1000)",
|
|
||||||
)
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--chunk-overlap",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="Overlap between text chunks (default: 200)",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load iMessage history and convert to text chunks."""
|
|
||||||
print("Loading iMessage conversation history...")
|
|
||||||
|
|
||||||
# Determine concatenation setting
|
|
||||||
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
|
|
||||||
|
|
||||||
# Initialize iMessage reader
|
|
||||||
reader = IMessageReader(concatenate_conversations=concatenate)
|
|
||||||
|
|
||||||
# Load documents
|
|
||||||
try:
|
|
||||||
if args.db_path:
|
|
||||||
# Use custom database path
|
|
||||||
db_dir = str(Path(args.db_path).parent)
|
|
||||||
documents = reader.load_data(input_dir=db_dir)
|
|
||||||
else:
|
|
||||||
# Use default macOS location
|
|
||||||
documents = reader.load_data()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading iMessage data: {e}")
|
|
||||||
print("\nTroubleshooting tips:")
|
|
||||||
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
|
|
||||||
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
|
|
||||||
print("3. Try specifying a custom path with --db-path if you have a backup")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
print("No iMessage conversations found!")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} iMessage documents")
|
|
||||||
|
|
||||||
# Show some statistics
|
|
||||||
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
|
|
||||||
print(f"Total messages: {total_messages}")
|
|
||||||
|
|
||||||
if concatenate:
|
|
||||||
# Show chat statistics
|
|
||||||
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
|
|
||||||
unique_chats = len(set(chat_names))
|
|
||||||
print(f"Unique conversations: {unique_chats}")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
documents,
|
|
||||||
chunk_size=args.chunk_size,
|
|
||||||
chunk_overlap=args.chunk_overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply max_items limit if specified
|
|
||||||
if args.max_items > 0:
|
|
||||||
all_texts = all_texts[: args.max_items]
|
|
||||||
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main entry point."""
|
|
||||||
app = IMessageRAG()
|
|
||||||
await app.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,18 +1,12 @@
|
|||||||
import concurrent.futures
|
from __future__ import annotations
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, cast
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
|
||||||
_repo_root = Path(current_file).resolve().parents[3]
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
@@ -22,380 +16,6 @@ def _ensure_repo_paths_importable(current_file: str) -> None:
|
|||||||
sys.path.append(str(_leann_hnsw_pkg))
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
def _find_backend_module_file() -> Optional[Path]:
|
|
||||||
"""Best-effort locate the backend leann_multi_vector.py file, avoiding this file."""
|
|
||||||
this_file = Path(__file__).resolve()
|
|
||||||
candidates: list[Path] = []
|
|
||||||
|
|
||||||
# Common in-repo location
|
|
||||||
repo_root = this_file.parents[3]
|
|
||||||
candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py")
|
|
||||||
candidates.append(
|
|
||||||
repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py"
|
|
||||||
)
|
|
||||||
|
|
||||||
for cand in candidates:
|
|
||||||
try:
|
|
||||||
if cand.exists() and cand.resolve() != this_file:
|
|
||||||
return cand.resolve()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Fallback: scan sys.path for another leann_multi_vector.py different from this file
|
|
||||||
for p in list(sys.path):
|
|
||||||
try:
|
|
||||||
cand = Path(p) / "leann_multi_vector.py"
|
|
||||||
if cand.exists() and cand.resolve() != this_file:
|
|
||||||
return cand.resolve()
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
_BACKEND_LEANN_CLASS: Optional[type] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_backend_leann_multi_vector() -> type:
|
|
||||||
"""Load backend LeannMultiVector class even if this file shadows its module name."""
|
|
||||||
global _BACKEND_LEANN_CLASS
|
|
||||||
if _BACKEND_LEANN_CLASS is not None:
|
|
||||||
return _BACKEND_LEANN_CLASS
|
|
||||||
|
|
||||||
backend_path = _find_backend_module_file()
|
|
||||||
if backend_path is None:
|
|
||||||
# Fallback to local implementation in this module
|
|
||||||
try:
|
|
||||||
cls = LeannMultiVector # type: ignore[name-defined]
|
|
||||||
_BACKEND_LEANN_CLASS = cls
|
|
||||||
return cls
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not locate backend 'leann_multi_vector.py' and no local implementation found. "
|
|
||||||
"Ensure the leann backend is available under packages/leann-backend-hnsw or installed."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
module_name = "leann_hnsw_backend_module"
|
|
||||||
spec = importlib.util.spec_from_file_location(module_name, str(backend_path))
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise ImportError(f"Failed to create spec for backend module at {backend_path}")
|
|
||||||
backend_module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules[module_name] = backend_module
|
|
||||||
spec.loader.exec_module(backend_module) # type: ignore[assignment]
|
|
||||||
|
|
||||||
if not hasattr(backend_module, "LeannMultiVector"):
|
|
||||||
raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}")
|
|
||||||
_BACKEND_LEANN_CLASS = backend_module.LeannMultiVector
|
|
||||||
return _BACKEND_LEANN_CLASS
|
|
||||||
|
|
||||||
|
|
||||||
def _natural_sort_key(name: str) -> int:
|
|
||||||
m = re.search(r"\d+", name)
|
|
||||||
return int(m.group()) if m else 0
|
|
||||||
|
|
||||||
|
|
||||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
|
||||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
|
||||||
filenames = sorted(filenames, key=_natural_sort_key)
|
|
||||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
|
||||||
images = [Image.open(p) for p in filepaths]
|
|
||||||
return filepaths, images
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
|
||||||
if not pdf_path:
|
|
||||||
return
|
|
||||||
os.makedirs(pages_dir, exist_ok=True)
|
|
||||||
try:
|
|
||||||
from pdf2image import convert_from_path
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
|
||||||
) from e
|
|
||||||
images = convert_from_path(pdf_path, dpi=dpi)
|
|
||||||
for i, image in enumerate(images):
|
|
||||||
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
|
||||||
|
|
||||||
|
|
||||||
def _select_device_and_dtype():
|
|
||||||
import torch
|
|
||||||
from colpali_engine.utils.torch_utils import get_torch_device
|
|
||||||
|
|
||||||
device_str = (
|
|
||||||
"cuda"
|
|
||||||
if torch.cuda.is_available()
|
|
||||||
else (
|
|
||||||
"mps"
|
|
||||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
|
||||||
else "cpu"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
device = get_torch_device(device_str)
|
|
||||||
# Stable dtype selection to avoid NaNs:
|
|
||||||
# - CUDA: prefer bfloat16 if supported, else float16
|
|
||||||
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
|
||||||
# - CPU: float32
|
|
||||||
if device_str == "cuda":
|
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
||||||
try:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
elif device_str == "mps":
|
|
||||||
dtype = torch.float32
|
|
||||||
else:
|
|
||||||
dtype = torch.float32
|
|
||||||
return device_str, device, dtype
|
|
||||||
|
|
||||||
|
|
||||||
def _load_colvision(model_choice: str):
|
|
||||||
import torch
|
|
||||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
|
||||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
|
||||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
|
||||||
|
|
||||||
device_str, device, dtype = _select_device_and_dtype()
|
|
||||||
|
|
||||||
if model_choice == "colqwen2":
|
|
||||||
model_name = "vidore/colqwen2-v1.0"
|
|
||||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
|
||||||
attn_implementation = (
|
|
||||||
"flash_attention_2"
|
|
||||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
|
||||||
else "eager"
|
|
||||||
)
|
|
||||||
model = ColQwen2.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map=device,
|
|
||||||
attn_implementation=attn_implementation,
|
|
||||||
).eval()
|
|
||||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
|
||||||
else:
|
|
||||||
model_name = "vidore/colpali-v1.2"
|
|
||||||
model = ColPali.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map=device,
|
|
||||||
).eval()
|
|
||||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
|
||||||
|
|
||||||
return model_name, model, processor, device_str, device, dtype
|
|
||||||
|
|
||||||
|
|
||||||
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
|
||||||
import torch
|
|
||||||
from colpali_engine.utils.torch_utils import ListDataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# Ensure deterministic eval and autocast for stability
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset=ListDataset[Image.Image](images),
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=lambda x: processor.process_images(x),
|
|
||||||
)
|
|
||||||
|
|
||||||
doc_vecs: list[Any] = []
|
|
||||||
for batch_doc in tqdm(dataloader, desc="Embedding images"):
|
|
||||||
with torch.no_grad():
|
|
||||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
|
||||||
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
|
||||||
if model.device.type == "cuda":
|
|
||||||
with torch.autocast(
|
|
||||||
device_type="cuda",
|
|
||||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
|
||||||
):
|
|
||||||
embeddings_doc = model(**batch_doc)
|
|
||||||
else:
|
|
||||||
embeddings_doc = model(**batch_doc)
|
|
||||||
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
|
||||||
return doc_vecs
|
|
||||||
|
|
||||||
|
|
||||||
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
|
||||||
import torch
|
|
||||||
from colpali_engine.utils.torch_utils import ListDataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset=ListDataset[str](queries),
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=lambda x: processor.process_queries(x),
|
|
||||||
)
|
|
||||||
|
|
||||||
q_vecs: list[Any] = []
|
|
||||||
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
|
||||||
with torch.no_grad():
|
|
||||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
|
||||||
if model.device.type == "cuda":
|
|
||||||
with torch.autocast(
|
|
||||||
device_type="cuda",
|
|
||||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
|
||||||
):
|
|
||||||
embeddings_query = model(**batch_query)
|
|
||||||
else:
|
|
||||||
embeddings_query = model(**batch_query)
|
|
||||||
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
|
||||||
return q_vecs
|
|
||||||
|
|
||||||
|
|
||||||
def _build_index(
|
|
||||||
index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image]
|
|
||||||
) -> Any:
|
|
||||||
LeannMultiVector = _get_backend_leann_multi_vector()
|
|
||||||
dim = int(doc_vecs[0].shape[-1])
|
|
||||||
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
|
||||||
retriever.create_collection()
|
|
||||||
for i, vec in enumerate(doc_vecs):
|
|
||||||
data = {
|
|
||||||
"colbert_vecs": vec.float().numpy(),
|
|
||||||
"doc_id": i,
|
|
||||||
"filepath": filepaths[i],
|
|
||||||
"image": images[i], # Include the original image
|
|
||||||
}
|
|
||||||
retriever.insert(data)
|
|
||||||
retriever.create_index()
|
|
||||||
return retriever
|
|
||||||
|
|
||||||
|
|
||||||
def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
|
|
||||||
LeannMultiVector = _get_backend_leann_multi_vector()
|
|
||||||
index_base = Path(index_path)
|
|
||||||
# Check for the actual HNSW index file written by the backend + our sidecar files
|
|
||||||
index_file = index_base.parent / f"{index_base.stem}.index"
|
|
||||||
meta = index_base.parent / f"{index_base.name}.meta.json"
|
|
||||||
labels = index_base.parent / f"{index_base.name}.labels.json"
|
|
||||||
if index_file.exists() and meta.exists() and labels.exists():
|
|
||||||
try:
|
|
||||||
with open(meta, encoding="utf-8") as f:
|
|
||||||
meta_json = json.load(f)
|
|
||||||
dim = int(meta_json.get("dimensions", 128))
|
|
||||||
except Exception:
|
|
||||||
dim = 128
|
|
||||||
return LeannMultiVector(index_path=index_path, dim=dim)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_similarity_map(
|
|
||||||
model,
|
|
||||||
processor,
|
|
||||||
image: Image.Image,
|
|
||||||
query: str,
|
|
||||||
token_idx: Optional[int] = None,
|
|
||||||
output_path: Optional[str] = None,
|
|
||||||
) -> tuple[int, float]:
|
|
||||||
import torch
|
|
||||||
from colpali_engine.interpretability import (
|
|
||||||
get_similarity_maps_from_embeddings,
|
|
||||||
plot_similarity_map,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_images = processor.process_images([image]).to(model.device)
|
|
||||||
batch_queries = processor.process_queries([query]).to(model.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
image_embeddings = model.forward(**batch_images)
|
|
||||||
query_embeddings = model.forward(**batch_queries)
|
|
||||||
|
|
||||||
n_patches = processor.get_n_patches(
|
|
||||||
image_size=image.size,
|
|
||||||
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
|
||||||
)
|
|
||||||
image_mask = processor.get_image_mask(batch_images)
|
|
||||||
|
|
||||||
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
|
||||||
image_embeddings=image_embeddings,
|
|
||||||
query_embeddings=query_embeddings,
|
|
||||||
n_patches=n_patches,
|
|
||||||
image_mask=image_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
similarity_maps = batched_similarity_maps[0]
|
|
||||||
|
|
||||||
# Determine token index if not provided: choose the token with highest max score
|
|
||||||
if token_idx is None:
|
|
||||||
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
|
||||||
token_idx = int(per_token_max.argmax().item())
|
|
||||||
|
|
||||||
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
|
||||||
|
|
||||||
if output_path:
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
fig, ax = plot_similarity_map(
|
|
||||||
image=image,
|
|
||||||
similarity_map=similarity_maps[token_idx],
|
|
||||||
figsize=(14, 14),
|
|
||||||
show_colorbar=False,
|
|
||||||
)
|
|
||||||
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
||||||
plt.savefig(output_path, bbox_inches="tight")
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
return token_idx, float(max_sim_score)
|
|
||||||
|
|
||||||
|
|
||||||
class QwenVL:
|
|
||||||
def __init__(self, device: str):
|
|
||||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
|
||||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
|
||||||
|
|
||||||
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
|
||||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
||||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
|
||||||
torch_dtype="auto",
|
|
||||||
device_map=device,
|
|
||||||
attn_implementation=attn_implementation,
|
|
||||||
)
|
|
||||||
|
|
||||||
min_pixels = 256 * 28 * 28
|
|
||||||
max_pixels = 1280 * 28 * 28
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
|
||||||
)
|
|
||||||
|
|
||||||
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
|
||||||
import base64
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
from qwen_vl_utils import process_vision_info
|
|
||||||
|
|
||||||
content = []
|
|
||||||
for img in images:
|
|
||||||
buffer = BytesIO()
|
|
||||||
img.save(buffer, format="jpeg")
|
|
||||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
||||||
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
|
||||||
content.append({"type": "text", "text": query})
|
|
||||||
messages = [{"role": "user", "content": content}]
|
|
||||||
|
|
||||||
text = self.processor.apply_chat_template(
|
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
image_inputs, video_inputs = process_vision_info(messages)
|
|
||||||
inputs = self.processor(
|
|
||||||
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
|
||||||
)
|
|
||||||
inputs = inputs.to(self.model.device)
|
|
||||||
|
|
||||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
|
||||||
generated_ids_trimmed = [
|
|
||||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
||||||
]
|
|
||||||
return self.processor.batch_decode(
|
|
||||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
|
|
||||||
# Ensure repo paths are importable for dynamic backend loading
|
|
||||||
_ensure_repo_paths_importable(__file__)
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
||||||
@@ -425,7 +45,6 @@ class LeannMultiVector:
|
|||||||
"is_recompute": is_recompute,
|
"is_recompute": is_recompute,
|
||||||
}
|
}
|
||||||
self._labels_meta: list[dict] = []
|
self._labels_meta: list[dict] = []
|
||||||
self._docid_to_indices: dict[int, list[int]] | None = None
|
|
||||||
|
|
||||||
def _meta_dict(self) -> dict:
|
def _meta_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@@ -450,7 +69,6 @@ class LeannMultiVector:
|
|||||||
"doc_id": int(data["doc_id"]),
|
"doc_id": int(data["doc_id"]),
|
||||||
"filepath": data.get("filepath", ""),
|
"filepath": data.get("filepath", ""),
|
||||||
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
||||||
"image": data.get("image"), # PIL Image object (optional)
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -462,15 +80,6 @@ class LeannMultiVector:
|
|||||||
index_path_obj = Path(self.index_path)
|
index_path_obj = Path(self.index_path)
|
||||||
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||||
|
|
||||||
def _embeddings_path(self) -> Path:
|
|
||||||
index_path_obj = Path(self.index_path)
|
|
||||||
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
|
|
||||||
|
|
||||||
def _images_dir_path(self) -> Path:
|
|
||||||
"""Directory where original images are stored."""
|
|
||||||
index_path_obj = Path(self.index_path)
|
|
||||||
return index_path_obj.parent / f"{index_path_obj.name}.images"
|
|
||||||
|
|
||||||
def create_index(self) -> None:
|
def create_index(self) -> None:
|
||||||
if not self._pending_items:
|
if not self._pending_items:
|
||||||
return
|
return
|
||||||
@@ -478,23 +87,10 @@ class LeannMultiVector:
|
|||||||
embeddings: list[np.ndarray] = []
|
embeddings: list[np.ndarray] = []
|
||||||
labels_meta: list[dict] = []
|
labels_meta: list[dict] = []
|
||||||
|
|
||||||
# Create images directory if needed
|
|
||||||
images_dir = self._images_dir_path()
|
|
||||||
images_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
for item in self._pending_items:
|
for item in self._pending_items:
|
||||||
doc_id = int(item["doc_id"])
|
doc_id = int(item["doc_id"])
|
||||||
filepath = item.get("filepath", "")
|
filepath = item.get("filepath", "")
|
||||||
colbert_vecs = item["colbert_vecs"]
|
colbert_vecs = item["colbert_vecs"]
|
||||||
image = item.get("image")
|
|
||||||
|
|
||||||
# Save image if provided
|
|
||||||
image_path = ""
|
|
||||||
if image is not None and isinstance(image, Image.Image):
|
|
||||||
image_filename = f"doc_{doc_id}.png"
|
|
||||||
image_path = str(images_dir / image_filename)
|
|
||||||
image.save(image_path, "PNG")
|
|
||||||
|
|
||||||
for seq_id, vec in enumerate(colbert_vecs):
|
for seq_id, vec in enumerate(colbert_vecs):
|
||||||
vec_np = np.asarray(vec, dtype=np.float32)
|
vec_np = np.asarray(vec, dtype=np.float32)
|
||||||
embeddings.append(vec_np)
|
embeddings.append(vec_np)
|
||||||
@@ -504,7 +100,6 @@ class LeannMultiVector:
|
|||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
"seq_id": int(seq_id),
|
"seq_id": int(seq_id),
|
||||||
"filepath": filepath,
|
"filepath": filepath,
|
||||||
"image_path": image_path, # Store the path to the saved image
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -512,6 +107,7 @@ class LeannMultiVector:
|
|||||||
return
|
return
|
||||||
|
|
||||||
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||||
|
# print shape of embeddings_np
|
||||||
print(embeddings_np.shape)
|
print(embeddings_np.shape)
|
||||||
|
|
||||||
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||||
@@ -525,9 +121,6 @@ class LeannMultiVector:
|
|||||||
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||||
_json.dump(labels_meta, f)
|
_json.dump(labels_meta, f)
|
||||||
|
|
||||||
# Persist embeddings for exact reranking
|
|
||||||
np.save(self._embeddings_path(), embeddings_np)
|
|
||||||
|
|
||||||
self._labels_meta = labels_meta
|
self._labels_meta = labels_meta
|
||||||
|
|
||||||
def _load_labels_meta_if_needed(self) -> None:
|
def _load_labels_meta_if_needed(self) -> None:
|
||||||
@@ -540,19 +133,6 @@ class LeannMultiVector:
|
|||||||
with open(labels_path, encoding="utf-8") as f:
|
with open(labels_path, encoding="utf-8") as f:
|
||||||
self._labels_meta = _json.load(f)
|
self._labels_meta = _json.load(f)
|
||||||
|
|
||||||
def _build_docid_to_indices_if_needed(self) -> None:
|
|
||||||
if self._docid_to_indices is not None:
|
|
||||||
return
|
|
||||||
self._load_labels_meta_if_needed()
|
|
||||||
mapping: dict[int, list[int]] = {}
|
|
||||||
for idx, meta in enumerate(self._labels_meta):
|
|
||||||
try:
|
|
||||||
doc_id = int(meta["doc_id"]) # type: ignore[index]
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
mapping.setdefault(doc_id, []).append(idx)
|
|
||||||
self._docid_to_indices = mapping
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||||
) -> list[tuple[float, int]]:
|
) -> list[tuple[float, int]]:
|
||||||
@@ -600,181 +180,3 @@ class LeannMultiVector:
|
|||||||
|
|
||||||
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||||
return scores[:topk] if len(scores) >= topk else scores
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
|
|
||||||
def search_exact(
|
|
||||||
self,
|
|
||||||
data: np.ndarray,
|
|
||||||
topk: int,
|
|
||||||
*,
|
|
||||||
first_stage_k: int = 200,
|
|
||||||
max_workers: int = 32,
|
|
||||||
) -> list[tuple[float, int]]:
|
|
||||||
"""
|
|
||||||
High-precision MaxSim reranking over candidate documents.
|
|
||||||
|
|
||||||
Steps:
|
|
||||||
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
|
|
||||||
2) For each candidate doc, load all its token embeddings and compute
|
|
||||||
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
|
|
||||||
|
|
||||||
Returns top-k list of (score, doc_id).
|
|
||||||
"""
|
|
||||||
# Normalize inputs
|
|
||||||
if data.ndim == 1:
|
|
||||||
data = data.reshape(1, -1)
|
|
||||||
if data.dtype != np.float32:
|
|
||||||
data = data.astype(np.float32)
|
|
||||||
|
|
||||||
self._load_labels_meta_if_needed()
|
|
||||||
self._build_docid_to_indices_if_needed()
|
|
||||||
|
|
||||||
emb_path = self._embeddings_path()
|
|
||||||
if not emb_path.exists():
|
|
||||||
# Fallback to approximate if we don't have persisted embeddings
|
|
||||||
return self.search(data, topk, first_stage_k=first_stage_k)
|
|
||||||
|
|
||||||
# Memory-map embeddings to avoid loading all into RAM
|
|
||||||
all_embeddings = np.load(emb_path, mmap_mode="r")
|
|
||||||
if all_embeddings.dtype != np.float32:
|
|
||||||
all_embeddings = all_embeddings.astype(np.float32)
|
|
||||||
|
|
||||||
# First-stage ANN to collect candidate doc_ids
|
|
||||||
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
|
||||||
raw = searcher.search(
|
|
||||||
data,
|
|
||||||
first_stage_k,
|
|
||||||
recompute_embeddings=False,
|
|
||||||
complexity=128,
|
|
||||||
beam_width=1,
|
|
||||||
prune_ratio=0.0,
|
|
||||||
batch_size=0,
|
|
||||||
)
|
|
||||||
labels = raw.get("labels")
|
|
||||||
if labels is None:
|
|
||||||
return []
|
|
||||||
candidate_doc_ids: set[int] = set()
|
|
||||||
for batch in labels:
|
|
||||||
for sid in batch:
|
|
||||||
try:
|
|
||||||
idx = int(sid)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
if 0 <= idx < len(self._labels_meta):
|
|
||||||
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
|
|
||||||
|
|
||||||
# Exact scoring per doc (parallelized)
|
|
||||||
assert self._docid_to_indices is not None
|
|
||||||
|
|
||||||
def _score_one(doc_id: int) -> tuple[float, int]:
|
|
||||||
token_indices = self._docid_to_indices.get(doc_id, [])
|
|
||||||
if not token_indices:
|
|
||||||
return (0.0, doc_id)
|
|
||||||
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
|
||||||
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
|
|
||||||
sim = np.dot(data, doc_vecs.T)
|
|
||||||
# nan-safe
|
|
||||||
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
|
||||||
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
|
||||||
return (float(score), doc_id)
|
|
||||||
|
|
||||||
scores: list[tuple[float, int]] = []
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
|
||||||
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
|
|
||||||
for fut in concurrent.futures.as_completed(futures):
|
|
||||||
scores.append(fut.result())
|
|
||||||
|
|
||||||
scores.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
return scores[:topk] if len(scores) >= topk else scores
|
|
||||||
|
|
||||||
def search_exact_all(
|
|
||||||
self,
|
|
||||||
data: np.ndarray,
|
|
||||||
topk: int,
|
|
||||||
*,
|
|
||||||
max_workers: int = 32,
|
|
||||||
) -> list[tuple[float, int]]:
|
|
||||||
"""
|
|
||||||
Exact MaxSim over ALL documents (no ANN pre-filtering).
|
|
||||||
|
|
||||||
This computes, for each document, sum_i max_j dot(q_i, d_j).
|
|
||||||
It memory-maps the persisted token-embedding matrix for scalability.
|
|
||||||
"""
|
|
||||||
if data.ndim == 1:
|
|
||||||
data = data.reshape(1, -1)
|
|
||||||
if data.dtype != np.float32:
|
|
||||||
data = data.astype(np.float32)
|
|
||||||
|
|
||||||
self._load_labels_meta_if_needed()
|
|
||||||
self._build_docid_to_indices_if_needed()
|
|
||||||
|
|
||||||
emb_path = self._embeddings_path()
|
|
||||||
if not emb_path.exists():
|
|
||||||
return self.search(data, topk)
|
|
||||||
|
|
||||||
all_embeddings = np.load(emb_path, mmap_mode="r")
|
|
||||||
if all_embeddings.dtype != np.float32:
|
|
||||||
all_embeddings = all_embeddings.astype(np.float32)
|
|
||||||
|
|
||||||
assert self._docid_to_indices is not None
|
|
||||||
candidate_doc_ids = list(self._docid_to_indices.keys())
|
|
||||||
|
|
||||||
def _score_one(doc_id: int) -> tuple[float, int]:
|
|
||||||
token_indices = self._docid_to_indices.get(doc_id, [])
|
|
||||||
if not token_indices:
|
|
||||||
return (0.0, doc_id)
|
|
||||||
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
|
||||||
sim = np.dot(data, doc_vecs.T)
|
|
||||||
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
|
||||||
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
|
||||||
return (float(score), doc_id)
|
|
||||||
|
|
||||||
scores: list[tuple[float, int]] = []
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
|
||||||
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
|
|
||||||
for fut in concurrent.futures.as_completed(futures):
|
|
||||||
scores.append(fut.result())
|
|
||||||
|
|
||||||
scores.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
return scores[:topk] if len(scores) >= topk else scores
|
|
||||||
|
|
||||||
def get_image(self, doc_id: int) -> Optional[Image.Image]:
|
|
||||||
"""
|
|
||||||
Retrieve the original image for a given doc_id from the index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: The document ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PIL Image object if found, None otherwise
|
|
||||||
"""
|
|
||||||
self._load_labels_meta_if_needed()
|
|
||||||
|
|
||||||
# Find the image_path for this doc_id (all seq_ids for same doc share the same image_path)
|
|
||||||
for meta in self._labels_meta:
|
|
||||||
if meta.get("doc_id") == doc_id:
|
|
||||||
image_path = meta.get("image_path", "")
|
|
||||||
if image_path and Path(image_path).exists():
|
|
||||||
return Image.open(image_path)
|
|
||||||
break
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_metadata(self, doc_id: int) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Retrieve metadata for a given doc_id.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: The document ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise
|
|
||||||
"""
|
|
||||||
self._load_labels_meta_if_needed()
|
|
||||||
|
|
||||||
for meta in self._labels_meta:
|
|
||||||
if meta.get("doc_id") == doc_id:
|
|
||||||
return {
|
|
||||||
"doc_id": doc_id,
|
|
||||||
"filepath": meta.get("filepath", ""),
|
|
||||||
"image_path": meta.get("image_path", ""),
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -2,31 +2,34 @@
|
|||||||
# %%
|
# %%
|
||||||
# uv pip install matplotlib qwen_vl_utils
|
# uv pip install matplotlib qwen_vl_utils
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
from leann_multi_vector import ( # utility functions/classes
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
_ensure_repo_paths_importable,
|
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||||
_load_images_from_dir,
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
_maybe_convert_pdf_to_images,
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
_load_colvision,
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
_embed_images,
|
if str(_leann_core_src) not in sys.path:
|
||||||
_embed_queries,
|
sys.path.append(str(_leann_core_src))
|
||||||
_build_index,
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
_load_retriever_if_index_exists,
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
_generate_similarity_map,
|
|
||||||
QwenVL,
|
|
||||||
)
|
|
||||||
|
|
||||||
_ensure_repo_paths_importable(__file__)
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Config
|
# Config
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
|
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
|
||||||
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||||
|
|
||||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||||
@@ -41,7 +44,7 @@ PAGES_DIR: str = "./pages"
|
|||||||
|
|
||||||
# Index + retrieval settings
|
# Index + retrieval settings
|
||||||
INDEX_PATH: str = "./indexes/colvision.leann"
|
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||||
TOPK: int = 3
|
TOPK: int = 1
|
||||||
FIRST_STAGE_K: int = 500
|
FIRST_STAGE_K: int = 500
|
||||||
REBUILD_INDEX: bool = False
|
REBUILD_INDEX: bool = False
|
||||||
|
|
||||||
@@ -51,28 +54,307 @@ SIMILARITY_MAP: bool = True
|
|||||||
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||||
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||||
ANSWER: bool = True
|
ANSWER: bool = True
|
||||||
MAX_NEW_TOKENS: int = 1024
|
MAX_NEW_TOKENS: int = 128
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Helpers
|
||||||
|
def _natural_sort_key(name: str) -> int:
|
||||||
|
m = re.search(r"\d+", name)
|
||||||
|
return int(m.group()) if m else 0
|
||||||
|
|
||||||
|
|
||||||
|
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||||
|
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||||
|
filenames = sorted(filenames, key=_natural_sort_key)
|
||||||
|
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||||
|
images = [Image.open(p) for p in filepaths]
|
||||||
|
return filepaths, images
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||||
|
if not pdf_path:
|
||||||
|
return
|
||||||
|
os.makedirs(pages_dir, exist_ok=True)
|
||||||
|
try:
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||||
|
) from e
|
||||||
|
images = convert_from_path(pdf_path, dpi=dpi)
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||||
|
|
||||||
|
|
||||||
|
def _select_device_and_dtype():
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import get_torch_device
|
||||||
|
|
||||||
|
device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(device_str)
|
||||||
|
# Stable dtype selection to avoid NaNs:
|
||||||
|
# - CUDA: prefer bfloat16 if supported, else float16
|
||||||
|
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||||
|
# - CPU: float32
|
||||||
|
if device_str == "cuda":
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
|
try:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif device_str == "mps":
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
return device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _load_colvision(model_choice: str):
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
device_str, device, dtype = _select_device_and_dtype()
|
||||||
|
|
||||||
|
if model_choice == "colqwen2":
|
||||||
|
model_name = "vidore/colqwen2-v1.0"
|
||||||
|
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||||
|
attn_implementation = (
|
||||||
|
"flash_attention_2"
|
||||||
|
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||||
|
else "eager"
|
||||||
|
)
|
||||||
|
model = ColQwen2.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
).eval()
|
||||||
|
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||||
|
else:
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
return model_name, model, processor, device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Ensure deterministic eval and autocast for stability
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[Image.Image](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_vecs: list[Any] = []
|
||||||
|
for batch_doc in tqdm(dataloader, desc="Embedding images"):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
else:
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
return doc_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
q_vecs: list[Any] = []
|
||||||
|
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
else:
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
return q_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
|
||||||
|
dim = int(doc_vecs[0].shape[-1])
|
||||||
|
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
retriever.create_collection()
|
||||||
|
for i, vec in enumerate(doc_vecs):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": vec.float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
||||||
|
index_base = Path(index_path)
|
||||||
|
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||||
|
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||||
|
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||||
|
if index_base.exists() and meta.exists() and labels.exists():
|
||||||
|
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_similarity_map(
|
||||||
|
model,
|
||||||
|
processor,
|
||||||
|
image: Image.Image,
|
||||||
|
query: str,
|
||||||
|
token_idx: Optional[int] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> tuple[int, float]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.interpretability import (
|
||||||
|
get_similarity_maps_from_embeddings,
|
||||||
|
plot_similarity_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_images = processor.process_images([image]).to(model.device)
|
||||||
|
batch_queries = processor.process_queries([query]).to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image_embeddings = model.forward(**batch_images)
|
||||||
|
query_embeddings = model.forward(**batch_queries)
|
||||||
|
|
||||||
|
n_patches = processor.get_n_patches(
|
||||||
|
image_size=image.size,
|
||||||
|
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||||
|
)
|
||||||
|
image_mask = processor.get_image_mask(batch_images)
|
||||||
|
|
||||||
|
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
n_patches=n_patches,
|
||||||
|
image_mask=image_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_maps = batched_similarity_maps[0]
|
||||||
|
|
||||||
|
# Determine token index if not provided: choose the token with highest max score
|
||||||
|
if token_idx is None:
|
||||||
|
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||||
|
token_idx = int(per_token_max.argmax().item())
|
||||||
|
|
||||||
|
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, ax = plot_similarity_map(
|
||||||
|
image=image,
|
||||||
|
similarity_map=similarity_maps[token_idx],
|
||||||
|
figsize=(14, 14),
|
||||||
|
show_colorbar=False,
|
||||||
|
)
|
||||||
|
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
plt.savefig(output_path, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return token_idx, float(max_sim_score)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVL:
|
||||||
|
def __init__(self, device: str):
|
||||||
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
)
|
||||||
|
|
||||||
|
min_pixels = 256 * 28 * 28
|
||||||
|
max_pixels = 1280 * 28 * 28
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
|
||||||
|
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for img in images:
|
||||||
|
buffer = BytesIO()
|
||||||
|
img.save(buffer, format="jpeg")
|
||||||
|
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||||
|
content.append({"type": "text", "text": query})
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
image_inputs, video_inputs = process_vision_info(messages)
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
inputs = inputs.to(self.model.device)
|
||||||
|
|
||||||
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
return self.processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
# Step 1: Check if we can skip data loading (index already exists)
|
# Step 1: Prepare data
|
||||||
retriever: Optional[Any] = None
|
|
||||||
need_to_build_index = REBUILD_INDEX
|
|
||||||
|
|
||||||
if not REBUILD_INDEX:
|
|
||||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
|
||||||
if retriever is not None:
|
|
||||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
|
||||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
|
||||||
need_to_build_index = False
|
|
||||||
else:
|
|
||||||
print(f"Index not found, will build new index")
|
|
||||||
need_to_build_index = True
|
|
||||||
|
|
||||||
# Step 2: Load data only if we need to build the index
|
|
||||||
if need_to_build_index:
|
|
||||||
print("Loading dataset...")
|
|
||||||
if USE_HF_DATASET:
|
if USE_HF_DATASET:
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
@@ -84,6 +366,7 @@ if need_to_build_index:
|
|||||||
p = dataset[i]
|
p = dataset[i]
|
||||||
# Compose a descriptive identifier for printing later
|
# Compose a descriptive identifier for printing later
|
||||||
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||||
|
print(identifier)
|
||||||
filepaths.append(identifier)
|
filepaths.append(identifier)
|
||||||
images.append(p["page_image"]) # PIL Image
|
images.append(p["page_image"]) # PIL Image
|
||||||
else:
|
else:
|
||||||
@@ -93,15 +376,10 @@ if need_to_build_index:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||||
)
|
)
|
||||||
print(f"Loaded {len(images)} images")
|
|
||||||
else:
|
|
||||||
print("Skipping dataset loading (using existing index)")
|
|
||||||
filepaths = [] # Not needed when using existing index
|
|
||||||
images = [] # Not needed when using existing index
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 3: Load model and processor (only if we need to build index or perform search)
|
# Step 2: Load model and processor
|
||||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||||
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||||
|
|
||||||
@@ -109,39 +387,34 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
|||||||
# %%
|
# %%
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 4: Build index if needed
|
# Step 3: Build or load index
|
||||||
if need_to_build_index and retriever is None:
|
retriever: Optional[LeannMultiVector] = None
|
||||||
print("Building index...")
|
if not REBUILD_INDEX:
|
||||||
doc_vecs = _embed_images(model, processor, images)
|
try:
|
||||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
one_vec = _embed_images(model, processor, [images[0]])[0]
|
||||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
||||||
# Clear memory
|
except Exception:
|
||||||
del images, filepaths, doc_vecs
|
retriever = None
|
||||||
|
|
||||||
# Note: Images are now stored in the index, retriever will load them on-demand from disk
|
if retriever is None:
|
||||||
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 5: Embed query and search
|
# Step 4: Embed query and search
|
||||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||||
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
|
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||||
if not results:
|
if not results:
|
||||||
print("No results found.")
|
print("No results found.")
|
||||||
else:
|
else:
|
||||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||||
top_images: list[Image.Image] = []
|
top_images: list[Image.Image] = []
|
||||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||||
# Retrieve image from index instead of memory
|
path = filepaths[doc_id]
|
||||||
image = retriever.get_image(doc_id)
|
|
||||||
if image is None:
|
|
||||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
metadata = retriever.get_metadata(doc_id)
|
|
||||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
|
||||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||||
top_images.append(image)
|
top_images.append(images[doc_id])
|
||||||
|
|
||||||
if SAVE_TOP_IMAGE:
|
if SAVE_TOP_IMAGE:
|
||||||
from pathlib import Path as _Path
|
from pathlib import Path as _Path
|
||||||
@@ -154,17 +427,12 @@ else:
|
|||||||
else:
|
else:
|
||||||
out_path = base / f"retrieved_page_rank{rank}.png"
|
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||||
img.save(str(out_path))
|
img.save(str(out_path))
|
||||||
# Print the retrieval score (document-level MaxSim) alongside the saved path
|
|
||||||
try:
|
|
||||||
score, _doc_id = results[rank - 1]
|
|
||||||
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
|
|
||||||
except Exception:
|
|
||||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||||
|
|
||||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 6: Similarity maps for top-K results
|
# Step 5: Similarity maps for top-K results
|
||||||
if results and SIMILARITY_MAP:
|
if results and SIMILARITY_MAP:
|
||||||
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||||
from pathlib import Path as _Path
|
from pathlib import Path as _Path
|
||||||
@@ -201,7 +469,7 @@ if results and SIMILARITY_MAP:
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 7: Optional answer generation
|
# Step 6: Optional answer generation
|
||||||
if results and ANSWER:
|
if results and ANSWER:
|
||||||
qwen = QwenVL(device=device_str)
|
qwen = QwenVL(device=device_str)
|
||||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||||
|
|||||||
@@ -1,183 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from leann import LeannSearcher
|
|
||||||
|
|
||||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
|
||||||
|
|
||||||
|
|
||||||
class TimeParser:
|
|
||||||
def __init__(self):
|
|
||||||
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
|
|
||||||
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
|
|
||||||
|
|
||||||
# Compile for performance
|
|
||||||
self.regex = re.compile(self.pattern, re.IGNORECASE)
|
|
||||||
|
|
||||||
# Stop words to remove before regex parsing
|
|
||||||
self.stop_words = {
|
|
||||||
"in",
|
|
||||||
"at",
|
|
||||||
"of",
|
|
||||||
"by",
|
|
||||||
"as",
|
|
||||||
"me",
|
|
||||||
"the",
|
|
||||||
"a",
|
|
||||||
"an",
|
|
||||||
"and",
|
|
||||||
"any",
|
|
||||||
"find",
|
|
||||||
"search",
|
|
||||||
"list",
|
|
||||||
"ago",
|
|
||||||
"back",
|
|
||||||
"past",
|
|
||||||
"earlier",
|
|
||||||
}
|
|
||||||
|
|
||||||
def clean_text(self, text):
|
|
||||||
"""Remove stop words from text"""
|
|
||||||
words = text.split()
|
|
||||||
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
def parse(self, text):
|
|
||||||
"""Extract all time expressions from text"""
|
|
||||||
# Clean text first
|
|
||||||
cleaned_text = self.clean_text(text)
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
for match in self.regex.finditer(cleaned_text):
|
|
||||||
fuzzy = match.group(1) # "around", "about", etc.
|
|
||||||
number = int(match.group(2))
|
|
||||||
unit = match.group(3).lower()
|
|
||||||
|
|
||||||
matches.append(
|
|
||||||
{
|
|
||||||
"full_match": match.group(0),
|
|
||||||
"fuzzy": bool(fuzzy),
|
|
||||||
"number": number,
|
|
||||||
"unit": unit,
|
|
||||||
"range": self.calculate_range(number, unit, bool(fuzzy)),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
def calculate_range(self, number, unit, is_fuzzy):
|
|
||||||
"""Convert to actual datetime range and return ISO format strings"""
|
|
||||||
units = {
|
|
||||||
"hour": timedelta(hours=number),
|
|
||||||
"day": timedelta(days=number),
|
|
||||||
"week": timedelta(weeks=number),
|
|
||||||
"month": timedelta(days=number * 30),
|
|
||||||
"year": timedelta(days=number * 365),
|
|
||||||
}
|
|
||||||
|
|
||||||
delta = units[unit]
|
|
||||||
now = datetime.now()
|
|
||||||
target = now - delta
|
|
||||||
|
|
||||||
if is_fuzzy:
|
|
||||||
buffer = delta * 0.2 # 20% buffer for fuzzy
|
|
||||||
start = (target - buffer).isoformat()
|
|
||||||
end = (target + buffer).isoformat()
|
|
||||||
else:
|
|
||||||
start = target.isoformat()
|
|
||||||
end = now.isoformat()
|
|
||||||
|
|
||||||
return (start, end)
|
|
||||||
|
|
||||||
|
|
||||||
def search_files(query, top_k=15):
|
|
||||||
"""Search the index and return results"""
|
|
||||||
# Parse time expressions
|
|
||||||
parser = TimeParser()
|
|
||||||
time_matches = parser.parse(query)
|
|
||||||
|
|
||||||
# Remove time expressions from query for semantic search
|
|
||||||
clean_query = query
|
|
||||||
if time_matches:
|
|
||||||
for match in time_matches:
|
|
||||||
clean_query = clean_query.replace(match["full_match"], "").strip()
|
|
||||||
|
|
||||||
# Check if clean_query is less than 4 characters
|
|
||||||
if len(clean_query) < 4:
|
|
||||||
print("Error: add more input for accurate results.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Single query to vector DB
|
|
||||||
searcher = LeannSearcher(INDEX_PATH)
|
|
||||||
results = searcher.search(
|
|
||||||
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter by time if time expression found
|
|
||||||
if time_matches:
|
|
||||||
time_range = time_matches[0]["range"] # Use first time expression
|
|
||||||
start_time, end_time = time_range
|
|
||||||
|
|
||||||
filtered_results = []
|
|
||||||
for result in results:
|
|
||||||
# Access metadata attribute directly (not .get())
|
|
||||||
metadata = result.metadata if hasattr(result, "metadata") else {}
|
|
||||||
|
|
||||||
if metadata:
|
|
||||||
# Check modification date first, fall back to creation date
|
|
||||||
date_str = metadata.get("modification_date") or metadata.get("creation_date")
|
|
||||||
|
|
||||||
if date_str:
|
|
||||||
# Convert strings to datetime objects for proper comparison
|
|
||||||
try:
|
|
||||||
file_date = datetime.fromisoformat(date_str)
|
|
||||||
start_dt = datetime.fromisoformat(start_time)
|
|
||||||
end_dt = datetime.fromisoformat(end_time)
|
|
||||||
|
|
||||||
# Compare dates properly
|
|
||||||
if start_dt <= file_date <= end_dt:
|
|
||||||
filtered_results.append(result)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
# Handle invalid date formats
|
|
||||||
print(f"Warning: Invalid date format in metadata: {date_str}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
results = filtered_results
|
|
||||||
|
|
||||||
# Print results
|
|
||||||
print(f"\nSearch results for: '{query}'")
|
|
||||||
if time_matches:
|
|
||||||
print(
|
|
||||||
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
|
|
||||||
)
|
|
||||||
print("-" * 80)
|
|
||||||
|
|
||||||
for i, result in enumerate(results, 1):
|
|
||||||
print(f"\n[{i}] Score: {result.score:.4f}")
|
|
||||||
print(f"Content: {result.text}")
|
|
||||||
|
|
||||||
# Show metadata if present
|
|
||||||
metadata = result.metadata if hasattr(result, "metadata") else None
|
|
||||||
if metadata:
|
|
||||||
if "creation_date" in metadata:
|
|
||||||
print(f"Created: {metadata['creation_date']}")
|
|
||||||
if "modification_date" in metadata:
|
|
||||||
print(f"Modified: {metadata['modification_date']}")
|
|
||||||
print("-" * 80)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if len(sys.argv) < 2:
|
|
||||||
print('Usage: python search_index.py "<search query>" [top_k]')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
query = sys.argv[1]
|
|
||||||
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
|
|
||||||
|
|
||||||
search_files(query, top_k)
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from leann import LeannBuilder
|
|
||||||
|
|
||||||
|
|
||||||
def process_json_items(json_file_path):
|
|
||||||
"""Load and process JSON file with metadata items"""
|
|
||||||
|
|
||||||
with open(json_file_path, encoding="utf-8") as f:
|
|
||||||
items = json.load(f)
|
|
||||||
|
|
||||||
# Guard against empty JSON
|
|
||||||
if not items:
|
|
||||||
print("⚠️ No items found in the JSON file. Exiting gracefully.")
|
|
||||||
return
|
|
||||||
|
|
||||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
|
||||||
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
|
|
||||||
|
|
||||||
total_items = len(items)
|
|
||||||
items_added = 0
|
|
||||||
print(f"Processing {total_items} items...")
|
|
||||||
|
|
||||||
for idx, item in enumerate(items):
|
|
||||||
try:
|
|
||||||
# Create embedding text sentence
|
|
||||||
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
|
|
||||||
|
|
||||||
# Prepare metadata with dates
|
|
||||||
metadata = {}
|
|
||||||
if "CreationDate" in item:
|
|
||||||
metadata["creation_date"] = item["CreationDate"]
|
|
||||||
if "ContentChangeDate" in item:
|
|
||||||
metadata["modification_date"] = item["ContentChangeDate"]
|
|
||||||
|
|
||||||
# Add to builder
|
|
||||||
builder.add_text(embedding_text, metadata=metadata)
|
|
||||||
items_added += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Show progress
|
|
||||||
progress = (idx + 1) / total_items * 100
|
|
||||||
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
print() # New line after progress
|
|
||||||
|
|
||||||
# Guard against no successfully added items
|
|
||||||
if items_added == 0:
|
|
||||||
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
|
|
||||||
print("Building index...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
print(f"✓ Index saved to {INDEX_PATH}")
|
|
||||||
except ValueError as e:
|
|
||||||
if "No chunks added" in str(e):
|
|
||||||
print("⚠️ No chunks were added to the builder. Index not created.")
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if len(sys.argv) != 2:
|
|
||||||
print("Usage: python build_index.py <json_file>")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
json_file = sys.argv[1]
|
|
||||||
if not Path(json_file).exists():
|
|
||||||
print(f"Error: File {json_file} not found")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
process_json_items(json_file)
|
|
||||||
@@ -1,265 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Spotlight Metadata Dumper for Vector DB
|
|
||||||
Extracts only essential metadata for semantic search embeddings
|
|
||||||
Output is optimized for vector database storage with minimal fields
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# Check platform before importing macOS-specific modules
|
|
||||||
if sys.platform != "darwin":
|
|
||||||
print("This script requires macOS (uses Spotlight)")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
|
|
||||||
|
|
||||||
# EDIT THIS LIST: Add or remove folders to search
|
|
||||||
# Can be either:
|
|
||||||
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
|
|
||||||
# - Absolute paths (e.g., "/Applications", "/System/Library")
|
|
||||||
SEARCH_FOLDERS = [
|
|
||||||
"Desktop",
|
|
||||||
"Downloads",
|
|
||||||
"Documents",
|
|
||||||
"Music",
|
|
||||||
"Pictures",
|
|
||||||
"Movies",
|
|
||||||
# "Library", # Uncomment to include
|
|
||||||
# "/Applications", # Absolute path example
|
|
||||||
# "Code/Projects", # Subfolder example
|
|
||||||
# Add any other folders here
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_serializable(obj):
|
|
||||||
"""Convert NS objects to Python serializable types"""
|
|
||||||
if obj is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle NSDate
|
|
||||||
if hasattr(obj, "timeIntervalSince1970"):
|
|
||||||
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
|
|
||||||
|
|
||||||
# Handle NSArray
|
|
||||||
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
|
|
||||||
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
|
|
||||||
|
|
||||||
# Convert to string
|
|
||||||
try:
|
|
||||||
return str(obj)
|
|
||||||
except Exception:
|
|
||||||
return repr(obj)
|
|
||||||
|
|
||||||
|
|
||||||
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
|
|
||||||
"""
|
|
||||||
Dump Spotlight data using public.item predicate
|
|
||||||
"""
|
|
||||||
# Build full paths from SEARCH_FOLDERS
|
|
||||||
import os
|
|
||||||
|
|
||||||
home_dir = os.path.expanduser("~")
|
|
||||||
search_paths = []
|
|
||||||
|
|
||||||
print("Search locations:")
|
|
||||||
for folder in SEARCH_FOLDERS:
|
|
||||||
# Check if it's an absolute path or relative
|
|
||||||
if folder.startswith("/"):
|
|
||||||
full_path = folder
|
|
||||||
else:
|
|
||||||
full_path = os.path.join(home_dir, folder)
|
|
||||||
|
|
||||||
if os.path.exists(full_path):
|
|
||||||
search_paths.append(full_path)
|
|
||||||
print(f" ✓ {full_path}")
|
|
||||||
else:
|
|
||||||
print(f" ✗ {full_path} (not found)")
|
|
||||||
|
|
||||||
if not search_paths:
|
|
||||||
print("No valid search paths found!")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
|
|
||||||
|
|
||||||
# Create query with public.item predicate
|
|
||||||
query = NSMetadataQuery.alloc().init()
|
|
||||||
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
|
|
||||||
query.setPredicate_(predicate)
|
|
||||||
|
|
||||||
# Set search scopes to our specific folders
|
|
||||||
query.setSearchScopes_(search_paths)
|
|
||||||
|
|
||||||
print("Starting query...")
|
|
||||||
query.startQuery()
|
|
||||||
|
|
||||||
# Wait for gathering to complete
|
|
||||||
run_loop = NSRunLoop.currentRunLoop()
|
|
||||||
print("Gathering results...")
|
|
||||||
|
|
||||||
# Let it gather for a few seconds
|
|
||||||
for i in range(50): # 5 seconds max
|
|
||||||
run_loop.runMode_beforeDate_(
|
|
||||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
|
||||||
)
|
|
||||||
# Check gathering status periodically
|
|
||||||
if i % 10 == 0:
|
|
||||||
current_count = query.resultCount()
|
|
||||||
if current_count > 0:
|
|
||||||
print(f" Found {current_count} items so far...")
|
|
||||||
|
|
||||||
# Continue while still gathering (up to 2 more seconds)
|
|
||||||
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
|
|
||||||
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
|
|
||||||
run_loop.runMode_beforeDate_(
|
|
||||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
|
||||||
)
|
|
||||||
|
|
||||||
query.stopQuery()
|
|
||||||
|
|
||||||
total_results = query.resultCount()
|
|
||||||
print(f"Found {total_results} total items")
|
|
||||||
|
|
||||||
if total_results == 0:
|
|
||||||
print("No results found")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Process items
|
|
||||||
items_to_process = min(total_results, max_items)
|
|
||||||
results = []
|
|
||||||
|
|
||||||
# ONLY relevant attributes for vector embeddings
|
|
||||||
# These provide essential context for semantic search without bloat
|
|
||||||
attributes = [
|
|
||||||
"kMDItemPath", # Full path for file retrieval
|
|
||||||
"kMDItemFSName", # Filename for display & embedding
|
|
||||||
"kMDItemFSSize", # Size for filtering/ranking
|
|
||||||
"kMDItemContentType", # File type for categorization
|
|
||||||
"kMDItemKind", # Human-readable type for embedding
|
|
||||||
"kMDItemFSCreationDate", # Temporal context
|
|
||||||
"kMDItemFSContentChangeDate", # Recency for ranking
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"Processing {items_to_process} items...")
|
|
||||||
|
|
||||||
for i in range(items_to_process):
|
|
||||||
try:
|
|
||||||
item = query.resultAtIndex_(i)
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
# Extract ONLY the relevant attributes
|
|
||||||
for attr in attributes:
|
|
||||||
try:
|
|
||||||
value = item.valueForAttribute_(attr)
|
|
||||||
if value is not None:
|
|
||||||
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
|
|
||||||
clean_key = attr.replace("kMDItem", "").replace("FS", "")
|
|
||||||
metadata[clean_key] = convert_to_serializable(value)
|
|
||||||
except (AttributeError, ValueError, TypeError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Only add if we have at least a path
|
|
||||||
if metadata.get("Path"):
|
|
||||||
results.append(metadata)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing item {i}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Save to JSON
|
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
print(f"\n✓ Saved {len(results)} items to {output_file}")
|
|
||||||
|
|
||||||
# Show summary
|
|
||||||
print("\nSample items:")
|
|
||||||
import os
|
|
||||||
|
|
||||||
home_dir = os.path.expanduser("~")
|
|
||||||
|
|
||||||
for i, item in enumerate(results[:3]):
|
|
||||||
print(f"\n[Item {i + 1}]")
|
|
||||||
print(f" Path: {item.get('Path', 'N/A')}")
|
|
||||||
print(f" Name: {item.get('Name', 'N/A')}")
|
|
||||||
print(f" Type: {item.get('ContentType', 'N/A')}")
|
|
||||||
print(f" Kind: {item.get('Kind', 'N/A')}")
|
|
||||||
|
|
||||||
# Handle size properly
|
|
||||||
size = item.get("Size")
|
|
||||||
if size:
|
|
||||||
try:
|
|
||||||
size_int = int(size)
|
|
||||||
if size_int > 1024 * 1024:
|
|
||||||
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
|
|
||||||
elif size_int > 1024:
|
|
||||||
print(f" Size: {size_int / 1024:.2f} KB")
|
|
||||||
else:
|
|
||||||
print(f" Size: {size_int} bytes")
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
print(f" Size: {size}")
|
|
||||||
|
|
||||||
# Show dates
|
|
||||||
if "CreationDate" in item:
|
|
||||||
print(f" Created: {item['CreationDate']}")
|
|
||||||
if "ContentChangeDate" in item:
|
|
||||||
print(f" Modified: {item['ContentChangeDate']}")
|
|
||||||
|
|
||||||
# Count by type
|
|
||||||
type_counts = {}
|
|
||||||
for item in results:
|
|
||||||
content_type = item.get("ContentType", "unknown")
|
|
||||||
type_counts[content_type] = type_counts.get(content_type, 0) + 1
|
|
||||||
|
|
||||||
print(f"\nTotal items saved: {len(results)}")
|
|
||||||
|
|
||||||
if type_counts:
|
|
||||||
print("\nTop content types:")
|
|
||||||
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
|
|
||||||
print(f" {ct}: {count} items")
|
|
||||||
|
|
||||||
# Count by folder
|
|
||||||
folder_counts = {}
|
|
||||||
for item in results:
|
|
||||||
path = item.get("Path", "")
|
|
||||||
for folder in SEARCH_FOLDERS:
|
|
||||||
# Build the full folder path
|
|
||||||
if folder.startswith("/"):
|
|
||||||
folder_path = folder
|
|
||||||
else:
|
|
||||||
folder_path = os.path.join(home_dir, folder)
|
|
||||||
|
|
||||||
if path.startswith(folder_path):
|
|
||||||
folder_counts[folder] = folder_counts.get(folder, 0) + 1
|
|
||||||
break
|
|
||||||
|
|
||||||
if folder_counts:
|
|
||||||
print("\nItems by location:")
|
|
||||||
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
|
|
||||||
print(f" {folder}: {count} items")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Parse arguments
|
|
||||||
if len(sys.argv) > 1:
|
|
||||||
try:
|
|
||||||
max_items = int(sys.argv[1])
|
|
||||||
except ValueError:
|
|
||||||
print("Usage: python spot.py [number_of_items]")
|
|
||||||
print("Default: 10 items")
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
max_items = 10
|
|
||||||
|
|
||||||
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
|
|
||||||
|
|
||||||
# Run dump
|
|
||||||
dump_spotlight_data(max_items=max_items, output_file=output_file)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# Slack MCP data integration for LEANN
|
|
||||||
@@ -1,510 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Slack MCP Reader for LEANN
|
|
||||||
|
|
||||||
This module provides functionality to connect to Slack MCP servers and fetch message data
|
|
||||||
for indexing in LEANN. It supports various Slack MCP server implementations and provides
|
|
||||||
flexible message processing options.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SlackMCPReader:
|
|
||||||
"""
|
|
||||||
Reader for Slack data via MCP (Model Context Protocol) servers.
|
|
||||||
|
|
||||||
This class connects to Slack MCP servers to fetch message data and convert it
|
|
||||||
into a format suitable for LEANN indexing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mcp_server_command: str,
|
|
||||||
workspace_name: Optional[str] = None,
|
|
||||||
concatenate_conversations: bool = True,
|
|
||||||
max_messages_per_conversation: int = 100,
|
|
||||||
max_retries: int = 5,
|
|
||||||
retry_delay: float = 2.0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the Slack MCP Reader.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mcp_server_command: Command to start the MCP server (e.g., 'slack-mcp-server')
|
|
||||||
workspace_name: Optional workspace name to filter messages
|
|
||||||
concatenate_conversations: Whether to group messages by channel/thread
|
|
||||||
max_messages_per_conversation: Maximum messages to include per conversation
|
|
||||||
max_retries: Maximum number of retries for failed operations
|
|
||||||
retry_delay: Initial delay between retries in seconds
|
|
||||||
"""
|
|
||||||
self.mcp_server_command = mcp_server_command
|
|
||||||
self.workspace_name = workspace_name
|
|
||||||
self.concatenate_conversations = concatenate_conversations
|
|
||||||
self.max_messages_per_conversation = max_messages_per_conversation
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.retry_delay = retry_delay
|
|
||||||
self.mcp_process = None
|
|
||||||
|
|
||||||
async def start_mcp_server(self):
|
|
||||||
"""Start the MCP server process."""
|
|
||||||
try:
|
|
||||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
|
||||||
*self.mcp_server_command.split(),
|
|
||||||
stdin=asyncio.subprocess.PIPE,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
)
|
|
||||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to start MCP server: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def stop_mcp_server(self):
|
|
||||||
"""Stop the MCP server process."""
|
|
||||||
if self.mcp_process:
|
|
||||||
self.mcp_process.terminate()
|
|
||||||
await self.mcp_process.wait()
|
|
||||||
logger.info("Stopped MCP server")
|
|
||||||
|
|
||||||
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Send a request to the MCP server and get response."""
|
|
||||||
if not self.mcp_process:
|
|
||||||
raise RuntimeError("MCP server not started")
|
|
||||||
|
|
||||||
request_json = json.dumps(request) + "\n"
|
|
||||||
self.mcp_process.stdin.write(request_json.encode())
|
|
||||||
await self.mcp_process.stdin.drain()
|
|
||||||
|
|
||||||
response_line = await self.mcp_process.stdout.readline()
|
|
||||||
if not response_line:
|
|
||||||
raise RuntimeError("No response from MCP server")
|
|
||||||
|
|
||||||
return json.loads(response_line.decode().strip())
|
|
||||||
|
|
||||||
async def initialize_mcp_connection(self):
|
|
||||||
"""Initialize the MCP connection."""
|
|
||||||
init_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 1,
|
|
||||||
"method": "initialize",
|
|
||||||
"params": {
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(init_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
|
||||||
|
|
||||||
logger.info("MCP connection initialized successfully")
|
|
||||||
|
|
||||||
async def list_available_tools(self) -> list[dict[str, Any]]:
|
|
||||||
"""List available tools from the MCP server."""
|
|
||||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(list_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
|
||||||
|
|
||||||
return response.get("result", {}).get("tools", [])
|
|
||||||
|
|
||||||
def _is_cache_sync_error(self, error: dict) -> bool:
|
|
||||||
"""Check if the error is related to users cache not being ready."""
|
|
||||||
if isinstance(error, dict):
|
|
||||||
message = error.get("message", "").lower()
|
|
||||||
return (
|
|
||||||
"users cache is not ready" in message or "sync process is still running" in message
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _retry_with_backoff(self, func, *args, **kwargs):
|
|
||||||
"""Retry a function with exponential backoff, especially for cache sync issues."""
|
|
||||||
last_exception = None
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
try:
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
|
|
||||||
# Check if this is a cache sync error
|
|
||||||
error_dict = {}
|
|
||||||
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
|
|
||||||
error_dict = e.args[0]
|
|
||||||
elif "Failed to fetch messages" in str(e):
|
|
||||||
# Try to extract error from the exception message
|
|
||||||
import re
|
|
||||||
|
|
||||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
error_dict = eval(match.group(1))
|
|
||||||
except (ValueError, SyntaxError, NameError):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Try alternative format
|
|
||||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
error_dict = eval(match.group(1))
|
|
||||||
except (ValueError, SyntaxError, NameError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self._is_cache_sync_error(error_dict):
|
|
||||||
if attempt < self.max_retries:
|
|
||||||
delay = self.retry_delay * (2**attempt) # Exponential backoff
|
|
||||||
logger.info(
|
|
||||||
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Cache sync still not ready after {self.max_retries} retries, giving up"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Not a cache sync error, don't retry
|
|
||||||
break
|
|
||||||
|
|
||||||
# If we get here, all retries failed or it's not a retryable error
|
|
||||||
raise last_exception
|
|
||||||
|
|
||||||
async def fetch_slack_messages(
|
|
||||||
self, channel: Optional[str] = None, limit: int = 100
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: Optional channel name to filter messages
|
|
||||||
limit: Maximum number of messages to fetch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of message dictionaries
|
|
||||||
"""
|
|
||||||
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
|
|
||||||
|
|
||||||
async def _fetch_slack_messages_impl(
|
|
||||||
self, channel: Optional[str] = None, limit: int = 100
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Internal implementation of fetch_slack_messages without retry logic.
|
|
||||||
"""
|
|
||||||
# This is a generic implementation - specific MCP servers may have different tool names
|
|
||||||
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
|
|
||||||
|
|
||||||
tools = await self.list_available_tools()
|
|
||||||
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
|
|
||||||
message_tool = None
|
|
||||||
|
|
||||||
# Look for a tool that can fetch messages - prioritize conversations_history
|
|
||||||
message_tool = None
|
|
||||||
|
|
||||||
# First, try to find conversations_history specifically
|
|
||||||
for tool in tools:
|
|
||||||
tool_name = tool.get("name", "").lower()
|
|
||||||
if "conversations_history" in tool_name:
|
|
||||||
message_tool = tool
|
|
||||||
logger.info(f"Found conversations_history tool: {tool}")
|
|
||||||
break
|
|
||||||
|
|
||||||
# If not found, look for other message-fetching tools
|
|
||||||
if not message_tool:
|
|
||||||
for tool in tools:
|
|
||||||
tool_name = tool.get("name", "").lower()
|
|
||||||
if any(
|
|
||||||
keyword in tool_name
|
|
||||||
for keyword in ["conversations_search", "message", "history"]
|
|
||||||
):
|
|
||||||
message_tool = tool
|
|
||||||
break
|
|
||||||
|
|
||||||
if not message_tool:
|
|
||||||
raise RuntimeError("No message fetching tool found in MCP server")
|
|
||||||
|
|
||||||
# Prepare tool call parameters
|
|
||||||
tool_params = {"limit": "180d"} # Use 180 days to get older messages
|
|
||||||
if channel:
|
|
||||||
# For conversations_history, use channel_id parameter
|
|
||||||
if message_tool["name"] == "conversations_history":
|
|
||||||
tool_params["channel_id"] = channel
|
|
||||||
else:
|
|
||||||
# Try common parameter names for channel specification
|
|
||||||
for param_name in ["channel", "channel_id", "channel_name"]:
|
|
||||||
tool_params[param_name] = channel
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info(f"Tool parameters: {tool_params}")
|
|
||||||
|
|
||||||
fetch_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 3,
|
|
||||||
"method": "tools/call",
|
|
||||||
"params": {"name": message_tool["name"], "arguments": tool_params},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(fetch_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"Failed to fetch messages: {response['error']}")
|
|
||||||
|
|
||||||
# Extract messages from response - format may vary by MCP server
|
|
||||||
result = response.get("result", {})
|
|
||||||
if "content" in result and isinstance(result["content"], list):
|
|
||||||
# Some MCP servers return content as a list
|
|
||||||
content = result["content"][0] if result["content"] else {}
|
|
||||||
if "text" in content:
|
|
||||||
try:
|
|
||||||
messages = json.loads(content["text"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# If not JSON, try to parse as CSV format (Slack MCP server format)
|
|
||||||
messages = self._parse_csv_messages(content["text"], channel)
|
|
||||||
else:
|
|
||||||
messages = result["content"]
|
|
||||||
else:
|
|
||||||
# Direct message format
|
|
||||||
messages = result.get("messages", [result])
|
|
||||||
|
|
||||||
return messages if isinstance(messages, list) else [messages]
|
|
||||||
|
|
||||||
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
|
|
||||||
"""Parse CSV format messages from Slack MCP server."""
|
|
||||||
import csv
|
|
||||||
import io
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
try:
|
|
||||||
# Split by lines and process each line as a CSV row
|
|
||||||
lines = csv_text.strip().split("\n")
|
|
||||||
if not lines:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
# Skip header line if it exists
|
|
||||||
start_idx = 0
|
|
||||||
if lines[0].startswith("MsgID,UserID,UserName"):
|
|
||||||
start_idx = 1
|
|
||||||
|
|
||||||
for line in lines[start_idx:]:
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Parse CSV line
|
|
||||||
reader = csv.reader(io.StringIO(line))
|
|
||||||
try:
|
|
||||||
row = next(reader)
|
|
||||||
if len(row) >= 7: # Ensure we have enough columns
|
|
||||||
message = {
|
|
||||||
"ts": row[0],
|
|
||||||
"user": row[1],
|
|
||||||
"username": row[2],
|
|
||||||
"real_name": row[3],
|
|
||||||
"channel": row[4],
|
|
||||||
"thread_ts": row[5],
|
|
||||||
"text": row[6],
|
|
||||||
"time": row[7] if len(row) > 7 else "",
|
|
||||||
"reactions": row[8] if len(row) > 8 else "",
|
|
||||||
"cursor": row[9] if len(row) > 9 else "",
|
|
||||||
}
|
|
||||||
messages.append(message)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to parse CSV messages: {e}")
|
|
||||||
# Fallback: treat entire text as one message
|
|
||||||
messages = [{"text": csv_text, "channel": channel or "unknown"}]
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def _format_message(self, message: dict[str, Any]) -> str:
|
|
||||||
"""Format a single message for indexing."""
|
|
||||||
text = message.get("text", "")
|
|
||||||
user = message.get("user", message.get("username", "Unknown"))
|
|
||||||
channel = message.get("channel", message.get("channel_name", "Unknown"))
|
|
||||||
timestamp = message.get("ts", message.get("timestamp", ""))
|
|
||||||
|
|
||||||
# Format timestamp if available
|
|
||||||
formatted_time = ""
|
|
||||||
if timestamp:
|
|
||||||
try:
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
if isinstance(timestamp, str) and "." in timestamp:
|
|
||||||
dt = datetime.datetime.fromtimestamp(float(timestamp))
|
|
||||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
elif isinstance(timestamp, (int, float)):
|
|
||||||
dt = datetime.datetime.fromtimestamp(timestamp)
|
|
||||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
else:
|
|
||||||
formatted_time = str(timestamp)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
formatted_time = str(timestamp)
|
|
||||||
|
|
||||||
# Build formatted message
|
|
||||||
parts = []
|
|
||||||
if channel:
|
|
||||||
parts.append(f"Channel: #{channel}")
|
|
||||||
if user:
|
|
||||||
parts.append(f"User: {user}")
|
|
||||||
if formatted_time:
|
|
||||||
parts.append(f"Time: {formatted_time}")
|
|
||||||
if text:
|
|
||||||
parts.append(f"Message: {text}")
|
|
||||||
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
def _create_concatenated_content(self, messages: list[dict[str, Any]], channel: str) -> str:
|
|
||||||
"""Create concatenated content from multiple messages in a channel."""
|
|
||||||
if not messages:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Sort messages by timestamp if available
|
|
||||||
try:
|
|
||||||
messages.sort(key=lambda x: float(x.get("ts", x.get("timestamp", 0))))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass # Keep original order if timestamps aren't numeric
|
|
||||||
|
|
||||||
# Limit messages per conversation
|
|
||||||
if len(messages) > self.max_messages_per_conversation:
|
|
||||||
messages = messages[-self.max_messages_per_conversation :]
|
|
||||||
|
|
||||||
# Create header
|
|
||||||
content_parts = [
|
|
||||||
f"Slack Channel: #{channel}",
|
|
||||||
f"Message Count: {len(messages)}",
|
|
||||||
f"Workspace: {self.workspace_name or 'Unknown'}",
|
|
||||||
"=" * 50,
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add messages
|
|
||||||
for message in messages:
|
|
||||||
formatted_msg = self._format_message(message)
|
|
||||||
if formatted_msg.strip():
|
|
||||||
content_parts.append(formatted_msg)
|
|
||||||
content_parts.append("-" * 30)
|
|
||||||
content_parts.append("")
|
|
||||||
|
|
||||||
return "\n".join(content_parts)
|
|
||||||
|
|
||||||
async def get_all_channels(self) -> list[str]:
|
|
||||||
"""Get list of all available channels."""
|
|
||||||
try:
|
|
||||||
channels_list_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 4,
|
|
||||||
"method": "tools/call",
|
|
||||||
"params": {"name": "channels_list", "arguments": {}},
|
|
||||||
}
|
|
||||||
channels_response = await self.send_mcp_request(channels_list_request)
|
|
||||||
if "result" in channels_response:
|
|
||||||
result = channels_response["result"]
|
|
||||||
if "content" in result and isinstance(result["content"], list):
|
|
||||||
content = result["content"][0] if result["content"] else {}
|
|
||||||
if "text" in content:
|
|
||||||
# Parse the channels from the response
|
|
||||||
channels = []
|
|
||||||
lines = content["text"].split("\n")
|
|
||||||
for line in lines:
|
|
||||||
if line.strip() and ("#" in line or "C" in line[:10]):
|
|
||||||
# Extract channel ID or name
|
|
||||||
parts = line.split()
|
|
||||||
for part in parts:
|
|
||||||
if part.startswith("C") and len(part) > 5:
|
|
||||||
channels.append(part)
|
|
||||||
elif part.startswith("#"):
|
|
||||||
channels.append(part[1:]) # Remove #
|
|
||||||
logger.info(f"Found {len(channels)} channels: {channels}")
|
|
||||||
return channels
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get channels list: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
|
|
||||||
"""
|
|
||||||
Read Slack data and return formatted text chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channels: Optional list of channel names to fetch. If None, fetches from all available channels.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of formatted text chunks ready for LEANN indexing
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await self.start_mcp_server()
|
|
||||||
await self.initialize_mcp_connection()
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
|
|
||||||
if channels:
|
|
||||||
# Fetch specific channels
|
|
||||||
for channel in channels:
|
|
||||||
try:
|
|
||||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
|
||||||
if messages:
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
text_content = self._create_concatenated_content(messages, channel)
|
|
||||||
if text_content.strip():
|
|
||||||
all_texts.append(text_content)
|
|
||||||
else:
|
|
||||||
# Process individual messages
|
|
||||||
for message in messages:
|
|
||||||
formatted_msg = self._format_message(message)
|
|
||||||
if formatted_msg.strip():
|
|
||||||
all_texts.append(formatted_msg)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# Fetch from all available channels
|
|
||||||
logger.info("Fetching from all available channels...")
|
|
||||||
all_channels = await self.get_all_channels()
|
|
||||||
|
|
||||||
if not all_channels:
|
|
||||||
# Fallback to common channel names if we can't get the list
|
|
||||||
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
|
|
||||||
logger.info(f"Using fallback channels: {all_channels}")
|
|
||||||
|
|
||||||
for channel in all_channels:
|
|
||||||
try:
|
|
||||||
logger.info(f"Searching channel: {channel}")
|
|
||||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
|
||||||
if messages:
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
text_content = self._create_concatenated_content(messages, channel)
|
|
||||||
if text_content.strip():
|
|
||||||
all_texts.append(text_content)
|
|
||||||
else:
|
|
||||||
# Process individual messages
|
|
||||||
for message in messages:
|
|
||||||
formatted_msg = self._format_message(message)
|
|
||||||
if formatted_msg.strip():
|
|
||||||
all_texts.append(formatted_msg)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await self.stop_mcp_server()
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
"""Async context manager entry."""
|
|
||||||
await self.start_mcp_server()
|
|
||||||
await self.initialize_mcp_connection()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
"""Async context manager exit."""
|
|
||||||
await self.stop_mcp_server()
|
|
||||||
@@ -1,227 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Slack RAG Application with MCP Support
|
|
||||||
|
|
||||||
This application enables RAG (Retrieval-Augmented Generation) on Slack messages
|
|
||||||
by connecting to Slack MCP servers to fetch live data and index it in LEANN.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --query "What did the team discuss about the project?"
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from apps.base_rag_example import BaseRAGExample
|
|
||||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
|
||||||
|
|
||||||
|
|
||||||
class SlackMCPRAG(BaseRAGExample):
|
|
||||||
"""
|
|
||||||
RAG application for Slack messages via MCP servers.
|
|
||||||
|
|
||||||
This class provides a complete RAG pipeline for Slack data, including
|
|
||||||
MCP server connection, data fetching, indexing, and interactive chat.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
name="Slack MCP RAG",
|
|
||||||
description="RAG application for Slack messages via MCP servers",
|
|
||||||
default_index_name="slack_messages",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
|
||||||
"""Add Slack MCP-specific arguments."""
|
|
||||||
parser.add_argument(
|
|
||||||
"--mcp-server",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Command to start the Slack MCP server (e.g., 'slack-mcp-server' or 'npx slack-mcp-server')",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--workspace-name",
|
|
||||||
type=str,
|
|
||||||
help="Slack workspace name for better organization and filtering",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--channels",
|
|
||||||
nargs="+",
|
|
||||||
help="Specific Slack channels to index (e.g., general random). If not specified, fetches from all available channels",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Group messages by channel/thread for better context (default: True)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
help="Process individual messages instead of grouping by channel",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-messages-per-channel",
|
|
||||||
type=int,
|
|
||||||
default=100,
|
|
||||||
help="Maximum number of messages to include per channel (default: 100)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--test-connection",
|
|
||||||
action="store_true",
|
|
||||||
help="Test MCP server connection and list available tools without indexing",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-retries",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Maximum number of retries for failed operations (default: 5)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--retry-delay",
|
|
||||||
type=float,
|
|
||||||
default=2.0,
|
|
||||||
help="Initial delay between retries in seconds (default: 2.0)",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_mcp_connection(self, args) -> bool:
|
|
||||||
"""Test the MCP server connection and display available tools."""
|
|
||||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader = SlackMCPReader(
|
|
||||||
mcp_server_command=args.mcp_server,
|
|
||||||
workspace_name=args.workspace_name,
|
|
||||||
concatenate_conversations=not args.no_concatenate_conversations,
|
|
||||||
max_messages_per_conversation=args.max_messages_per_channel,
|
|
||||||
max_retries=args.max_retries,
|
|
||||||
retry_delay=args.retry_delay,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with reader:
|
|
||||||
tools = await reader.list_available_tools()
|
|
||||||
|
|
||||||
print("Successfully connected to MCP server!")
|
|
||||||
print(f"Available tools ({len(tools)}):")
|
|
||||||
|
|
||||||
for i, tool in enumerate(tools, 1):
|
|
||||||
name = tool.get("name", "Unknown")
|
|
||||||
description = tool.get("description", "No description available")
|
|
||||||
print(f"\n{i}. {name}")
|
|
||||||
print(
|
|
||||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show input schema if available
|
|
||||||
schema = tool.get("inputSchema", {})
|
|
||||||
if schema.get("properties"):
|
|
||||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
|
||||||
print(
|
|
||||||
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to connect to MCP server: {e}")
|
|
||||||
print("\nTroubleshooting tips:")
|
|
||||||
print("1. Make sure the MCP server is installed and accessible")
|
|
||||||
print("2. Check if the server command is correct")
|
|
||||||
print("3. Ensure you have proper authentication/credentials configured")
|
|
||||||
print("4. Try running the MCP server command directly to test it")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load Slack messages via MCP server."""
|
|
||||||
print(f"Connecting to Slack MCP server: {args.mcp_server}")
|
|
||||||
|
|
||||||
if args.workspace_name:
|
|
||||||
print(f"Workspace: {args.workspace_name}")
|
|
||||||
|
|
||||||
# Filter out empty strings from channels
|
|
||||||
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
|
|
||||||
|
|
||||||
if channels:
|
|
||||||
print(f"Channels: {', '.join(channels)}")
|
|
||||||
else:
|
|
||||||
print("Fetching from all available channels")
|
|
||||||
|
|
||||||
concatenate = not args.no_concatenate_conversations
|
|
||||||
print(
|
|
||||||
f"Processing mode: {'Concatenated conversations' if concatenate else 'Individual messages'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader = SlackMCPReader(
|
|
||||||
mcp_server_command=args.mcp_server,
|
|
||||||
workspace_name=args.workspace_name,
|
|
||||||
concatenate_conversations=concatenate,
|
|
||||||
max_messages_per_conversation=args.max_messages_per_channel,
|
|
||||||
max_retries=args.max_retries,
|
|
||||||
retry_delay=args.retry_delay,
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = await reader.read_slack_data(channels=channels)
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
print("No messages found! This could mean:")
|
|
||||||
print("- The MCP server couldn't fetch messages")
|
|
||||||
print("- The specified channels don't exist or are empty")
|
|
||||||
print("- Authentication issues with the Slack workspace")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Successfully loaded {len(texts)} text chunks from Slack")
|
|
||||||
|
|
||||||
# Show sample of what was loaded
|
|
||||||
if texts:
|
|
||||||
sample_text = texts[0][:200] + "..." if len(texts[0]) > 200 else texts[0]
|
|
||||||
print("\nSample content:")
|
|
||||||
print("-" * 40)
|
|
||||||
print(sample_text)
|
|
||||||
print("-" * 40)
|
|
||||||
|
|
||||||
return texts
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading Slack data: {e}")
|
|
||||||
print("\nThis might be due to:")
|
|
||||||
print("- MCP server connection issues")
|
|
||||||
print("- Authentication problems")
|
|
||||||
print("- Network connectivity issues")
|
|
||||||
print("- Incorrect channel names")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""Main entry point with MCP connection testing."""
|
|
||||||
args = self.parser.parse_args()
|
|
||||||
|
|
||||||
# Test connection if requested
|
|
||||||
if args.test_connection:
|
|
||||||
success = await self.test_mcp_connection(args)
|
|
||||||
if not success:
|
|
||||||
return
|
|
||||||
print(
|
|
||||||
"MCP server is working! You can now run without --test-connection to start indexing."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Run the standard RAG pipeline
|
|
||||||
await super().run()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main entry point for the Slack MCP RAG application."""
|
|
||||||
app = SlackMCPRAG()
|
|
||||||
await app.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# Twitter MCP data integration for LEANN
|
|
||||||
@@ -1,295 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Twitter MCP Reader for LEANN
|
|
||||||
|
|
||||||
This module provides functionality to connect to Twitter MCP servers and fetch bookmark data
|
|
||||||
for indexing in LEANN. It supports various Twitter MCP server implementations and provides
|
|
||||||
flexible bookmark processing options.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TwitterMCPReader:
|
|
||||||
"""
|
|
||||||
Reader for Twitter bookmark data via MCP (Model Context Protocol) servers.
|
|
||||||
|
|
||||||
This class connects to Twitter MCP servers to fetch bookmark data and convert it
|
|
||||||
into a format suitable for LEANN indexing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mcp_server_command: str,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
include_tweet_content: bool = True,
|
|
||||||
include_metadata: bool = True,
|
|
||||||
max_bookmarks: int = 1000,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the Twitter MCP Reader.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mcp_server_command: Command to start the MCP server (e.g., 'twitter-mcp-server')
|
|
||||||
username: Optional Twitter username to filter bookmarks
|
|
||||||
include_tweet_content: Whether to include full tweet content
|
|
||||||
include_metadata: Whether to include tweet metadata (likes, retweets, etc.)
|
|
||||||
max_bookmarks: Maximum number of bookmarks to fetch
|
|
||||||
"""
|
|
||||||
self.mcp_server_command = mcp_server_command
|
|
||||||
self.username = username
|
|
||||||
self.include_tweet_content = include_tweet_content
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
self.max_bookmarks = max_bookmarks
|
|
||||||
self.mcp_process = None
|
|
||||||
|
|
||||||
async def start_mcp_server(self):
|
|
||||||
"""Start the MCP server process."""
|
|
||||||
try:
|
|
||||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
|
||||||
*self.mcp_server_command.split(),
|
|
||||||
stdin=asyncio.subprocess.PIPE,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
)
|
|
||||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to start MCP server: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def stop_mcp_server(self):
|
|
||||||
"""Stop the MCP server process."""
|
|
||||||
if self.mcp_process:
|
|
||||||
self.mcp_process.terminate()
|
|
||||||
await self.mcp_process.wait()
|
|
||||||
logger.info("Stopped MCP server")
|
|
||||||
|
|
||||||
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Send a request to the MCP server and get response."""
|
|
||||||
if not self.mcp_process:
|
|
||||||
raise RuntimeError("MCP server not started")
|
|
||||||
|
|
||||||
request_json = json.dumps(request) + "\n"
|
|
||||||
self.mcp_process.stdin.write(request_json.encode())
|
|
||||||
await self.mcp_process.stdin.drain()
|
|
||||||
|
|
||||||
response_line = await self.mcp_process.stdout.readline()
|
|
||||||
if not response_line:
|
|
||||||
raise RuntimeError("No response from MCP server")
|
|
||||||
|
|
||||||
return json.loads(response_line.decode().strip())
|
|
||||||
|
|
||||||
async def initialize_mcp_connection(self):
|
|
||||||
"""Initialize the MCP connection."""
|
|
||||||
init_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 1,
|
|
||||||
"method": "initialize",
|
|
||||||
"params": {
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "leann-twitter-reader", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(init_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
|
||||||
|
|
||||||
logger.info("MCP connection initialized successfully")
|
|
||||||
|
|
||||||
async def list_available_tools(self) -> list[dict[str, Any]]:
|
|
||||||
"""List available tools from the MCP server."""
|
|
||||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(list_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
|
||||||
|
|
||||||
return response.get("result", {}).get("tools", [])
|
|
||||||
|
|
||||||
async def fetch_twitter_bookmarks(self, limit: Optional[int] = None) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Fetch Twitter bookmarks using MCP tools.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
limit: Maximum number of bookmarks to fetch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of bookmark dictionaries
|
|
||||||
"""
|
|
||||||
tools = await self.list_available_tools()
|
|
||||||
bookmark_tool = None
|
|
||||||
|
|
||||||
# Look for a tool that can fetch bookmarks
|
|
||||||
for tool in tools:
|
|
||||||
tool_name = tool.get("name", "").lower()
|
|
||||||
if any(keyword in tool_name for keyword in ["bookmark", "saved", "favorite"]):
|
|
||||||
bookmark_tool = tool
|
|
||||||
break
|
|
||||||
|
|
||||||
if not bookmark_tool:
|
|
||||||
raise RuntimeError("No bookmark fetching tool found in MCP server")
|
|
||||||
|
|
||||||
# Prepare tool call parameters
|
|
||||||
tool_params = {}
|
|
||||||
if limit or self.max_bookmarks:
|
|
||||||
tool_params["limit"] = limit or self.max_bookmarks
|
|
||||||
if self.username:
|
|
||||||
tool_params["username"] = self.username
|
|
||||||
|
|
||||||
fetch_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 3,
|
|
||||||
"method": "tools/call",
|
|
||||||
"params": {"name": bookmark_tool["name"], "arguments": tool_params},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(fetch_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"Failed to fetch bookmarks: {response['error']}")
|
|
||||||
|
|
||||||
# Extract bookmarks from response
|
|
||||||
result = response.get("result", {})
|
|
||||||
if "content" in result and isinstance(result["content"], list):
|
|
||||||
content = result["content"][0] if result["content"] else {}
|
|
||||||
if "text" in content:
|
|
||||||
try:
|
|
||||||
bookmarks = json.loads(content["text"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# If not JSON, treat as plain text
|
|
||||||
bookmarks = [{"text": content["text"], "source": "twitter"}]
|
|
||||||
else:
|
|
||||||
bookmarks = result["content"]
|
|
||||||
else:
|
|
||||||
bookmarks = result.get("bookmarks", result.get("tweets", [result]))
|
|
||||||
|
|
||||||
return bookmarks if isinstance(bookmarks, list) else [bookmarks]
|
|
||||||
|
|
||||||
def _format_bookmark(self, bookmark: dict[str, Any]) -> str:
|
|
||||||
"""Format a single bookmark for indexing."""
|
|
||||||
# Extract tweet information
|
|
||||||
text = bookmark.get("text", bookmark.get("content", ""))
|
|
||||||
author = bookmark.get(
|
|
||||||
"author", bookmark.get("username", bookmark.get("user", {}).get("username", "Unknown"))
|
|
||||||
)
|
|
||||||
timestamp = bookmark.get("created_at", bookmark.get("timestamp", ""))
|
|
||||||
url = bookmark.get("url", bookmark.get("tweet_url", ""))
|
|
||||||
|
|
||||||
# Extract metadata if available
|
|
||||||
likes = bookmark.get("likes", bookmark.get("favorite_count", 0))
|
|
||||||
retweets = bookmark.get("retweets", bookmark.get("retweet_count", 0))
|
|
||||||
replies = bookmark.get("replies", bookmark.get("reply_count", 0))
|
|
||||||
|
|
||||||
# Build formatted bookmark
|
|
||||||
parts = []
|
|
||||||
|
|
||||||
# Header
|
|
||||||
parts.append("=== Twitter Bookmark ===")
|
|
||||||
|
|
||||||
if author:
|
|
||||||
parts.append(f"Author: @{author}")
|
|
||||||
|
|
||||||
if timestamp:
|
|
||||||
# Format timestamp if it's a standard format
|
|
||||||
try:
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
if "T" in str(timestamp): # ISO format
|
|
||||||
dt = datetime.datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
|
||||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
else:
|
|
||||||
formatted_time = str(timestamp)
|
|
||||||
parts.append(f"Date: {formatted_time}")
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
parts.append(f"Date: {timestamp}")
|
|
||||||
|
|
||||||
if url:
|
|
||||||
parts.append(f"URL: {url}")
|
|
||||||
|
|
||||||
# Tweet content
|
|
||||||
if text and self.include_tweet_content:
|
|
||||||
parts.append("")
|
|
||||||
parts.append("Content:")
|
|
||||||
parts.append(text)
|
|
||||||
|
|
||||||
# Metadata
|
|
||||||
if self.include_metadata and any([likes, retweets, replies]):
|
|
||||||
parts.append("")
|
|
||||||
parts.append("Engagement:")
|
|
||||||
if likes:
|
|
||||||
parts.append(f" Likes: {likes}")
|
|
||||||
if retweets:
|
|
||||||
parts.append(f" Retweets: {retweets}")
|
|
||||||
if replies:
|
|
||||||
parts.append(f" Replies: {replies}")
|
|
||||||
|
|
||||||
# Extract hashtags and mentions if available
|
|
||||||
hashtags = bookmark.get("hashtags", [])
|
|
||||||
mentions = bookmark.get("mentions", [])
|
|
||||||
|
|
||||||
if hashtags or mentions:
|
|
||||||
parts.append("")
|
|
||||||
if hashtags:
|
|
||||||
parts.append(f"Hashtags: {', '.join(hashtags)}")
|
|
||||||
if mentions:
|
|
||||||
parts.append(f"Mentions: {', '.join(mentions)}")
|
|
||||||
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
async def read_twitter_bookmarks(self) -> list[str]:
|
|
||||||
"""
|
|
||||||
Read Twitter bookmark data and return formatted text chunks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of formatted text chunks ready for LEANN indexing
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await self.start_mcp_server()
|
|
||||||
await self.initialize_mcp_connection()
|
|
||||||
|
|
||||||
print(f"Fetching up to {self.max_bookmarks} bookmarks...")
|
|
||||||
if self.username:
|
|
||||||
print(f"Filtering for user: @{self.username}")
|
|
||||||
|
|
||||||
bookmarks = await self.fetch_twitter_bookmarks()
|
|
||||||
|
|
||||||
if not bookmarks:
|
|
||||||
print("No bookmarks found")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Processing {len(bookmarks)} bookmarks...")
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
processed_count = 0
|
|
||||||
|
|
||||||
for bookmark in bookmarks:
|
|
||||||
try:
|
|
||||||
formatted_bookmark = self._format_bookmark(bookmark)
|
|
||||||
if formatted_bookmark.strip():
|
|
||||||
all_texts.append(formatted_bookmark)
|
|
||||||
processed_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to format bookmark: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Successfully processed {processed_count} bookmarks")
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await self.stop_mcp_server()
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
"""Async context manager entry."""
|
|
||||||
await self.start_mcp_server()
|
|
||||||
await self.initialize_mcp_connection()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
"""Async context manager exit."""
|
|
||||||
await self.stop_mcp_server()
|
|
||||||
@@ -1,195 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Twitter RAG Application with MCP Support
|
|
||||||
|
|
||||||
This application enables RAG (Retrieval-Augmented Generation) on Twitter bookmarks
|
|
||||||
by connecting to Twitter MCP servers to fetch live data and index it in LEANN.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --query "What articles did I bookmark about AI?"
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from apps.base_rag_example import BaseRAGExample
|
|
||||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
|
||||||
|
|
||||||
|
|
||||||
class TwitterMCPRAG(BaseRAGExample):
|
|
||||||
"""
|
|
||||||
RAG application for Twitter bookmarks via MCP servers.
|
|
||||||
|
|
||||||
This class provides a complete RAG pipeline for Twitter bookmark data, including
|
|
||||||
MCP server connection, data fetching, indexing, and interactive chat.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
name="Twitter MCP RAG",
|
|
||||||
description="RAG application for Twitter bookmarks via MCP servers",
|
|
||||||
default_index_name="twitter_bookmarks",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
|
||||||
"""Add Twitter MCP-specific arguments."""
|
|
||||||
parser.add_argument(
|
|
||||||
"--mcp-server",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Command to start the Twitter MCP server (e.g., 'twitter-mcp-server' or 'npx twitter-mcp-server')",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--username", type=str, help="Twitter username to filter bookmarks (without @)"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-bookmarks",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Maximum number of bookmarks to fetch (default: 1000)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-tweet-content",
|
|
||||||
action="store_true",
|
|
||||||
help="Exclude tweet content, only include metadata",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-metadata",
|
|
||||||
action="store_true",
|
|
||||||
help="Exclude engagement metadata (likes, retweets, etc.)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--test-connection",
|
|
||||||
action="store_true",
|
|
||||||
help="Test MCP server connection and list available tools without indexing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_mcp_connection(self, args) -> bool:
|
|
||||||
"""Test the MCP server connection and display available tools."""
|
|
||||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader = TwitterMCPReader(
|
|
||||||
mcp_server_command=args.mcp_server,
|
|
||||||
username=args.username,
|
|
||||||
include_tweet_content=not args.no_tweet_content,
|
|
||||||
include_metadata=not args.no_metadata,
|
|
||||||
max_bookmarks=args.max_bookmarks,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with reader:
|
|
||||||
tools = await reader.list_available_tools()
|
|
||||||
|
|
||||||
print("\n✅ Successfully connected to MCP server!")
|
|
||||||
print(f"Available tools ({len(tools)}):")
|
|
||||||
|
|
||||||
for i, tool in enumerate(tools, 1):
|
|
||||||
name = tool.get("name", "Unknown")
|
|
||||||
description = tool.get("description", "No description available")
|
|
||||||
print(f"\n{i}. {name}")
|
|
||||||
print(
|
|
||||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show input schema if available
|
|
||||||
schema = tool.get("inputSchema", {})
|
|
||||||
if schema.get("properties"):
|
|
||||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
|
||||||
print(
|
|
||||||
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Failed to connect to MCP server: {e}")
|
|
||||||
print("\nTroubleshooting tips:")
|
|
||||||
print("1. Make sure the Twitter MCP server is installed and accessible")
|
|
||||||
print("2. Check if the server command is correct")
|
|
||||||
print("3. Ensure you have proper Twitter API credentials configured")
|
|
||||||
print("4. Verify your Twitter account has bookmarks to fetch")
|
|
||||||
print("5. Try running the MCP server command directly to test it")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load Twitter bookmarks via MCP server."""
|
|
||||||
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
|
|
||||||
|
|
||||||
if args.username:
|
|
||||||
print(f"Username filter: @{args.username}")
|
|
||||||
|
|
||||||
print(f"Max bookmarks: {args.max_bookmarks}")
|
|
||||||
print(f"Include tweet content: {not args.no_tweet_content}")
|
|
||||||
print(f"Include metadata: {not args.no_metadata}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader = TwitterMCPReader(
|
|
||||||
mcp_server_command=args.mcp_server,
|
|
||||||
username=args.username,
|
|
||||||
include_tweet_content=not args.no_tweet_content,
|
|
||||||
include_metadata=not args.no_metadata,
|
|
||||||
max_bookmarks=args.max_bookmarks,
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = await reader.read_twitter_bookmarks()
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
print("❌ No bookmarks found! This could mean:")
|
|
||||||
print("- You don't have any bookmarks on Twitter")
|
|
||||||
print("- The MCP server couldn't access your bookmarks")
|
|
||||||
print("- Authentication issues with Twitter API")
|
|
||||||
print("- The username filter didn't match any bookmarks")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"✅ Successfully loaded {len(texts)} bookmarks from Twitter")
|
|
||||||
|
|
||||||
# Show sample of what was loaded
|
|
||||||
if texts:
|
|
||||||
sample_text = texts[0][:300] + "..." if len(texts[0]) > 300 else texts[0]
|
|
||||||
print("\nSample bookmark:")
|
|
||||||
print("-" * 50)
|
|
||||||
print(sample_text)
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
return texts
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error loading Twitter bookmarks: {e}")
|
|
||||||
print("\nThis might be due to:")
|
|
||||||
print("- MCP server connection issues")
|
|
||||||
print("- Twitter API authentication problems")
|
|
||||||
print("- Network connectivity issues")
|
|
||||||
print("- Rate limiting from Twitter API")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""Main entry point with MCP connection testing."""
|
|
||||||
args = self.parser.parse_args()
|
|
||||||
|
|
||||||
# Test connection if requested
|
|
||||||
if args.test_connection:
|
|
||||||
success = await self.test_mcp_connection(args)
|
|
||||||
if not success:
|
|
||||||
return
|
|
||||||
print(
|
|
||||||
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Run the standard RAG pipeline
|
|
||||||
await super().run()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main entry point for the Twitter MCP RAG application."""
|
|
||||||
app = TwitterMCPRAG()
|
|
||||||
await app.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script to reproduce issue #159: Slow search performance
|
|
||||||
Configuration:
|
|
||||||
- GPU: A10
|
|
||||||
- embedding_model: BAAI/bge-large-zh-v1.5
|
|
||||||
- data size: 180M text (~90K chunks)
|
|
||||||
- backend: hnsw
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
|
|
||||||
|
|
||||||
# Configuration matching the issue
|
|
||||||
INDEX_PATH = "./test_issue_159.leann"
|
|
||||||
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
|
|
||||||
BACKEND_NAME = "hnsw"
|
|
||||||
|
|
||||||
|
|
||||||
def generate_test_data(num_chunks=90000, chunk_size=2000):
|
|
||||||
"""Generate test data similar to 180MB text (~90K chunks)"""
|
|
||||||
# Each chunk is approximately 2000 characters
|
|
||||||
# 90K chunks * 2000 chars ≈ 180MB
|
|
||||||
chunks = []
|
|
||||||
base_text = (
|
|
||||||
"这是一个测试文档。LEANN是一个创新的向量数据库, 通过图基选择性重计算实现97%的存储节省。"
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(num_chunks):
|
|
||||||
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
|
|
||||||
chunks.append(chunk[:chunk_size])
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
def test_search_performance():
|
|
||||||
"""Test search performance with different configurations"""
|
|
||||||
print("=" * 80)
|
|
||||||
print("Testing LEANN Search Performance (Issue #159)")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
meta_path = Path(f"{INDEX_PATH}.meta.json")
|
|
||||||
if meta_path.exists():
|
|
||||||
print(f"\n✓ Index already exists at {INDEX_PATH}")
|
|
||||||
print(" Skipping build phase. Delete the index to rebuild.")
|
|
||||||
else:
|
|
||||||
print("\n📦 Building index...")
|
|
||||||
print(f" Backend: {BACKEND_NAME}")
|
|
||||||
print(f" Embedding Model: {EMBEDDING_MODEL}")
|
|
||||||
print(" Generating test data (~90K chunks, ~180MB)...")
|
|
||||||
|
|
||||||
chunks = generate_test_data(num_chunks=90000)
|
|
||||||
print(f" Generated {len(chunks)} chunks")
|
|
||||||
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=BACKEND_NAME,
|
|
||||||
embedding_model=EMBEDDING_MODEL,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(" Adding chunks to builder...")
|
|
||||||
start_time = time.time()
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
builder.add_text(chunk)
|
|
||||||
if (i + 1) % 10000 == 0:
|
|
||||||
print(f" Added {i + 1}/{len(chunks)} chunks...")
|
|
||||||
|
|
||||||
print(" Building index...")
|
|
||||||
build_start = time.time()
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
build_time = time.time() - build_start
|
|
||||||
print(f" ✓ Index built in {build_time:.2f} seconds")
|
|
||||||
|
|
||||||
# Test search with different complexity values
|
|
||||||
print("\n🔍 Testing search performance...")
|
|
||||||
searcher = LeannSearcher(INDEX_PATH)
|
|
||||||
|
|
||||||
test_query = "LEANN向量数据库存储优化"
|
|
||||||
|
|
||||||
# Test with minimal complexity (8)
|
|
||||||
print("\n Test 4: Minimal complexity (8)")
|
|
||||||
print(f" Query: '{test_query}'")
|
|
||||||
start_time = time.time()
|
|
||||||
results = searcher.search(test_query, top_k=10, complexity=8)
|
|
||||||
search_time = time.time() - start_time
|
|
||||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
|
||||||
print(f" Results: {len(results)} items")
|
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_search_performance()
|
|
||||||
@@ -54,51 +54,29 @@ def extract_thinking_answer(response):
|
|||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
def load_hf_model(model_name="Qwen/Qwen3-8B"):
|
||||||
"""Load HuggingFace model
|
"""Load HuggingFace model"""
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): Name of the model to load
|
|
||||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
|
||||||
Defaults to False for security. Only enable for trusted models.
|
|
||||||
"""
|
|
||||||
if not HF_AVAILABLE:
|
if not HF_AVAILABLE:
|
||||||
raise ImportError("transformers not available")
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
if trust_remote_code:
|
|
||||||
print(
|
|
||||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Loading HF: {model_name}")
|
print(f"Loading HF: {model_name}")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
def load_vllm_model(model_name="Qwen/Qwen3-8B"):
|
||||||
"""Load vLLM model
|
"""Load vLLM model"""
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): Name of the model to load
|
|
||||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
|
||||||
Defaults to False for security. Only enable for trusted models.
|
|
||||||
"""
|
|
||||||
if not VLLM_AVAILABLE:
|
if not VLLM_AVAILABLE:
|
||||||
raise ImportError("vllm not available")
|
raise ImportError("vllm not available")
|
||||||
|
|
||||||
if trust_remote_code:
|
|
||||||
print(
|
|
||||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Loading vLLM: {model_name}")
|
print(f"Loading vLLM: {model_name}")
|
||||||
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
|
llm = LLM(model=model_name, trust_remote_code=True)
|
||||||
|
|
||||||
# Qwen3 specific config
|
# Qwen3 specific config
|
||||||
if is_qwen3_model(model_name):
|
if is_qwen3_model(model_name):
|
||||||
@@ -200,33 +178,19 @@ def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complex
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
|
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
|
||||||
"""Load Qwen2.5-VL multimodal model
|
"""Load Qwen2.5-VL multimodal model"""
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): Name of the model to load
|
|
||||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
|
||||||
Defaults to False for security. Only enable for trusted models.
|
|
||||||
"""
|
|
||||||
if not HF_AVAILABLE:
|
if not HF_AVAILABLE:
|
||||||
raise ImportError("transformers not available")
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
if trust_remote_code:
|
|
||||||
print(
|
|
||||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Loading Qwen2.5-VL: {model_name}")
|
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
model = AutoModelForVision2Seq.from_pretrained(
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
model_name,
|
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map="auto",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return processor, model
|
return processor, model
|
||||||
@@ -238,14 +202,9 @@ def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_co
|
|||||||
try:
|
try:
|
||||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained(
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
model_name, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
model_name,
|
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map="auto",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return processor, model
|
return processor, model
|
||||||
|
|||||||
@@ -1,143 +0,0 @@
|
|||||||
# Update Benchmarks
|
|
||||||
|
|
||||||
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
|
||||||
search” pipeline under different assumptions:
|
|
||||||
|
|
||||||
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
|
||||||
settings influence incremental `add()` latency when embeddings are fetched
|
|
||||||
over the ZMQ embedding server.
|
|
||||||
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
|
||||||
against an offline approach that keeps the graph static and fuses results.
|
|
||||||
|
|
||||||
Both suites build a non-compact, `is_recompute=True` index so that new
|
|
||||||
embeddings are pulled from the embedding server. Benchmark outputs are written
|
|
||||||
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
|
||||||
|
|
||||||
## Benchmarks
|
|
||||||
|
|
||||||
### 1. HNSW RNG Recompute Benchmark
|
|
||||||
|
|
||||||
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
|
||||||
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
|
||||||
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
|
||||||
is enabled:
|
|
||||||
|
|
||||||
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
|
||||||
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
|
||||||
| `baseline` | Enabled | Enabled | Enabled |
|
|
||||||
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
|
||||||
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
|
||||||
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
|
||||||
|
|
||||||
For each scenario the script:
|
|
||||||
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
|
||||||
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
|
||||||
3. Appends the requested updates using the scenario’s RNG flags.
|
|
||||||
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
|
||||||
timings before appending a row to the CSV output.
|
|
||||||
|
|
||||||
**Run:**
|
|
||||||
```bash
|
|
||||||
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
|
||||||
LEANN_LOG_LEVEL=INFO \
|
|
||||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
|
||||||
--runs 1 \
|
|
||||||
--index-path .leann/bench/test.leann \
|
|
||||||
--initial-files data/PrideandPrejudice.txt \
|
|
||||||
--update-files data/huawei_pangu.md \
|
|
||||||
--max-initial 300 \
|
|
||||||
--max-updates 1 \
|
|
||||||
--add-timeout 120
|
|
||||||
```
|
|
||||||
|
|
||||||
**Output:**
|
|
||||||
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
|
||||||
(including ms/passage) for each run.
|
|
||||||
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
|
||||||
`LEANN_HNSW_LOG_PATH`).
|
|
||||||
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
|
||||||
|
|
||||||
### 2. Sequential vs. Offline Update Benchmark
|
|
||||||
|
|
||||||
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
|
||||||
same dataset:
|
|
||||||
|
|
||||||
- **Scenario A – Sequential Update**
|
|
||||||
- Start an embedding server.
|
|
||||||
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
|
||||||
mutates the HNSW graph.
|
|
||||||
- After all inserts, run a search on the updated graph.
|
|
||||||
- Metrics recorded: update time (`add_total_s`), post-update search time
|
|
||||||
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
|
||||||
latency.
|
|
||||||
|
|
||||||
- **Scenario B – Offline Embedding + Concurrent Search**
|
|
||||||
- Stop Scenario A’s server and start a fresh embedding server.
|
|
||||||
- Spawn two threads: one generates embeddings for the new passages offline
|
|
||||||
(graph unchanged); the other computes the query embedding and searches the
|
|
||||||
existing graph.
|
|
||||||
- Merge offline similarities with the graph search results to emulate late
|
|
||||||
fusion, then report the merged top‑k preview.
|
|
||||||
- Metrics recorded: embedding time (`emb_time_s`), search time
|
|
||||||
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
|
||||||
|
|
||||||
**Run (both scenarios):**
|
|
||||||
```bash
|
|
||||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
|
||||||
--index-path .leann/bench/offline_vs_update.leann \
|
|
||||||
--max-initial 300 \
|
|
||||||
--num-updates 1
|
|
||||||
```
|
|
||||||
|
|
||||||
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
|
||||||
print timing summaries to stdout and append the results to CSV.
|
|
||||||
|
|
||||||
**Output:**
|
|
||||||
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
|
||||||
Scenario A and B.
|
|
||||||
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
|
||||||
checks.
|
|
||||||
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
|
||||||
|
|
||||||
### 3. Visualisation
|
|
||||||
|
|
||||||
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
|
||||||
benchmark into a single two-panel plot.
|
|
||||||
|
|
||||||
**Run:**
|
|
||||||
```bash
|
|
||||||
uv run -m benchmarks.update.plot_bench_results \
|
|
||||||
--csv benchmarks/update/bench_results.csv \
|
|
||||||
--csv-right benchmarks/update/offline_vs_update.csv \
|
|
||||||
--out benchmarks/update/bench_latency_from_csv.png
|
|
||||||
```
|
|
||||||
|
|
||||||
**Options:**
|
|
||||||
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
|
||||||
- `--csv` – RNG benchmark results CSV (left panel).
|
|
||||||
- `--csv-right` – Update strategy results CSV (right panel).
|
|
||||||
- `--out` – Output image path (PNG/PDF supported).
|
|
||||||
|
|
||||||
**Output:**
|
|
||||||
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
|
||||||
suites.
|
|
||||||
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
|
||||||
slides/papers.
|
|
||||||
|
|
||||||
## Parameters & Environment
|
|
||||||
|
|
||||||
### Common CLI Flags
|
|
||||||
- `--max-initial` – Number of initial passages used to seed the index.
|
|
||||||
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
|
||||||
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
|
||||||
- `--runs` – Number of repetitions (RNG benchmark only).
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
|
||||||
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
|
||||||
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
|
||||||
execution of the embedding model.
|
|
||||||
|
|
||||||
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
|
||||||
multiple RNG strategies, and evaluate whether sequential updates or offline
|
|
||||||
fusion better match your latency/accuracy trade-offs.
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
"""Benchmarks for LEANN update workflows."""
|
|
||||||
|
|
||||||
# Expose helper to locate repository root for other modules that need it.
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def find_repo_root() -> Path:
|
|
||||||
"""Return the project root containing pyproject.toml."""
|
|
||||||
current = Path(__file__).resolve()
|
|
||||||
for parent in current.parents:
|
|
||||||
if (parent / "pyproject.toml").exists():
|
|
||||||
return parent
|
|
||||||
return current.parents[1]
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["find_repo_root"]
|
|
||||||
@@ -1,804 +0,0 @@
|
|||||||
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
|
||||||
embedding recomputation.
|
|
||||||
|
|
||||||
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
|
||||||
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
|
||||||
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
|
||||||
when RNG pruning is fully enabled vs. partially/fully disabled.
|
|
||||||
|
|
||||||
Example usage (run from the repo root; downloads the model on first run)::
|
|
||||||
|
|
||||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
|
||||||
--index-path .leann/bench/leann-demo.leann \
|
|
||||||
--runs 1
|
|
||||||
|
|
||||||
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
|
||||||
if you want a larger or different workload, and change the embedding model via
|
|
||||||
``--model-name``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import msgpack
|
|
||||||
import numpy as np
|
|
||||||
import zmq
|
|
||||||
from leann.api import LeannBuilder
|
|
||||||
|
|
||||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
|
||||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
|
||||||
|
|
||||||
from leann.embedding_compute import compute_embeddings
|
|
||||||
from leann.embedding_server_manager import EmbeddingServerManager
|
|
||||||
from leann.registry import register_project_directory
|
|
||||||
from leann_backend_hnsw import faiss # type: ignore
|
|
||||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
if not logging.getLogger().handlers:
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
def _find_repo_root() -> Path:
|
|
||||||
"""Locate project root by walking up until pyproject.toml is found."""
|
|
||||||
current = Path(__file__).resolve()
|
|
||||||
for parent in current.parents:
|
|
||||||
if (parent / "pyproject.toml").exists():
|
|
||||||
return parent
|
|
||||||
# Fallback: assume repo is two levels up (../..)
|
|
||||||
return current.parents[2]
|
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = _find_repo_root()
|
|
||||||
if str(REPO_ROOT) not in sys.path:
|
|
||||||
sys.path.insert(0, str(REPO_ROOT))
|
|
||||||
|
|
||||||
from apps.chunking import create_text_chunks # noqa: E402
|
|
||||||
|
|
||||||
DEFAULT_INITIAL_FILES = [
|
|
||||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
|
||||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
|
||||||
]
|
|
||||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
|
||||||
|
|
||||||
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
|
||||||
|
|
||||||
|
|
||||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
|
|
||||||
documents = []
|
|
||||||
for path in paths:
|
|
||||||
p = path.expanduser().resolve()
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"Input path not found: {p}")
|
|
||||||
if p.is_dir():
|
|
||||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
|
||||||
documents.extend(reader.load_data(show_progress=True))
|
|
||||||
else:
|
|
||||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
|
||||||
documents.extend(reader.load_data(show_progress=True))
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
return []
|
|
||||||
|
|
||||||
chunks = create_text_chunks(
|
|
||||||
documents,
|
|
||||||
chunk_size=512,
|
|
||||||
chunk_overlap=128,
|
|
||||||
use_ast_chunking=False,
|
|
||||||
)
|
|
||||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
|
||||||
if limit is not None:
|
|
||||||
cleaned = cleaned[:limit]
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_index_dir(index_path: Path) -> None:
|
|
||||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_index_files(index_path: Path) -> None:
|
|
||||||
parent = index_path.parent
|
|
||||||
if not parent.exists():
|
|
||||||
return
|
|
||||||
stem = index_path.stem
|
|
||||||
for file in parent.glob(f"{stem}*"):
|
|
||||||
if file.is_file():
|
|
||||||
file.unlink()
|
|
||||||
|
|
||||||
|
|
||||||
def build_initial_index(
|
|
||||||
index_path: Path,
|
|
||||||
paragraphs: list[str],
|
|
||||||
model_name: str,
|
|
||||||
embedding_mode: str,
|
|
||||||
distance_metric: str,
|
|
||||||
ef_construction: int,
|
|
||||||
) -> None:
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model=model_name,
|
|
||||||
embedding_mode=embedding_mode,
|
|
||||||
is_compact=False,
|
|
||||||
is_recompute=True,
|
|
||||||
distance_metric=distance_metric,
|
|
||||||
backend_kwargs={
|
|
||||||
"distance_metric": distance_metric,
|
|
||||||
"is_compact": False,
|
|
||||||
"is_recompute": True,
|
|
||||||
"efConstruction": ef_construction,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for idx, passage in enumerate(paragraphs):
|
|
||||||
builder.add_text(passage, metadata={"id": str(idx)})
|
|
||||||
builder.build_index(str(index_path))
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
|
||||||
return [{"text": text, "metadata": {}} for text in paragraphs]
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_update_with_mode(
|
|
||||||
index_path: Path,
|
|
||||||
new_chunks: list[dict[str, Any]],
|
|
||||||
model_name: str,
|
|
||||||
embedding_mode: str,
|
|
||||||
distance_metric: str,
|
|
||||||
disable_forward_rng: bool,
|
|
||||||
disable_reverse_rng: bool,
|
|
||||||
server_port: int,
|
|
||||||
add_timeout: int,
|
|
||||||
ef_construction: int,
|
|
||||||
) -> tuple[float, float]:
|
|
||||||
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
|
||||||
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
|
||||||
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
|
||||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
|
||||||
|
|
||||||
with open(meta_path, encoding="utf-8") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
|
|
||||||
with open(offset_file, "rb") as f:
|
|
||||||
offset_map: dict[str, int] = pickle.load(f)
|
|
||||||
existing_ids = set(offset_map.keys())
|
|
||||||
|
|
||||||
valid_chunks: list[dict[str, Any]] = []
|
|
||||||
for chunk in new_chunks:
|
|
||||||
text = chunk.get("text", "")
|
|
||||||
if not isinstance(text, str) or not text.strip():
|
|
||||||
continue
|
|
||||||
metadata = chunk.setdefault("metadata", {})
|
|
||||||
passage_id = chunk.get("id") or metadata.get("id")
|
|
||||||
if passage_id and passage_id in existing_ids:
|
|
||||||
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
|
||||||
valid_chunks.append(chunk)
|
|
||||||
|
|
||||||
if not valid_chunks:
|
|
||||||
raise ValueError("No valid chunks to append.")
|
|
||||||
|
|
||||||
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
texts_to_embed,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
is_build=False,
|
|
||||||
batch_size=16,
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
||||||
if distance_metric == "cosine":
|
|
||||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
||||||
norms[norms == 0] = 1
|
|
||||||
embeddings = embeddings / norms
|
|
||||||
|
|
||||||
index = faiss.read_index(str(index_file))
|
|
||||||
index.is_recompute = True
|
|
||||||
if getattr(index, "storage", None) is None:
|
|
||||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
|
||||||
storage_index = faiss.IndexFlatIP(index.d)
|
|
||||||
else:
|
|
||||||
storage_index = faiss.IndexFlatL2(index.d)
|
|
||||||
index.storage = storage_index
|
|
||||||
index.own_fields = True
|
|
||||||
try:
|
|
||||||
storage_index.ntotal = index.ntotal
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
|
||||||
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
|
||||||
if ef_construction is not None:
|
|
||||||
index.hnsw.efConstruction = ef_construction
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
|
||||||
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
|
||||||
logger.info(
|
|
||||||
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
|
||||||
disable_forward_rng,
|
|
||||||
disable_reverse_rng,
|
|
||||||
applied_forward,
|
|
||||||
applied_reverse,
|
|
||||||
)
|
|
||||||
|
|
||||||
base_id = index.ntotal
|
|
||||||
for offset, chunk in enumerate(valid_chunks):
|
|
||||||
new_id = str(base_id + offset)
|
|
||||||
chunk.setdefault("metadata", {})["id"] = new_id
|
|
||||||
chunk["id"] = new_id
|
|
||||||
|
|
||||||
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
|
||||||
offset_map_backup = offset_map.copy()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(passages_file, "a", encoding="utf-8") as f:
|
|
||||||
for chunk in valid_chunks:
|
|
||||||
offset = f.tell()
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"id": chunk["id"],
|
|
||||||
"text": chunk["text"],
|
|
||||||
"metadata": chunk.get("metadata", {}),
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
f.write("\n")
|
|
||||||
offset_map[chunk["id"]] = offset
|
|
||||||
|
|
||||||
with open(offset_file, "wb") as f:
|
|
||||||
pickle.dump(offset_map, f)
|
|
||||||
|
|
||||||
server_manager = EmbeddingServerManager(
|
|
||||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
|
||||||
)
|
|
||||||
server_started, actual_port = server_manager.start_server(
|
|
||||||
port=server_port,
|
|
||||||
model_name=model_name,
|
|
||||||
embedding_mode=embedding_mode,
|
|
||||||
passages_file=str(meta_path),
|
|
||||||
distance_metric=distance_metric,
|
|
||||||
)
|
|
||||||
if not server_started:
|
|
||||||
raise RuntimeError("Failed to start embedding server.")
|
|
||||||
|
|
||||||
if hasattr(index.hnsw, "set_zmq_port"):
|
|
||||||
index.hnsw.set_zmq_port(actual_port)
|
|
||||||
elif hasattr(index, "set_zmq_port"):
|
|
||||||
index.set_zmq_port(actual_port)
|
|
||||||
|
|
||||||
_warmup_embedding_server(actual_port)
|
|
||||||
|
|
||||||
total_start = time.time()
|
|
||||||
add_elapsed = 0.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
import signal
|
|
||||||
|
|
||||||
def _timeout_handler(signum, frame):
|
|
||||||
raise TimeoutError("incremental add timed out")
|
|
||||||
|
|
||||||
if add_timeout > 0:
|
|
||||||
signal.signal(signal.SIGALRM, _timeout_handler)
|
|
||||||
signal.alarm(add_timeout)
|
|
||||||
|
|
||||||
add_start = time.time()
|
|
||||||
for i in range(embeddings.shape[0]):
|
|
||||||
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
|
||||||
add_elapsed = time.time() - add_start
|
|
||||||
if add_timeout > 0:
|
|
||||||
signal.alarm(0)
|
|
||||||
faiss.write_index(index, str(index_file))
|
|
||||||
finally:
|
|
||||||
server_manager.stop_server()
|
|
||||||
|
|
||||||
except TimeoutError:
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
if passages_file.exists():
|
|
||||||
with open(passages_file, "rb+") as f:
|
|
||||||
f.truncate(rollback_size)
|
|
||||||
with open(offset_file, "wb") as f:
|
|
||||||
pickle.dump(offset_map_backup, f)
|
|
||||||
raise
|
|
||||||
|
|
||||||
prune_hnsw_embeddings_inplace(str(index_file))
|
|
||||||
|
|
||||||
meta["total_passages"] = len(offset_map)
|
|
||||||
with open(meta_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(meta, f, indent=2)
|
|
||||||
|
|
||||||
# Reset toggles so the index on disk returns to baseline behaviour.
|
|
||||||
try:
|
|
||||||
index.hnsw.set_disable_rng_during_add(False)
|
|
||||||
index.hnsw.set_disable_reverse_prune(False)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
faiss.write_index(index, str(index_file))
|
|
||||||
|
|
||||||
total_elapsed = time.time() - total_start
|
|
||||||
|
|
||||||
return total_elapsed, add_elapsed
|
|
||||||
|
|
||||||
|
|
||||||
def _total_zmq_nodes(log_path: Path) -> int:
|
|
||||||
if not log_path.exists():
|
|
||||||
return 0
|
|
||||||
with log_path.open("r", encoding="utf-8") as log_file:
|
|
||||||
text = log_file.read()
|
|
||||||
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
|
||||||
|
|
||||||
|
|
||||||
def _warmup_embedding_server(port: int) -> None:
|
|
||||||
"""Send a dummy REQ so the embedding server loads its model."""
|
|
||||||
ctx = zmq.Context()
|
|
||||||
try:
|
|
||||||
sock = ctx.socket(zmq.REQ)
|
|
||||||
sock.setsockopt(zmq.LINGER, 0)
|
|
||||||
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
|
||||||
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
|
||||||
sock.connect(f"tcp://127.0.0.1:{port}")
|
|
||||||
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
|
||||||
sock.send(payload)
|
|
||||||
try:
|
|
||||||
sock.recv()
|
|
||||||
except zmq.error.Again:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
sock.close()
|
|
||||||
ctx.term()
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
"--index-path",
|
|
||||||
type=Path,
|
|
||||||
default=Path(".leann/bench/leann-demo.leann"),
|
|
||||||
help="Output index base path (without extension).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--initial-files",
|
|
||||||
nargs="*",
|
|
||||||
type=Path,
|
|
||||||
default=DEFAULT_INITIAL_FILES,
|
|
||||||
help="Files used to build the initial index.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--update-files",
|
|
||||||
nargs="*",
|
|
||||||
type=Path,
|
|
||||||
default=DEFAULT_UPDATE_FILES,
|
|
||||||
help="Files appended during the benchmark.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-name",
|
|
||||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
help="Embedding model used for build/update.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--embedding-mode",
|
|
||||||
default="sentence-transformers",
|
|
||||||
help="Embedding mode passed to LeannBuilder/embedding server.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--distance-metric",
|
|
||||||
default="mips",
|
|
||||||
choices=["mips", "l2", "cosine"],
|
|
||||||
help="Distance metric for HNSW backend.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ef-construction",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="efConstruction setting for initial build.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--server-port",
|
|
||||||
type=int,
|
|
||||||
default=5557,
|
|
||||||
help="Port for the real embedding server.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-initial",
|
|
||||||
type=int,
|
|
||||||
default=300,
|
|
||||||
help="Optional cap on initial passages (after chunking).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-updates",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Optional cap on update passages (after chunking).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--add-timeout",
|
|
||||||
type=int,
|
|
||||||
default=900,
|
|
||||||
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--plot-path",
|
|
||||||
type=Path,
|
|
||||||
default=Path("bench_latency.png"),
|
|
||||||
help="Where to save the latency bar plot.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cap-y",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--broken-y",
|
|
||||||
action="store_true",
|
|
||||||
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lower-cap-y",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--upper-start-y",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--csv-path",
|
|
||||||
type=Path,
|
|
||||||
default=Path("benchmarks/update/bench_results.csv"),
|
|
||||||
help="Where to append per-scenario results as CSV.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
register_project_directory(REPO_ROOT)
|
|
||||||
|
|
||||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
|
||||||
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
|
||||||
if not update_paragraphs:
|
|
||||||
raise ValueError("No update passages found; please provide --update-files with content.")
|
|
||||||
|
|
||||||
update_chunks = prepare_new_chunks(update_paragraphs)
|
|
||||||
ensure_index_dir(args.index_path)
|
|
||||||
|
|
||||||
scenarios = [
|
|
||||||
("baseline", False, False, True),
|
|
||||||
("no_cache_baseline", False, False, False),
|
|
||||||
("disable_forward_rng", True, False, True),
|
|
||||||
("disable_forward_and_reverse_rng", True, True, True),
|
|
||||||
]
|
|
||||||
|
|
||||||
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
|
||||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
|
||||||
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
|
||||||
|
|
||||||
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
||||||
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
||||||
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
|
||||||
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
||||||
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
||||||
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
||||||
|
|
||||||
# CSV setup
|
|
||||||
import csv
|
|
||||||
|
|
||||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
|
||||||
csv_fields = [
|
|
||||||
"run_id",
|
|
||||||
"scenario",
|
|
||||||
"cache_enabled",
|
|
||||||
"ef_construction",
|
|
||||||
"max_initial",
|
|
||||||
"max_updates",
|
|
||||||
"total_time_s",
|
|
||||||
"add_only_s",
|
|
||||||
"latency_ms_per_passage",
|
|
||||||
"zmq_nodes",
|
|
||||||
"stageA_time_s",
|
|
||||||
"stageBC_time_s",
|
|
||||||
"model_name",
|
|
||||||
"embedding_mode",
|
|
||||||
"distance_metric",
|
|
||||||
]
|
|
||||||
# Create CSV with header if missing
|
|
||||||
if args.csv_path:
|
|
||||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
|
||||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
|
||||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
|
||||||
writer.writeheader()
|
|
||||||
|
|
||||||
for run in range(args.runs):
|
|
||||||
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
|
||||||
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
|
||||||
print(f"\nScenario: {name}")
|
|
||||||
cleanup_index_files(args.index_path)
|
|
||||||
if log_path.exists():
|
|
||||||
try:
|
|
||||||
log_path.unlink()
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
|
||||||
build_initial_index(
|
|
||||||
args.index_path,
|
|
||||||
initial_paragraphs,
|
|
||||||
args.model_name,
|
|
||||||
args.embedding_mode,
|
|
||||||
args.distance_metric,
|
|
||||||
args.ef_construction,
|
|
||||||
)
|
|
||||||
|
|
||||||
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
|
||||||
args.index_path,
|
|
||||||
update_chunks,
|
|
||||||
args.model_name,
|
|
||||||
args.embedding_mode,
|
|
||||||
args.distance_metric,
|
|
||||||
disable_forward,
|
|
||||||
disable_reverse,
|
|
||||||
args.server_port,
|
|
||||||
args.add_timeout,
|
|
||||||
args.ef_construction,
|
|
||||||
)
|
|
||||||
except TimeoutError as exc:
|
|
||||||
print(f"Scenario {name} timed out: {exc}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
|
||||||
if curr_size < prev_size:
|
|
||||||
prev_size = 0
|
|
||||||
zmq_count = 0
|
|
||||||
if log_path.exists():
|
|
||||||
with log_path.open("r", encoding="utf-8") as log_file:
|
|
||||||
log_file.seek(prev_size)
|
|
||||||
new_entries = log_file.read()
|
|
||||||
zmq_count = sum(
|
|
||||||
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
|
||||||
)
|
|
||||||
stageA = sum(
|
|
||||||
float(x)
|
|
||||||
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
|
||||||
)
|
|
||||||
stageBC = sum(
|
|
||||||
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stageA = 0.0
|
|
||||||
stageBC = 0.0
|
|
||||||
|
|
||||||
per_chunk = add_elapsed / len(update_chunks)
|
|
||||||
print(
|
|
||||||
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
|
||||||
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
|
||||||
)
|
|
||||||
print(f"ZMQ node fetch total: {zmq_count}")
|
|
||||||
results_total[name].append(total_elapsed)
|
|
||||||
results_add[name].append(add_elapsed)
|
|
||||||
results_zmq[name].append(zmq_count)
|
|
||||||
results_ms_per_passage[name].append(per_chunk * 1e3)
|
|
||||||
results_stageA[name].append(stageA)
|
|
||||||
results_stageBC[name].append(stageBC)
|
|
||||||
|
|
||||||
# Append row to CSV
|
|
||||||
if args.csv_path:
|
|
||||||
row = {
|
|
||||||
"run_id": run_id,
|
|
||||||
"scenario": name,
|
|
||||||
"cache_enabled": 1 if cache_enabled else 0,
|
|
||||||
"ef_construction": args.ef_construction,
|
|
||||||
"max_initial": args.max_initial,
|
|
||||||
"max_updates": args.max_updates,
|
|
||||||
"total_time_s": round(total_elapsed, 6),
|
|
||||||
"add_only_s": round(add_elapsed, 6),
|
|
||||||
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
|
||||||
"zmq_nodes": int(zmq_count),
|
|
||||||
"stageA_time_s": round(stageA, 6),
|
|
||||||
"stageBC_time_s": round(stageBC, 6),
|
|
||||||
"model_name": args.model_name,
|
|
||||||
"embedding_mode": args.embedding_mode,
|
|
||||||
"distance_metric": args.distance_metric,
|
|
||||||
}
|
|
||||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
|
||||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
|
||||||
writer.writerow(row)
|
|
||||||
|
|
||||||
print("\n=== Summary ===")
|
|
||||||
for name in results_add:
|
|
||||||
add_values = results_add[name]
|
|
||||||
total_values = results_total[name]
|
|
||||||
zmq_values = results_zmq[name]
|
|
||||||
latency_values = results_ms_per_passage[name]
|
|
||||||
if not add_values:
|
|
||||||
print(f"{name}: no successful runs")
|
|
||||||
continue
|
|
||||||
avg_add = sum(add_values) / len(add_values)
|
|
||||||
avg_total = sum(total_values) / len(total_values)
|
|
||||||
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
|
||||||
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
|
||||||
runs = len(add_values)
|
|
||||||
print(
|
|
||||||
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
|
||||||
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.plot_path:
|
|
||||||
try:
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
labels = [name for name, *_ in scenarios]
|
|
||||||
values = [
|
|
||||||
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
|
||||||
if results_ms_per_passage[name]
|
|
||||||
else 0.0
|
|
||||||
for name in labels
|
|
||||||
]
|
|
||||||
|
|
||||||
def _auto_cap(vals: list[float]) -> float | None:
|
|
||||||
s = sorted(vals, reverse=True)
|
|
||||||
if len(s) < 2:
|
|
||||||
return None
|
|
||||||
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
|
||||||
return s[1] * 1.1
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _fmt_ms(v: float) -> str:
|
|
||||||
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
|
||||||
|
|
||||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
|
||||||
|
|
||||||
if args.broken_y:
|
|
||||||
s = sorted(values, reverse=True)
|
|
||||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
|
||||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
|
||||||
upper_start = (
|
|
||||||
args.upper_start_y
|
|
||||||
if args.upper_start_y is not None
|
|
||||||
else max(second * 1.2, lower_cap * 1.02)
|
|
||||||
)
|
|
||||||
ymax = max(values) * 1.10 if values else 1.0
|
|
||||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
sharex=True,
|
|
||||||
figsize=(7.4, 5.0),
|
|
||||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
|
||||||
)
|
|
||||||
x = list(range(len(labels)))
|
|
||||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
|
||||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
|
||||||
ax_bottom.set_ylim(0, lower_cap)
|
|
||||||
ax_top.set_ylim(upper_start, ymax)
|
|
||||||
for i, v in enumerate(values):
|
|
||||||
if v <= lower_cap:
|
|
||||||
ax_bottom.text(
|
|
||||||
i,
|
|
||||||
v + lower_cap * 0.02,
|
|
||||||
_fmt_ms(v),
|
|
||||||
ha="center",
|
|
||||||
va="bottom",
|
|
||||||
fontsize=9,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
|
||||||
ax_top.spines["bottom"].set_visible(False)
|
|
||||||
ax_bottom.spines["top"].set_visible(False)
|
|
||||||
ax_top.tick_params(labeltop=False)
|
|
||||||
ax_bottom.xaxis.tick_bottom()
|
|
||||||
d = 0.015
|
|
||||||
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
|
||||||
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
|
||||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
|
||||||
kwargs.update({"transform": ax_bottom.transAxes})
|
|
||||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
|
||||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
|
||||||
ax_bottom.set_xticks(range(len(labels)))
|
|
||||||
ax_bottom.set_xticklabels(labels)
|
|
||||||
ax = ax_bottom
|
|
||||||
else:
|
|
||||||
cap = args.cap_y or _auto_cap(values)
|
|
||||||
plt.figure(figsize=(7.2, 4.2))
|
|
||||||
ax = plt.gca()
|
|
||||||
if cap is not None:
|
|
||||||
show_vals = [min(v, cap) for v in values]
|
|
||||||
bars = []
|
|
||||||
for i, (v, show) in enumerate(zip(values, show_vals)):
|
|
||||||
b = ax.bar(i, show, color=colors[i], width=0.8)
|
|
||||||
bars.append(b[0])
|
|
||||||
if v > cap:
|
|
||||||
bars[-1].set_hatch("//")
|
|
||||||
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
|
||||||
else:
|
|
||||||
ax.text(
|
|
||||||
i,
|
|
||||||
show + max(1.0, 0.01 * (cap or show)),
|
|
||||||
_fmt_ms(v),
|
|
||||||
ha="center",
|
|
||||||
va="bottom",
|
|
||||||
fontsize=9,
|
|
||||||
)
|
|
||||||
ax.set_ylim(0, cap * 1.10)
|
|
||||||
ax.plot(
|
|
||||||
[0.02 - 0.02, 0.02 + 0.02],
|
|
||||||
[0.98 + 0.02, 0.98 - 0.02],
|
|
||||||
transform=ax.transAxes,
|
|
||||||
color="k",
|
|
||||||
lw=1,
|
|
||||||
)
|
|
||||||
ax.plot(
|
|
||||||
[0.98 - 0.02, 0.98 + 0.02],
|
|
||||||
[0.98 + 0.02, 0.98 - 0.02],
|
|
||||||
transform=ax.transAxes,
|
|
||||||
color="k",
|
|
||||||
lw=1,
|
|
||||||
)
|
|
||||||
if any(v > cap for v in values):
|
|
||||||
ax.legend(
|
|
||||||
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
|
||||||
)
|
|
||||||
ax.set_xticks(range(len(labels)))
|
|
||||||
ax.set_xticklabels(labels)
|
|
||||||
else:
|
|
||||||
ax.bar(labels, values, color=colors[: len(labels)])
|
|
||||||
for idx, val in enumerate(values):
|
|
||||||
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
|
||||||
|
|
||||||
plt.ylabel("Average add latency (ms per passage)")
|
|
||||||
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(args.plot_path)
|
|
||||||
print(f"Saved latency bar plot to {args.plot_path}")
|
|
||||||
# ZMQ time split (Stage A vs B/C)
|
|
||||||
try:
|
|
||||||
plt.figure(figsize=(6, 4))
|
|
||||||
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
|
||||||
bc_vals = [
|
|
||||||
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
|
||||||
]
|
|
||||||
ind = range(len(labels))
|
|
||||||
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
|
||||||
plt.bar(
|
|
||||||
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
|
||||||
)
|
|
||||||
plt.xticks(list(ind), labels, rotation=10)
|
|
||||||
plt.ylabel("Server ZMQ time (s)")
|
|
||||||
plt.title(
|
|
||||||
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
|
||||||
)
|
|
||||||
plt.legend()
|
|
||||||
out2 = args.plot_path.with_name(
|
|
||||||
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
|
||||||
)
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(out2)
|
|
||||||
print(f"Saved ZMQ time split plot to {out2}")
|
|
||||||
except Exception as e:
|
|
||||||
print("Failed to plot ZMQ split:", e)
|
|
||||||
except ImportError:
|
|
||||||
print("matplotlib not available; skipping plot generation")
|
|
||||||
|
|
||||||
# leave the last build on disk for inspection
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
|
||||||
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
|
||||||
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
|
||||||
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
|
||||||
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
|
||||||
|
@@ -1,704 +0,0 @@
|
|||||||
"""
|
|
||||||
Compare two latency models for small incremental updates vs. search:
|
|
||||||
|
|
||||||
Scenario A (sequential update then search):
|
|
||||||
- Build initial HNSW (is_recompute=True)
|
|
||||||
- Start embedding server (ZMQ) for recompute
|
|
||||||
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
|
||||||
- Then run a search query on the updated index
|
|
||||||
- Report total time = sum(add_i) + search_time, with breakdowns
|
|
||||||
|
|
||||||
Scenario B (offline embeds + concurrent search; no graph updates):
|
|
||||||
- Do NOT insert the N passages into the graph
|
|
||||||
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
|
||||||
embedding and run a search on the existing index
|
|
||||||
- After both finish, compute similarity between the query embedding and the N
|
|
||||||
new passage embeddings, merge with the index search results by score, and
|
|
||||||
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
|
||||||
|
|
||||||
This script reuses the model/data loading conventions of
|
|
||||||
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
|
||||||
comparison for the two execution strategies above.
|
|
||||||
|
|
||||||
Example (from the repository root):
|
|
||||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
|
||||||
--index-path .leann/bench/offline_vs_update.leann \
|
|
||||||
--max-initial 300 --num-updates 5 --k 10
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import psutil # type: ignore
|
|
||||||
from leann.api import LeannBuilder
|
|
||||||
|
|
||||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
|
||||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
|
||||||
|
|
||||||
from leann.embedding_compute import compute_embeddings
|
|
||||||
from leann.embedding_server_manager import EmbeddingServerManager
|
|
||||||
from leann.registry import register_project_directory
|
|
||||||
from leann_backend_hnsw import faiss # type: ignore
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
if not logging.getLogger().handlers:
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
def _find_repo_root() -> Path:
|
|
||||||
"""Locate project root by walking up until pyproject.toml is found."""
|
|
||||||
current = Path(__file__).resolve()
|
|
||||||
for parent in current.parents:
|
|
||||||
if (parent / "pyproject.toml").exists():
|
|
||||||
return parent
|
|
||||||
# Fallback: assume repo is two levels up (../..)
|
|
||||||
return current.parents[2]
|
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = _find_repo_root()
|
|
||||||
if str(REPO_ROOT) not in sys.path:
|
|
||||||
sys.path.insert(0, str(REPO_ROOT))
|
|
||||||
|
|
||||||
from apps.chunking import create_text_chunks # noqa: E402
|
|
||||||
|
|
||||||
DEFAULT_INITIAL_FILES = [
|
|
||||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
|
||||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
|
||||||
]
|
|
||||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
|
||||||
|
|
||||||
|
|
||||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
|
|
||||||
documents = []
|
|
||||||
for path in paths:
|
|
||||||
p = path.expanduser().resolve()
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"Input path not found: {p}")
|
|
||||||
if p.is_dir():
|
|
||||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
|
||||||
documents.extend(reader.load_data(show_progress=True))
|
|
||||||
else:
|
|
||||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
|
||||||
documents.extend(reader.load_data(show_progress=True))
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
return []
|
|
||||||
|
|
||||||
chunks = create_text_chunks(
|
|
||||||
documents,
|
|
||||||
chunk_size=512,
|
|
||||||
chunk_overlap=128,
|
|
||||||
use_ast_chunking=False,
|
|
||||||
)
|
|
||||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
|
||||||
if limit is not None:
|
|
||||||
cleaned = cleaned[:limit]
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_index_dir(index_path: Path) -> None:
|
|
||||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_index_files(index_path: Path) -> None:
|
|
||||||
parent = index_path.parent
|
|
||||||
if not parent.exists():
|
|
||||||
return
|
|
||||||
stem = index_path.stem
|
|
||||||
for file in parent.glob(f"{stem}*"):
|
|
||||||
if file.is_file():
|
|
||||||
file.unlink()
|
|
||||||
|
|
||||||
|
|
||||||
def build_initial_index(
|
|
||||||
index_path: Path,
|
|
||||||
paragraphs: list[str],
|
|
||||||
model_name: str,
|
|
||||||
embedding_mode: str,
|
|
||||||
distance_metric: str,
|
|
||||||
ef_construction: int,
|
|
||||||
) -> None:
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model=model_name,
|
|
||||||
embedding_mode=embedding_mode,
|
|
||||||
is_compact=False,
|
|
||||||
is_recompute=True,
|
|
||||||
distance_metric=distance_metric,
|
|
||||||
backend_kwargs={
|
|
||||||
"distance_metric": distance_metric,
|
|
||||||
"is_compact": False,
|
|
||||||
"is_recompute": True,
|
|
||||||
"efConstruction": ef_construction,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for idx, passage in enumerate(paragraphs):
|
|
||||||
builder.add_text(passage, metadata={"id": str(idx)})
|
|
||||||
builder.build_index(str(index_path))
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
|
||||||
if metric == "cosine":
|
|
||||||
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
|
||||||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
|
||||||
norms[norms == 0] = 1
|
|
||||||
vecs = vecs / norms
|
|
||||||
return vecs
|
|
||||||
|
|
||||||
|
|
||||||
def _read_index_for_search(index_path: Path) -> Any:
|
|
||||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
|
||||||
# Force-disable experimental disk cache when loading the index so that
|
|
||||||
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
|
||||||
cfg = faiss.HNSWIndexConfig()
|
|
||||||
cfg.is_recompute = True
|
|
||||||
if hasattr(cfg, "disk_cache_ratio"):
|
|
||||||
cfg.disk_cache_ratio = 0.0
|
|
||||||
if hasattr(cfg, "external_storage_path"):
|
|
||||||
cfg.external_storage_path = None
|
|
||||||
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
|
||||||
index = faiss.read_index(str(index_file), io_flags, cfg)
|
|
||||||
# ensure recompute mode persists after reload
|
|
||||||
try:
|
|
||||||
index.is_recompute = True
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
actual_ntotal = index.hnsw.levels.size()
|
|
||||||
except AttributeError:
|
|
||||||
actual_ntotal = index.ntotal
|
|
||||||
if actual_ntotal != index.ntotal:
|
|
||||||
print(
|
|
||||||
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
index.ntotal = actual_ntotal
|
|
||||||
if getattr(index, "storage", None) is None:
|
|
||||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
|
||||||
storage_index = faiss.IndexFlatIP(index.d)
|
|
||||||
else:
|
|
||||||
storage_index = faiss.IndexFlatL2(index.d)
|
|
||||||
index.storage = storage_index
|
|
||||||
index.own_fields = True
|
|
||||||
return index
|
|
||||||
|
|
||||||
|
|
||||||
def _append_passages_for_updates(
|
|
||||||
meta_path: Path,
|
|
||||||
start_id: int,
|
|
||||||
texts: list[str],
|
|
||||||
) -> list[str]:
|
|
||||||
"""Append update passages so the embedding server can serve recompute fetches."""
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
|
|
||||||
index_dir = meta_path.parent
|
|
||||||
meta_name = meta_path.name
|
|
||||||
if not meta_name.endswith(".meta.json"):
|
|
||||||
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
|
||||||
index_base = meta_name[: -len(".meta.json")]
|
|
||||||
|
|
||||||
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
|
||||||
offsets_file = index_dir / f"{index_base}.passages.idx"
|
|
||||||
|
|
||||||
if not passages_file.exists() or not offsets_file.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
"Passage store missing; cannot register update passages for recompute mode."
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(offsets_file, "rb") as f:
|
|
||||||
offset_map: dict[str, int] = pickle.load(f)
|
|
||||||
|
|
||||||
assigned_ids: list[str] = []
|
|
||||||
with open(passages_file, "a", encoding="utf-8") as f:
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
passage_id = str(start_id + i)
|
|
||||||
offset = f.tell()
|
|
||||||
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
|
||||||
f.write("\n")
|
|
||||||
offset_map[passage_id] = offset
|
|
||||||
assigned_ids.append(passage_id)
|
|
||||||
|
|
||||||
with open(offsets_file, "wb") as f:
|
|
||||||
pickle.dump(offset_map, f)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(meta_path, encoding="utf-8") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
meta = {}
|
|
||||||
meta["total_passages"] = len(offset_map)
|
|
||||||
with open(meta_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(meta, f, indent=2)
|
|
||||||
|
|
||||||
return assigned_ids
|
|
||||||
|
|
||||||
|
|
||||||
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
|
||||||
q = np.ascontiguousarray(q, dtype=np.float32)
|
|
||||||
distances = np.zeros((1, k), dtype=np.float32)
|
|
||||||
indices = np.zeros((1, k), dtype=np.int64)
|
|
||||||
index.search(
|
|
||||||
1,
|
|
||||||
faiss.swig_ptr(q),
|
|
||||||
k,
|
|
||||||
faiss.swig_ptr(distances),
|
|
||||||
faiss.swig_ptr(indices),
|
|
||||||
)
|
|
||||||
return distances[0], indices[0]
|
|
||||||
|
|
||||||
|
|
||||||
def _score_for_metric(dist: float, metric: str) -> float:
|
|
||||||
# Convert FAISS distance to a "higher is better" score
|
|
||||||
if metric in ("mips", "cosine"):
|
|
||||||
return float(dist)
|
|
||||||
# l2 distance (smaller better) -> negative distance as score
|
|
||||||
return -float(dist)
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_results(
|
|
||||||
index_results: tuple[np.ndarray, np.ndarray],
|
|
||||||
offline_scores: list[tuple[int, float]],
|
|
||||||
k: int,
|
|
||||||
metric: str,
|
|
||||||
) -> list[tuple[str, float]]:
|
|
||||||
distances, indices = index_results
|
|
||||||
merged: list[tuple[str, float]] = []
|
|
||||||
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
|
||||||
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
|
||||||
for j, s in offline_scores:
|
|
||||||
merged.append((f"offline:{j}", s))
|
|
||||||
merged.sort(key=lambda x: x[1], reverse=True)
|
|
||||||
return merged[:k]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ScenarioResult:
|
|
||||||
name: str
|
|
||||||
update_total_s: float
|
|
||||||
search_s: float
|
|
||||||
overall_s: float
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
"--index-path",
|
|
||||||
type=Path,
|
|
||||||
default=Path(".leann/bench/offline-vs-update.leann"),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--initial-files",
|
|
||||||
nargs="*",
|
|
||||||
type=Path,
|
|
||||||
default=DEFAULT_INITIAL_FILES,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--update-files",
|
|
||||||
nargs="*",
|
|
||||||
type=Path,
|
|
||||||
default=DEFAULT_UPDATE_FILES,
|
|
||||||
)
|
|
||||||
parser.add_argument("--max-initial", type=int, default=300)
|
|
||||||
parser.add_argument("--num-updates", type=int, default=5)
|
|
||||||
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
|
||||||
parser.add_argument(
|
|
||||||
"--query",
|
|
||||||
type=str,
|
|
||||||
default="neural network",
|
|
||||||
help="Query text used for the search benchmark.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--server-port", type=int, default=5557)
|
|
||||||
parser.add_argument("--add-timeout", type=int, default=600)
|
|
||||||
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
|
||||||
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
|
||||||
parser.add_argument(
|
|
||||||
"--distance-metric",
|
|
||||||
default="mips",
|
|
||||||
choices=["mips", "l2", "cosine"],
|
|
||||||
)
|
|
||||||
parser.add_argument("--ef-construction", type=int, default=200)
|
|
||||||
parser.add_argument(
|
|
||||||
"--only",
|
|
||||||
choices=["A", "B", "both"],
|
|
||||||
default="both",
|
|
||||||
help="Run only Scenario A, Scenario B, or both",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--csv-path",
|
|
||||||
type=Path,
|
|
||||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
|
||||||
help="Where to append results (CSV).",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
register_project_directory(REPO_ROOT)
|
|
||||||
|
|
||||||
# Load data
|
|
||||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
|
||||||
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
|
||||||
if not update_paragraphs:
|
|
||||||
raise ValueError("No update passages loaded from --update-files")
|
|
||||||
update_paragraphs = update_paragraphs[: args.num_updates]
|
|
||||||
if len(update_paragraphs) < args.num_updates:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
|
||||||
)
|
|
||||||
|
|
||||||
ensure_index_dir(args.index_path)
|
|
||||||
cleanup_index_files(args.index_path)
|
|
||||||
|
|
||||||
# Build initial index
|
|
||||||
build_initial_index(
|
|
||||||
args.index_path,
|
|
||||||
initial_paragraphs,
|
|
||||||
args.model_name,
|
|
||||||
args.embedding_mode,
|
|
||||||
args.distance_metric,
|
|
||||||
args.ef_construction,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare index object and meta
|
|
||||||
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
|
||||||
index = _read_index_for_search(args.index_path)
|
|
||||||
|
|
||||||
# CSV setup
|
|
||||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
|
||||||
if args.csv_path:
|
|
||||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
csv_fields = [
|
|
||||||
"run_id",
|
|
||||||
"scenario",
|
|
||||||
"max_initial",
|
|
||||||
"num_updates",
|
|
||||||
"k",
|
|
||||||
"total_time_s",
|
|
||||||
"add_total_s",
|
|
||||||
"search_time_s",
|
|
||||||
"emb_time_s",
|
|
||||||
"makespan_s",
|
|
||||||
"model_name",
|
|
||||||
"embedding_mode",
|
|
||||||
"distance_metric",
|
|
||||||
]
|
|
||||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
|
||||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
|
||||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
|
||||||
writer.writeheader()
|
|
||||||
|
|
||||||
# Debug: list existing HNSW server PIDs before starting
|
|
||||||
try:
|
|
||||||
existing = [
|
|
||||||
p
|
|
||||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
|
||||||
if any(
|
|
||||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
|
||||||
for arg in (p.info.get("cmdline") or [])
|
|
||||||
)
|
|
||||||
]
|
|
||||||
if existing:
|
|
||||||
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
|
||||||
for p in existing:
|
|
||||||
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
|
||||||
except Exception as _e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
add_total = 0.0
|
|
||||||
search_after_add = 0.0
|
|
||||||
total_seq = 0.0
|
|
||||||
port_a = None
|
|
||||||
if args.only in ("A", "both"):
|
|
||||||
# Scenario A: sequential update then search
|
|
||||||
start_id = index.ntotal
|
|
||||||
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
|
||||||
if assigned_ids:
|
|
||||||
logger.debug(
|
|
||||||
"Registered %d update passages starting at id %s",
|
|
||||||
len(assigned_ids),
|
|
||||||
assigned_ids[0],
|
|
||||||
)
|
|
||||||
server_manager = EmbeddingServerManager(
|
|
||||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
|
||||||
)
|
|
||||||
ok, port = server_manager.start_server(
|
|
||||||
port=args.server_port,
|
|
||||||
model_name=args.model_name,
|
|
||||||
embedding_mode=args.embedding_mode,
|
|
||||||
passages_file=str(meta_path),
|
|
||||||
distance_metric=args.distance_metric,
|
|
||||||
)
|
|
||||||
if not ok:
|
|
||||||
raise RuntimeError("Failed to start embedding server")
|
|
||||||
try:
|
|
||||||
# Set ZMQ port for recompute mode
|
|
||||||
if hasattr(index.hnsw, "set_zmq_port"):
|
|
||||||
index.hnsw.set_zmq_port(port)
|
|
||||||
elif hasattr(index, "set_zmq_port"):
|
|
||||||
index.set_zmq_port(port)
|
|
||||||
|
|
||||||
# Start A overall timer BEFORE computing update embeddings
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
# Compute embeddings for updates (counted into A's overall)
|
|
||||||
t_emb0 = time.time()
|
|
||||||
upd_embs = compute_embeddings(
|
|
||||||
update_paragraphs,
|
|
||||||
args.model_name,
|
|
||||||
mode=args.embedding_mode,
|
|
||||||
is_build=False,
|
|
||||||
batch_size=16,
|
|
||||||
)
|
|
||||||
emb_time_updates = time.time() - t_emb0
|
|
||||||
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
|
||||||
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
|
||||||
|
|
||||||
# Perform sequential adds
|
|
||||||
for i in range(upd_embs.shape[0]):
|
|
||||||
t_add0 = time.time()
|
|
||||||
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
|
||||||
add_total += time.time() - t_add0
|
|
||||||
# Don't persist index after adds to avoid contaminating Scenario B
|
|
||||||
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
|
||||||
# faiss.write_index(index, str(index_file))
|
|
||||||
|
|
||||||
# Search after updates
|
|
||||||
q_emb = compute_embeddings(
|
|
||||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
|
||||||
)
|
|
||||||
q_emb = np.asarray(q_emb, dtype=np.float32)
|
|
||||||
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
|
||||||
|
|
||||||
# Warm up search with a dummy query first
|
|
||||||
print("[DEBUG] Warming up search...")
|
|
||||||
_ = _search(index, q_emb, 1)
|
|
||||||
|
|
||||||
t_s0 = time.time()
|
|
||||||
D_upd, I_upd = _search(index, q_emb, args.k)
|
|
||||||
search_after_add = time.time() - t_s0
|
|
||||||
total_seq = time.time() - t0
|
|
||||||
finally:
|
|
||||||
server_manager.stop_server()
|
|
||||||
port_a = port
|
|
||||||
|
|
||||||
print("\n=== Scenario A: update->search (sequential) ===")
|
|
||||||
# emb_time_updates is defined only when A runs
|
|
||||||
try:
|
|
||||||
_emb_a = emb_time_updates
|
|
||||||
except NameError:
|
|
||||||
_emb_a = 0.0
|
|
||||||
print(
|
|
||||||
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
|
||||||
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
|
||||||
)
|
|
||||||
# CSV row for A
|
|
||||||
if args.csv_path:
|
|
||||||
row_a = {
|
|
||||||
"run_id": run_id,
|
|
||||||
"scenario": "A",
|
|
||||||
"max_initial": args.max_initial,
|
|
||||||
"num_updates": args.num_updates,
|
|
||||||
"k": args.k,
|
|
||||||
"total_time_s": round(total_seq, 6),
|
|
||||||
"add_total_s": round(add_total, 6),
|
|
||||||
"search_time_s": round(search_after_add, 6),
|
|
||||||
"emb_time_s": round(_emb_a, 6),
|
|
||||||
"makespan_s": 0.0,
|
|
||||||
"model_name": args.model_name,
|
|
||||||
"embedding_mode": args.embedding_mode,
|
|
||||||
"distance_metric": args.distance_metric,
|
|
||||||
}
|
|
||||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
|
||||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
|
||||||
writer.writerow(row_a)
|
|
||||||
|
|
||||||
# Verify server cleanup
|
|
||||||
try:
|
|
||||||
# short sleep to allow signal handling to finish
|
|
||||||
time.sleep(0.5)
|
|
||||||
leftovers = [
|
|
||||||
p
|
|
||||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
|
||||||
if any(
|
|
||||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
|
||||||
for arg in (p.info.get("cmdline") or [])
|
|
||||||
)
|
|
||||||
]
|
|
||||||
if leftovers:
|
|
||||||
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
|
||||||
for p in leftovers:
|
|
||||||
print(
|
|
||||||
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Scenario B: offline embeds + concurrent search (no graph updates)
|
|
||||||
if args.only in ("B", "both"):
|
|
||||||
# ensure a server is available for recompute search
|
|
||||||
server_manager_b = EmbeddingServerManager(
|
|
||||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
|
||||||
)
|
|
||||||
requested_port = args.server_port if port_a is None else port_a
|
|
||||||
ok_b, port_b = server_manager_b.start_server(
|
|
||||||
port=requested_port,
|
|
||||||
model_name=args.model_name,
|
|
||||||
embedding_mode=args.embedding_mode,
|
|
||||||
passages_file=str(meta_path),
|
|
||||||
distance_metric=args.distance_metric,
|
|
||||||
)
|
|
||||||
if not ok_b:
|
|
||||||
raise RuntimeError("Failed to start embedding server for Scenario B")
|
|
||||||
|
|
||||||
# Wait for server to fully initialize
|
|
||||||
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Read the index first
|
|
||||||
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
|
||||||
|
|
||||||
# Then configure ZMQ port on the correct index object
|
|
||||||
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
|
||||||
index_no_update.hnsw.set_zmq_port(port_b)
|
|
||||||
elif hasattr(index_no_update, "set_zmq_port"):
|
|
||||||
index_no_update.set_zmq_port(port_b)
|
|
||||||
|
|
||||||
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
|
||||||
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
|
||||||
logger.info("Warming up embedding model for Scenario B...")
|
|
||||||
_ = compute_embeddings(
|
|
||||||
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare worker A: compute embeddings for the same N passages
|
|
||||||
emb_time = 0.0
|
|
||||||
updates_embs_offline: np.ndarray | None = None
|
|
||||||
|
|
||||||
def _worker_emb():
|
|
||||||
nonlocal emb_time, updates_embs_offline
|
|
||||||
t = time.time()
|
|
||||||
updates_embs_offline = compute_embeddings(
|
|
||||||
update_paragraphs,
|
|
||||||
args.model_name,
|
|
||||||
mode=args.embedding_mode,
|
|
||||||
is_build=False,
|
|
||||||
batch_size=16,
|
|
||||||
)
|
|
||||||
emb_time = time.time() - t
|
|
||||||
|
|
||||||
# Pre-compute query embedding and warm up search outside of timed section.
|
|
||||||
q_vec = compute_embeddings(
|
|
||||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
|
||||||
)
|
|
||||||
q_vec = np.asarray(q_vec, dtype=np.float32)
|
|
||||||
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
|
||||||
print("[DEBUG B] Warming up search...")
|
|
||||||
_ = _search(index_no_update, q_vec, 1)
|
|
||||||
|
|
||||||
# Worker B: timed search on the warmed index
|
|
||||||
search_time = 0.0
|
|
||||||
offline_elapsed = 0.0
|
|
||||||
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
|
||||||
|
|
||||||
def _worker_search():
|
|
||||||
nonlocal search_time, index_results
|
|
||||||
t = time.time()
|
|
||||||
distances, indices = _search(index_no_update, q_vec, args.k)
|
|
||||||
search_time = time.time() - t
|
|
||||||
index_results = (distances, indices)
|
|
||||||
|
|
||||||
# Run two workers concurrently
|
|
||||||
t0 = time.time()
|
|
||||||
th1 = threading.Thread(target=_worker_emb)
|
|
||||||
th2 = threading.Thread(target=_worker_search)
|
|
||||||
th1.start()
|
|
||||||
th2.start()
|
|
||||||
th1.join()
|
|
||||||
th2.join()
|
|
||||||
offline_elapsed = time.time() - t0
|
|
||||||
|
|
||||||
# For mixing: compute query vs. offline update similarities (pure client-side)
|
|
||||||
offline_scores: list[tuple[int, float]] = []
|
|
||||||
if updates_embs_offline is not None:
|
|
||||||
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
|
||||||
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
|
||||||
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
|
||||||
for j in range(upd2.shape[0]):
|
|
||||||
if args.distance_metric in ("mips", "cosine"):
|
|
||||||
s = float(np.dot(q_vec[0], upd2[j]))
|
|
||||||
else:
|
|
||||||
diff = q_vec[0] - upd2[j]
|
|
||||||
s = -float(np.dot(diff, diff))
|
|
||||||
offline_scores.append((j, s))
|
|
||||||
|
|
||||||
merged_topk = (
|
|
||||||
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
|
||||||
if index_results
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
|
||||||
print(
|
|
||||||
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
|
||||||
)
|
|
||||||
if merged_topk:
|
|
||||||
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
|
||||||
print(f"Merged top-5 preview: {preview}")
|
|
||||||
# CSV row for B
|
|
||||||
if args.csv_path:
|
|
||||||
row_b = {
|
|
||||||
"run_id": run_id,
|
|
||||||
"scenario": "B",
|
|
||||||
"max_initial": args.max_initial,
|
|
||||||
"num_updates": args.num_updates,
|
|
||||||
"k": args.k,
|
|
||||||
"total_time_s": 0.0,
|
|
||||||
"add_total_s": 0.0,
|
|
||||||
"search_time_s": round(search_time, 6),
|
|
||||||
"emb_time_s": round(emb_time, 6),
|
|
||||||
"makespan_s": round(offline_elapsed, 6),
|
|
||||||
"model_name": args.model_name,
|
|
||||||
"embedding_mode": args.embedding_mode,
|
|
||||||
"distance_metric": args.distance_metric,
|
|
||||||
}
|
|
||||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
|
||||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
|
||||||
writer.writerow(row_b)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
server_manager_b.stop_server()
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n=== Summary ===")
|
|
||||||
msg_a = (
|
|
||||||
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
|
||||||
if args.only in ("A", "both")
|
|
||||||
else "A: skipped"
|
|
||||||
)
|
|
||||||
msg_b = (
|
|
||||||
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
|
||||||
if args.only in ("B", "both")
|
|
||||||
else "B: skipped"
|
|
||||||
)
|
|
||||||
print(msg_a + "\n" + msg_b)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
|
||||||
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
|
||||||
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
|
||||||
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
|
||||||
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
|
||||||
|
@@ -1,645 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Plot latency bars from the benchmark CSV produced by
|
|
||||||
benchmarks/update/bench_hnsw_rng_recompute.py.
|
|
||||||
|
|
||||||
If you also provide an offline_vs_update.csv via --csv-right
|
|
||||||
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
|
||||||
output a side-by-side figure:
|
|
||||||
- Left: ms/passage bars (four RNG scenarios).
|
|
||||||
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
uv run python benchmarks/update/plot_bench_results.py \
|
|
||||||
--csv benchmarks/update/bench_results.csv \
|
|
||||||
--out benchmarks/update/bench_latency_from_csv.png
|
|
||||||
|
|
||||||
The script selects the latest run_id in the CSV and plots four bars for
|
|
||||||
the default scenarios:
|
|
||||||
- baseline
|
|
||||||
- no_cache_baseline
|
|
||||||
- disable_forward_rng
|
|
||||||
- disable_forward_and_reverse_rng
|
|
||||||
|
|
||||||
If multiple rows exist per scenario for that run_id, the script averages
|
|
||||||
their latency_ms_per_passage values.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import csv
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
DEFAULT_SCENARIOS = [
|
|
||||||
"no_cache_baseline",
|
|
||||||
"baseline",
|
|
||||||
"disable_forward_rng",
|
|
||||||
"disable_forward_and_reverse_rng",
|
|
||||||
]
|
|
||||||
|
|
||||||
SCENARIO_LABELS = {
|
|
||||||
"baseline": "+ Cache",
|
|
||||||
"no_cache_baseline": "Naive \n Recompute",
|
|
||||||
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
|
||||||
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Paper-style colors and hatches for scenarios
|
|
||||||
SCENARIO_STYLES = {
|
|
||||||
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
|
||||||
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
|
||||||
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
|
||||||
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_latest_run(csv_path: Path):
|
|
||||||
rows = []
|
|
||||||
with csv_path.open("r", encoding="utf-8") as f:
|
|
||||||
reader = csv.DictReader(f)
|
|
||||||
for row in reader:
|
|
||||||
rows.append(row)
|
|
||||||
if not rows:
|
|
||||||
raise SystemExit("CSV is empty: no rows to plot")
|
|
||||||
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
|
||||||
run_ids = [r.get("run_id", "") for r in rows]
|
|
||||||
latest = max(run_ids)
|
|
||||||
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
|
||||||
if not latest_rows:
|
|
||||||
# Fallback: take last 4 rows
|
|
||||||
latest_rows = rows[-4:]
|
|
||||||
latest = latest_rows[-1].get("run_id", "unknown")
|
|
||||||
return latest, latest_rows
|
|
||||||
|
|
||||||
|
|
||||||
def aggregate_latency(rows):
|
|
||||||
acc = defaultdict(list)
|
|
||||||
for r in rows:
|
|
||||||
sc = r.get("scenario", "")
|
|
||||||
try:
|
|
||||||
val = float(r.get("latency_ms_per_passage", "nan"))
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
acc[sc].append(val)
|
|
||||||
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
|
||||||
return avg
|
|
||||||
|
|
||||||
|
|
||||||
def _auto_cap(values: list[float]) -> float | None:
|
|
||||||
if not values:
|
|
||||||
return None
|
|
||||||
sorted_vals = sorted(values, reverse=True)
|
|
||||||
if len(sorted_vals) < 2:
|
|
||||||
return None
|
|
||||||
max_v, second = sorted_vals[0], sorted_vals[1]
|
|
||||||
if second <= 0:
|
|
||||||
return None
|
|
||||||
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
|
||||||
if max_v >= 2.5 * second:
|
|
||||||
return second * 1.1
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
|
||||||
# Draw small diagonal ticks near left/right to signal cap
|
|
||||||
x0, x1 = rel_x0, rel_x1
|
|
||||||
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
|
||||||
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
|
||||||
|
|
||||||
|
|
||||||
def _fmt_ms(v: float) -> str:
|
|
||||||
if v >= 1000:
|
|
||||||
return f"{v / 1000:.1f}k"
|
|
||||||
return f"{v:.1f}"
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Set LaTeX style for paper figures (matching paper_fig.py)
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1.5
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
|
|
||||||
ap = argparse.ArgumentParser(description=__doc__)
|
|
||||||
ap.add_argument(
|
|
||||||
"--csv",
|
|
||||||
type=Path,
|
|
||||||
default=Path("benchmarks/update/bench_results.csv"),
|
|
||||||
help="Path to results CSV (defaults to bench_results.csv)",
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--out",
|
|
||||||
type=Path,
|
|
||||||
default=Path("add_ablation.pdf"),
|
|
||||||
help="Output image path",
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--csv-right",
|
|
||||||
type=Path,
|
|
||||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
|
||||||
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--cap-y",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--no-auto-cap",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--broken-y",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--lower-cap-y",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--upper-start-y",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
|
||||||
)
|
|
||||||
args = ap.parse_args()
|
|
||||||
|
|
||||||
latest_run, latest_rows = load_latest_run(args.csv)
|
|
||||||
avg = aggregate_latency(latest_rows)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
except Exception as e:
|
|
||||||
raise SystemExit(f"matplotlib not available: {e}")
|
|
||||||
|
|
||||||
scenarios = DEFAULT_SCENARIOS
|
|
||||||
values = [avg.get(name, 0.0) for name in scenarios]
|
|
||||||
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
|
||||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
|
||||||
|
|
||||||
# If right CSV is provided, build side-by-side figure
|
|
||||||
if args.csv_right is not None:
|
|
||||||
try:
|
|
||||||
right_rows_all = []
|
|
||||||
with args.csv_right.open("r", encoding="utf-8") as f:
|
|
||||||
rreader = csv.DictReader(f)
|
|
||||||
right_rows_all = list(rreader)
|
|
||||||
if right_rows_all:
|
|
||||||
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
|
||||||
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
|
||||||
else:
|
|
||||||
r_latest = None
|
|
||||||
right_rows = []
|
|
||||||
except Exception:
|
|
||||||
r_latest = None
|
|
||||||
right_rows = []
|
|
||||||
|
|
||||||
a_total = 0.0
|
|
||||||
b_makespan = 0.0
|
|
||||||
for r in right_rows:
|
|
||||||
sc = (r.get("scenario", "") or "").strip().upper()
|
|
||||||
if sc == "A":
|
|
||||||
try:
|
|
||||||
a_total = float(r.get("total_time_s", 0.0))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
elif sc == "B":
|
|
||||||
try:
|
|
||||||
b_makespan = float(r.get("makespan_s", 0.0))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib import gridspec
|
|
||||||
|
|
||||||
# Left subplot (reuse current style, with optional cap)
|
|
||||||
cap = args.cap_y
|
|
||||||
if cap is None and not args.no_auto_cap:
|
|
||||||
cap = _auto_cap(values)
|
|
||||||
x = list(range(len(labels)))
|
|
||||||
|
|
||||||
if args.broken_y:
|
|
||||||
# Use broken axis for left subplot
|
|
||||||
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
|
||||||
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
|
||||||
gs = gridspec.GridSpec(
|
|
||||||
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
|
||||||
)
|
|
||||||
ax_left_top = fig.add_subplot(gs[0, 0])
|
|
||||||
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
|
||||||
ax_right = fig.add_subplot(gs[:, 1])
|
|
||||||
|
|
||||||
# Determine break points
|
|
||||||
s = sorted(values, reverse=True)
|
|
||||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
|
||||||
lower_cap = (
|
|
||||||
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
|
||||||
) # Increased to show more range
|
|
||||||
upper_start = (
|
|
||||||
args.upper_start_y
|
|
||||||
if args.upper_start_y is not None
|
|
||||||
else max(second * 1.5, lower_cap * 1.02)
|
|
||||||
)
|
|
||||||
ymax = (
|
|
||||||
max(values) * 1.90 if values else 1.0
|
|
||||||
) # Increase headroom to 1.90 for text label and tick range
|
|
||||||
|
|
||||||
# Draw bars on both axes
|
|
||||||
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
|
||||||
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
|
||||||
|
|
||||||
# Set limits
|
|
||||||
ax_left_bottom.set_ylim(0, lower_cap)
|
|
||||||
ax_left_top.set_ylim(upper_start, ymax)
|
|
||||||
|
|
||||||
# Annotate values (convert ms to s)
|
|
||||||
values_s = [v / 1000.0 for v in values]
|
|
||||||
lower_cap_s = lower_cap / 1000.0
|
|
||||||
upper_start_s = upper_start / 1000.0
|
|
||||||
ymax_s = ymax / 1000.0
|
|
||||||
|
|
||||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
|
||||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
|
||||||
|
|
||||||
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
|
||||||
ax_left_bottom.clear()
|
|
||||||
ax_left_top.clear()
|
|
||||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
|
||||||
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
|
||||||
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
|
||||||
# Draw in bottom axis for all bars
|
|
||||||
ax_left_bottom.bar(
|
|
||||||
i,
|
|
||||||
v,
|
|
||||||
width=bar_width,
|
|
||||||
color="white",
|
|
||||||
edgecolor=style["edgecolor"],
|
|
||||||
hatch=style["hatch"],
|
|
||||||
linewidth=1.2,
|
|
||||||
)
|
|
||||||
# Only draw in top axis if the bar is tall enough to reach the upper range
|
|
||||||
if v > upper_start_s:
|
|
||||||
ax_left_top.bar(
|
|
||||||
i,
|
|
||||||
v,
|
|
||||||
width=bar_width,
|
|
||||||
color="white",
|
|
||||||
edgecolor=style["edgecolor"],
|
|
||||||
hatch=style["hatch"],
|
|
||||||
linewidth=1.2,
|
|
||||||
)
|
|
||||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
|
||||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
|
||||||
|
|
||||||
for i, v in enumerate(values_s):
|
|
||||||
if v <= lower_cap_s:
|
|
||||||
ax_left_bottom.text(
|
|
||||||
i,
|
|
||||||
v + lower_cap_s * 0.02,
|
|
||||||
f"{v:.2f}",
|
|
||||||
ha="center",
|
|
||||||
va="bottom",
|
|
||||||
fontsize=8,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ax_left_top.text(
|
|
||||||
i,
|
|
||||||
v + (ymax_s - upper_start_s) * 0.02,
|
|
||||||
f"{v:.2f}",
|
|
||||||
ha="center",
|
|
||||||
va="bottom",
|
|
||||||
fontsize=8,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hide spines between axes
|
|
||||||
ax_left_top.spines["bottom"].set_visible(False)
|
|
||||||
ax_left_bottom.spines["top"].set_visible(False)
|
|
||||||
ax_left_top.tick_params(
|
|
||||||
labeltop=False, labelbottom=False, bottom=False
|
|
||||||
) # Hide tick marks
|
|
||||||
ax_left_bottom.xaxis.tick_bottom()
|
|
||||||
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
|
||||||
|
|
||||||
# Draw break marks (matching paper_fig.py style)
|
|
||||||
d = 0.015
|
|
||||||
kwargs = {
|
|
||||||
"transform": ax_left_top.transAxes,
|
|
||||||
"color": "k",
|
|
||||||
"clip_on": False,
|
|
||||||
"linewidth": 0.8,
|
|
||||||
"zorder": 10,
|
|
||||||
}
|
|
||||||
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
|
||||||
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
|
||||||
kwargs.update({"transform": ax_left_bottom.transAxes})
|
|
||||||
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
|
||||||
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
|
||||||
|
|
||||||
ax_left_bottom.set_xticks(x)
|
|
||||||
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
|
||||||
# Don't set ylabel here - will use fig.text for alignment
|
|
||||||
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
|
||||||
ax_left_top.tick_params(axis="y", labelsize=10)
|
|
||||||
# Add subtle grid for better readability
|
|
||||||
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
|
||||||
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
|
||||||
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
|
||||||
|
|
||||||
# Set x-axis limits to match bar width with right subplot
|
|
||||||
ax_left_bottom.set_xlim(-0.6, 3.6)
|
|
||||||
ax_left_top.set_xlim(-0.6, 3.6)
|
|
||||||
|
|
||||||
ax_left = ax_left_bottom # for compatibility
|
|
||||||
else:
|
|
||||||
# Regular side-by-side layout
|
|
||||||
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
|
||||||
|
|
||||||
if cap is not None:
|
|
||||||
show_vals = [min(v, cap) for v in values]
|
|
||||||
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
|
||||||
for i, (val, show) in enumerate(zip(values, show_vals)):
|
|
||||||
if val > cap:
|
|
||||||
bars[i].set_hatch("//")
|
|
||||||
ax_left.text(
|
|
||||||
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ax_left.text(
|
|
||||||
i,
|
|
||||||
show + max(1.0, 0.01 * (cap or show)),
|
|
||||||
_fmt_ms(val),
|
|
||||||
ha="center",
|
|
||||||
va="bottom",
|
|
||||||
fontsize=9,
|
|
||||||
)
|
|
||||||
ax_left.set_ylim(0, cap * 1.10)
|
|
||||||
_add_break_marker(ax_left, y=0.98)
|
|
||||||
ax_left.set_xticks(x)
|
|
||||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
|
||||||
else:
|
|
||||||
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
|
||||||
for i, v in enumerate(values):
|
|
||||||
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
|
||||||
ax_left.set_xticks(x)
|
|
||||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
|
||||||
ax_left.set_ylabel("Latency (ms per passage)")
|
|
||||||
max_initial = latest_rows[0].get("max_initial", "?")
|
|
||||||
max_updates = latest_rows[0].get("max_updates", "?")
|
|
||||||
ax_left.set_title(
|
|
||||||
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Right subplot (A vs B, seconds) - paper style
|
|
||||||
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
|
||||||
r_values = [a_total or 0.0, b_makespan or 0.0]
|
|
||||||
r_styles = [
|
|
||||||
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
|
||||||
{"edgecolor": "#edc948", "hatch": "/////"},
|
|
||||||
]
|
|
||||||
# 2 bars, centered with proper spacing
|
|
||||||
xr = [0, 1]
|
|
||||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
|
||||||
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
|
||||||
ax_right.bar(
|
|
||||||
xr[i],
|
|
||||||
v,
|
|
||||||
width=bar_width,
|
|
||||||
color="white",
|
|
||||||
edgecolor=style["edgecolor"],
|
|
||||||
hatch=style["hatch"],
|
|
||||||
linewidth=1.2,
|
|
||||||
)
|
|
||||||
for i, v in enumerate(r_values):
|
|
||||||
max_v = max(r_values) if r_values else 1.0
|
|
||||||
offset = max(0.0002, 0.02 * max_v)
|
|
||||||
ax_right.text(
|
|
||||||
xr[i],
|
|
||||||
v + offset,
|
|
||||||
f"{v:.2f}",
|
|
||||||
ha="center",
|
|
||||||
va="bottom",
|
|
||||||
fontsize=8,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
ax_right.set_xticks(xr)
|
|
||||||
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
|
||||||
# Don't set ylabel here - will use fig.text for alignment
|
|
||||||
ax_right.tick_params(axis="y", labelsize=10)
|
|
||||||
# Add subtle grid for better readability
|
|
||||||
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
|
||||||
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
|
||||||
|
|
||||||
# Set x-axis limits to match left subplot's bar width visually
|
|
||||||
# Accounting for width_ratios=[1.5, 1]:
|
|
||||||
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
|
||||||
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
|
||||||
# Right: 2 bars, need same visual width
|
|
||||||
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
|
||||||
# range_right = 4.2 / 1.5 = 2.8
|
|
||||||
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
|
||||||
ax_right.set_xlim(-0.9, 1.9)
|
|
||||||
|
|
||||||
# Set y-axis limit with headroom for text labels
|
|
||||||
if r_values:
|
|
||||||
max_v = max(r_values)
|
|
||||||
ax_right.set_ylim(0, max_v * 1.15)
|
|
||||||
|
|
||||||
# Format y-axis to avoid scientific notation
|
|
||||||
ax_right.ticklabel_format(style="plain", axis="y")
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# Add aligned ylabels using fig.text (after tight_layout)
|
|
||||||
# Get the vertical center of the entire figure
|
|
||||||
fig_center_y = 0.5
|
|
||||||
# Left ylabel - closer to left plot
|
|
||||||
left_x = 0.05
|
|
||||||
fig.text(
|
|
||||||
left_x,
|
|
||||||
fig_center_y,
|
|
||||||
"Latency (s)",
|
|
||||||
va="center",
|
|
||||||
rotation="vertical",
|
|
||||||
fontsize=11,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
# Right ylabel - closer to right plot
|
|
||||||
right_bbox = ax_right.get_position()
|
|
||||||
right_x = right_bbox.x0 - 0.07
|
|
||||||
fig.text(
|
|
||||||
right_x,
|
|
||||||
fig_center_y,
|
|
||||||
"Latency (s)",
|
|
||||||
va="center",
|
|
||||||
rotation="vertical",
|
|
||||||
fontsize=11,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
|
||||||
# Also save PDF for paper
|
|
||||||
pdf_out = args.out.with_suffix(".pdf")
|
|
||||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
|
||||||
print(f"Saved: {args.out}")
|
|
||||||
print(f"Saved: {pdf_out}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Broken-Y mode
|
|
||||||
if args.broken_y:
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
sharex=True,
|
|
||||||
figsize=(7.5, 6.75),
|
|
||||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine default breaks from second-highest
|
|
||||||
s = sorted(values, reverse=True)
|
|
||||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
|
||||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
|
||||||
upper_start = (
|
|
||||||
args.upper_start_y
|
|
||||||
if args.upper_start_y is not None
|
|
||||||
else max(second * 1.2, lower_cap * 1.02)
|
|
||||||
)
|
|
||||||
ymax = max(values) * 1.10 if values else 1.0
|
|
||||||
|
|
||||||
x = list(range(len(labels)))
|
|
||||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
|
||||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
|
||||||
|
|
||||||
# Limits
|
|
||||||
ax_bottom.set_ylim(0, lower_cap)
|
|
||||||
ax_top.set_ylim(upper_start, ymax)
|
|
||||||
|
|
||||||
# Annotate values
|
|
||||||
for i, v in enumerate(values):
|
|
||||||
if v <= lower_cap:
|
|
||||||
ax_bottom.text(
|
|
||||||
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
|
||||||
|
|
||||||
# Hide spines between axes and draw diagonal break marks
|
|
||||||
ax_top.spines["bottom"].set_visible(False)
|
|
||||||
ax_bottom.spines["top"].set_visible(False)
|
|
||||||
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
|
||||||
ax_bottom.xaxis.tick_bottom()
|
|
||||||
|
|
||||||
# Diagonal lines at the break (matching paper_fig.py style)
|
|
||||||
d = 0.015
|
|
||||||
kwargs = {
|
|
||||||
"transform": ax_top.transAxes,
|
|
||||||
"color": "k",
|
|
||||||
"clip_on": False,
|
|
||||||
"linewidth": 0.8,
|
|
||||||
"zorder": 10,
|
|
||||||
}
|
|
||||||
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
|
||||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
|
||||||
kwargs.update({"transform": ax_bottom.transAxes})
|
|
||||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
|
||||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
|
||||||
|
|
||||||
ax_bottom.set_xticks(x)
|
|
||||||
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
|
||||||
ax = ax_bottom # for labeling below
|
|
||||||
else:
|
|
||||||
cap = args.cap_y
|
|
||||||
if cap is None and not args.no_auto_cap:
|
|
||||||
cap = _auto_cap(values)
|
|
||||||
|
|
||||||
plt.figure(figsize=(5.4, 3.15))
|
|
||||||
ax = plt.gca()
|
|
||||||
|
|
||||||
if cap is not None:
|
|
||||||
show_vals = [min(v, cap) for v in values]
|
|
||||||
bars = []
|
|
||||||
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
|
||||||
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
|
||||||
bars.append(bar[0])
|
|
||||||
# Hatch and annotate when capped
|
|
||||||
if val > cap:
|
|
||||||
bars[-1].set_hatch("//")
|
|
||||||
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
|
||||||
else:
|
|
||||||
ax.text(
|
|
||||||
i,
|
|
||||||
show + max(1.0, 0.01 * (cap or show)),
|
|
||||||
f"{_fmt_ms(val)}",
|
|
||||||
ha="center",
|
|
||||||
va="bottom",
|
|
||||||
fontsize=9,
|
|
||||||
)
|
|
||||||
ax.set_ylim(0, cap * 1.10)
|
|
||||||
_add_break_marker(ax, y=0.98)
|
|
||||||
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
|
||||||
v > cap for v in values
|
|
||||||
) else None
|
|
||||||
ax.set_xticks(range(len(labels)))
|
|
||||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
|
||||||
else:
|
|
||||||
ax.bar(labels, values, color=colors[: len(labels)])
|
|
||||||
for idx, val in enumerate(values):
|
|
||||||
ax.text(
|
|
||||||
idx,
|
|
||||||
val + 1.0,
|
|
||||||
f"{_fmt_ms(val)}",
|
|
||||||
ha="center",
|
|
||||||
va="bottom",
|
|
||||||
fontsize=10,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
|
||||||
# Try to extract some context for title
|
|
||||||
max_initial = latest_rows[0].get("max_initial", "?")
|
|
||||||
max_updates = latest_rows[0].get("max_updates", "?")
|
|
||||||
|
|
||||||
if args.broken_y:
|
|
||||||
fig.text(
|
|
||||||
0.02,
|
|
||||||
0.5,
|
|
||||||
"Latency (s)",
|
|
||||||
va="center",
|
|
||||||
rotation="vertical",
|
|
||||||
fontsize=11,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
fig.suptitle(
|
|
||||||
"Add Operation Latency",
|
|
||||||
fontsize=11,
|
|
||||||
y=0.98,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
|
||||||
else:
|
|
||||||
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
|
||||||
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
|
||||||
# Also save PDF for paper
|
|
||||||
pdf_out = args.out.with_suffix(".pdf")
|
|
||||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
|
||||||
print(f"Saved: {args.out}")
|
|
||||||
print(f"Saved: {pdf_out}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -455,5 +455,5 @@ Conclusion:
|
|||||||
|
|
||||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||||
- [DiskANN Original Paper](https://suhasjs.github.io/files/diskann_neurips19.pdf)
|
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||||
|
|||||||
@@ -1,395 +0,0 @@
|
|||||||
# Slack Integration Setup Guide
|
|
||||||
|
|
||||||
This guide provides step-by-step instructions for setting up Slack integration with LEANN.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
LEANN's Slack integration uses MCP (Model Context Protocol) servers to fetch and index your Slack messages for RAG (Retrieval-Augmented Generation). This allows you to search through your Slack conversations using natural language queries.
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
1. **Slack Workspace Access**: You need admin or owner permissions in your Slack workspace to create apps and configure OAuth tokens.
|
|
||||||
|
|
||||||
2. **Slack MCP Server**: Install a Slack MCP server (e.g., `slack-mcp-server` via npm)
|
|
||||||
|
|
||||||
3. **LEANN**: Ensure you have LEANN installed and working
|
|
||||||
|
|
||||||
## Step 1: Create a Slack App
|
|
||||||
|
|
||||||
### 1.1 Go to Slack API Dashboard
|
|
||||||
|
|
||||||
1. Visit [https://api.slack.com/apps](https://api.slack.com/apps)
|
|
||||||
2. Click **"Create New App"**
|
|
||||||
3. Choose **"From scratch"**
|
|
||||||
4. Enter your app name (e.g., "LEANN Slack Integration")
|
|
||||||
5. Select your workspace
|
|
||||||
6. Click **"Create App"**
|
|
||||||
|
|
||||||
### 1.2 Configure App Permissions
|
|
||||||
|
|
||||||
#### Token Scopes
|
|
||||||
|
|
||||||
1. In your app dashboard, go to **"OAuth & Permissions"** in the left sidebar
|
|
||||||
2. Scroll down to **"Scopes"** section
|
|
||||||
3. Under **"Bot Token Scopes & OAuth Scope"**, click **"Add an OAuth Scope"**
|
|
||||||
4. Add the following scopes:
|
|
||||||
- `channels:read` - Read public channel information
|
|
||||||
- `channels:history` - Read messages in public channels
|
|
||||||
- `groups:read` - Read private channel information
|
|
||||||
- `groups:history` - Read messages in private channels
|
|
||||||
- `im:read` - Read direct message information
|
|
||||||
- `im:history` - Read direct messages
|
|
||||||
- `mpim:read` - Read group direct message information
|
|
||||||
- `mpim:history` - Read group direct messages
|
|
||||||
- `users:read` - Read user information
|
|
||||||
- `team:read` - Read workspace information
|
|
||||||
|
|
||||||
#### App-Level Tokens (Optional)
|
|
||||||
|
|
||||||
Some MCP servers may require app-level tokens:
|
|
||||||
|
|
||||||
1. Go to **"Basic Information"** in the left sidebar
|
|
||||||
2. Scroll down to **"App-Level Tokens"**
|
|
||||||
3. Click **"Generate Token and Scopes"**
|
|
||||||
4. Enter a name (e.g., "LEANN Integration")
|
|
||||||
5. Add the `connections:write` scope
|
|
||||||
6. Click **"Generate"**
|
|
||||||
7. Copy the token (starts with `xapp-`)
|
|
||||||
|
|
||||||
### 1.3 Install App to Workspace
|
|
||||||
|
|
||||||
1. Go to **"OAuth & Permissions"** in the left sidebar
|
|
||||||
2. Click **"Install to Workspace"**
|
|
||||||
3. Review the permissions and click **"Allow"**
|
|
||||||
4. Copy the **"Bot User OAuth Token"** (starts with `xoxb-`)
|
|
||||||
5. Copy the **"User OAuth Token"** (starts with `xoxp-`)
|
|
||||||
|
|
||||||
## Step 2: Install Slack MCP Server
|
|
||||||
|
|
||||||
### Option A: Using npm (Recommended)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install globally
|
|
||||||
npm install -g slack-mcp-server
|
|
||||||
|
|
||||||
# Or install locally
|
|
||||||
npm install slack-mcp-server
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option B: Using npx (No installation required)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Use directly without installation
|
|
||||||
npx slack-mcp-server
|
|
||||||
```
|
|
||||||
|
|
||||||
## Step 3: Install and Configure Ollama (for Real LLM Responses)
|
|
||||||
|
|
||||||
### 3.1 Install Ollama
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install Ollama using Homebrew (macOS)
|
|
||||||
brew install ollama
|
|
||||||
|
|
||||||
# Or download from https://ollama.ai/
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.2 Start Ollama Service
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Start Ollama as a service
|
|
||||||
brew services start ollama
|
|
||||||
|
|
||||||
# Or start manually
|
|
||||||
ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.3 Pull a Model
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Pull a lightweight model for testing
|
|
||||||
ollama pull llama3.2:1b
|
|
||||||
|
|
||||||
# Verify the model is available
|
|
||||||
ollama list
|
|
||||||
```
|
|
||||||
|
|
||||||
## Step 4: Configure Environment Variables
|
|
||||||
|
|
||||||
Create a `.env` file or set environment variables:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Required: User OAuth Token
|
|
||||||
SLACK_OAUTH_TOKEN=xoxp-your-user-oauth-token-here
|
|
||||||
|
|
||||||
# Optional: App-Level Token (if your MCP server requires it)
|
|
||||||
SLACK_APP_TOKEN=xapp-your-app-token-here
|
|
||||||
|
|
||||||
# Optional: Workspace-specific settings
|
|
||||||
SLACK_WORKSPACE_ID=T1234567890 # Your workspace ID (optional)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Step 5: Test the Setup
|
|
||||||
|
|
||||||
### 5.1 Test MCP Server Connection
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--test-connection \
|
|
||||||
--workspace-name "Your Workspace Name"
|
|
||||||
```
|
|
||||||
|
|
||||||
This will test the connection and list available tools without indexing any data.
|
|
||||||
|
|
||||||
### 5.2 Index a Specific Channel
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--workspace-name "Your Workspace Name" \
|
|
||||||
--channels general \
|
|
||||||
--query "What did we discuss about the project?"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.3 Real RAG Query Examples
|
|
||||||
|
|
||||||
This section demonstrates successful Slack RAG integration queries against the Sky Lab Computing workspace's "random" channel. The system successfully retrieves actual conversation messages and performs semantic search with high relevance scores, including finding specific research paper announcements and technical discussions.
|
|
||||||
|
|
||||||
### Example 1: Advisor Models Query
|
|
||||||
|
|
||||||
**Query:** "train black-box models to adopt to your personal data"
|
|
||||||
|
|
||||||
This query demonstrates the system's ability to find specific research announcements about training black-box models for personal data adaptation.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### Example 2: Barbarians at the Gate Query
|
|
||||||
|
|
||||||
**Query:** "AI-driven research systems ADRS"
|
|
||||||
|
|
||||||
This query demonstrates the system's ability to find specific research announcements about AI-driven research systems and algorithm discovery.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
- Bot is installed in the Sky Lab Computing workspace and invited to the target channel (run `/invite @YourBotName` in the channel if needed)
|
|
||||||
- Bot token available and exported in the same terminal session
|
|
||||||
|
|
||||||
### Commands
|
|
||||||
|
|
||||||
1) Set the workspace token for this shell
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export SLACK_MCP_XOXP_TOKEN="xoxp-***-redacted-***"
|
|
||||||
```
|
|
||||||
|
|
||||||
2) Run queries against the "random" channel by channel ID (C0GN5BX0F)
|
|
||||||
|
|
||||||
**Advisor Models Query:**
|
|
||||||
```bash
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--workspace-name "Sky Lab Computing" \
|
|
||||||
--channels C0GN5BX0F \
|
|
||||||
--max-messages-per-channel 100000 \
|
|
||||||
--query "train black-box models to adopt to your personal data" \
|
|
||||||
--llm ollama \
|
|
||||||
--llm-model "llama3.2:1b" \
|
|
||||||
--llm-host "http://localhost:11434" \
|
|
||||||
--no-concatenate-conversations
|
|
||||||
```
|
|
||||||
|
|
||||||
**Barbarians at the Gate Query:**
|
|
||||||
```bash
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--workspace-name "Sky Lab Computing" \
|
|
||||||
--channels C0GN5BX0F \
|
|
||||||
--max-messages-per-channel 100000 \
|
|
||||||
--query "AI-driven research systems ADRS" \
|
|
||||||
--llm ollama \
|
|
||||||
--llm-model "llama3.2:1b" \
|
|
||||||
--llm-host "http://localhost:11434" \
|
|
||||||
--no-concatenate-conversations
|
|
||||||
```
|
|
||||||
|
|
||||||
These examples demonstrate the system's ability to find and retrieve specific research announcements and technical discussions from the conversation history, showcasing the power of semantic search in Slack data.
|
|
||||||
|
|
||||||
3) Optional: Ask a broader question
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python test_channel_by_id_or_name.py \
|
|
||||||
--channel-id C0GN5BX0F \
|
|
||||||
--workspace-name "Sky Lab Computing" \
|
|
||||||
--query "What is LEANN about?"
|
|
||||||
```
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- If you see `not_in_channel`, invite the bot to the channel and re-run.
|
|
||||||
- If you see `channel_not_found`, confirm the channel ID and workspace.
|
|
||||||
- Deep search via server-side “search” tools may require additional Slack scopes; the example above performs client-side filtering over retrieved history.
|
|
||||||
|
|
||||||
## Common Issues and Solutions
|
|
||||||
|
|
||||||
### Issue 1: "users cache is not ready yet" Error
|
|
||||||
|
|
||||||
**Problem**: You see this warning:
|
|
||||||
```
|
|
||||||
WARNING - Failed to fetch messages from channel random: Failed to fetch messages: {'code': -32603, 'message': 'users cache is not ready yet, sync process is still running... please wait'}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Solution**: This is a common timing issue. The LEANN integration now includes automatic retry logic:
|
|
||||||
|
|
||||||
1. **Wait and Retry**: The system will automatically retry with exponential backoff (2s, 4s, 8s, etc.)
|
|
||||||
2. **Increase Retry Parameters**: If needed, you can customize retry behavior:
|
|
||||||
```bash
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--max-retries 10 \
|
|
||||||
--retry-delay 3.0 \
|
|
||||||
--channels general \
|
|
||||||
--query "Your query here"
|
|
||||||
```
|
|
||||||
3. **Keep MCP Server Running**: Start the MCP server separately and keep it running:
|
|
||||||
```bash
|
|
||||||
# Terminal 1: Start MCP server
|
|
||||||
slack-mcp-server
|
|
||||||
|
|
||||||
# Terminal 2: Run LEANN (it will connect to the running server)
|
|
||||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --channels general --query "test"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Issue 2: "No message fetching tool found"
|
|
||||||
|
|
||||||
**Problem**: The MCP server doesn't have the expected tools.
|
|
||||||
|
|
||||||
**Solution**:
|
|
||||||
1. Check if your MCP server is properly installed and configured
|
|
||||||
2. Verify your Slack tokens are correct
|
|
||||||
3. Try a different MCP server implementation
|
|
||||||
4. Check the MCP server documentation for required configuration
|
|
||||||
|
|
||||||
### Issue 3: Permission Denied Errors
|
|
||||||
|
|
||||||
**Problem**: You get permission errors when trying to access channels.
|
|
||||||
|
|
||||||
**Solutions**:
|
|
||||||
1. **Check Bot Permissions**: Ensure your bot has been added to the channels you want to access
|
|
||||||
2. **Verify Token Scopes**: Make sure you have all required scopes configured
|
|
||||||
3. **Channel Access**: For private channels, the bot needs to be explicitly invited
|
|
||||||
4. **Workspace Permissions**: Ensure your Slack app has the necessary workspace permissions
|
|
||||||
|
|
||||||
### Issue 4: Empty Results
|
|
||||||
|
|
||||||
**Problem**: No messages are returned even though the channel has messages.
|
|
||||||
|
|
||||||
**Solutions**:
|
|
||||||
1. **Check Channel Names**: Ensure channel names are correct (without the # symbol)
|
|
||||||
2. **Verify Bot Access**: Make sure the bot can access the channels
|
|
||||||
3. **Check Date Ranges**: Some MCP servers have limitations on message history
|
|
||||||
4. **Increase Message Limits**: Try increasing the message limit:
|
|
||||||
```bash
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--channels general \
|
|
||||||
--max-messages-per-channel 1000 \
|
|
||||||
--query "test"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Advanced Configuration
|
|
||||||
|
|
||||||
### Custom MCP Server Commands
|
|
||||||
|
|
||||||
If you need to pass additional parameters to your MCP server:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server --token-file /path/to/tokens.json" \
|
|
||||||
--workspace-name "Your Workspace" \
|
|
||||||
--channels general \
|
|
||||||
--query "Your query"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multiple Workspaces
|
|
||||||
|
|
||||||
To work with multiple Slack workspaces, you can:
|
|
||||||
|
|
||||||
1. Create separate apps for each workspace
|
|
||||||
2. Use different environment variables
|
|
||||||
3. Run separate instances with different configurations
|
|
||||||
|
|
||||||
### Performance Optimization
|
|
||||||
|
|
||||||
For better performance with large workspaces:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--workspace-name "Your Workspace" \
|
|
||||||
--max-messages-per-channel 500 \
|
|
||||||
--no-concatenate-conversations \
|
|
||||||
--query "Your query"
|
|
||||||
```
|
|
||||||
---
|
|
||||||
|
|
||||||
## Troubleshooting Checklist
|
|
||||||
|
|
||||||
- [ ] Slack app created with proper permissions
|
|
||||||
- [ ] Bot token (xoxb-) copied correctly
|
|
||||||
- [ ] App-level token (xapp-) created if needed
|
|
||||||
- [ ] MCP server installed and accessible
|
|
||||||
- [ ] Ollama installed and running (`brew services start ollama`)
|
|
||||||
- [ ] Ollama model pulled (`ollama pull llama3.2:1b`)
|
|
||||||
- [ ] Environment variables set correctly
|
|
||||||
- [ ] Bot invited to relevant channels
|
|
||||||
- [ ] Channel names specified without # symbol
|
|
||||||
- [ ] Sufficient retry attempts configured
|
|
||||||
- [ ] Network connectivity to Slack APIs
|
|
||||||
|
|
||||||
## Getting Help
|
|
||||||
|
|
||||||
If you continue to have issues:
|
|
||||||
|
|
||||||
1. **Check Logs**: Look for detailed error messages in the console output
|
|
||||||
2. **Test MCP Server**: Use `--test-connection` to verify the MCP server is working
|
|
||||||
3. **Verify Tokens**: Double-check that your Slack tokens are valid and have the right scopes
|
|
||||||
4. **Check Ollama**: Ensure Ollama is running (`ollama serve`) and the model is available (`ollama list`)
|
|
||||||
5. **Community Support**: Reach out to the LEANN community for help
|
|
||||||
|
|
||||||
## Example Commands
|
|
||||||
|
|
||||||
### Basic Usage
|
|
||||||
```bash
|
|
||||||
# Test connection
|
|
||||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
|
|
||||||
|
|
||||||
# Index specific channels
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--workspace-name "My Company" \
|
|
||||||
--channels general random \
|
|
||||||
--query "What did we decide about the project timeline?"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Advanced Usage
|
|
||||||
```bash
|
|
||||||
# With custom retry settings
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--workspace-name "My Company" \
|
|
||||||
--channels general \
|
|
||||||
--max-retries 10 \
|
|
||||||
--retry-delay 5.0 \
|
|
||||||
--max-messages-per-channel 2000 \
|
|
||||||
--query "Show me all decisions made in the last month"
|
|
||||||
```
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 445 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 508 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 437 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 474 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 501 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 454 KiB |
@@ -1,178 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
MCP Integration Examples for LEANN
|
|
||||||
|
|
||||||
This script demonstrates how to use LEANN with different MCP servers for
|
|
||||||
RAG on various platforms like Slack and Twitter.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
1. Slack message RAG via MCP
|
|
||||||
2. Twitter bookmark RAG via MCP
|
|
||||||
3. Testing MCP server connections
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add the parent directory to the path so we can import from apps
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
|
|
||||||
async def demo_slack_mcp():
|
|
||||||
"""Demonstrate Slack MCP integration."""
|
|
||||||
print("=" * 60)
|
|
||||||
print("🔥 Slack MCP RAG Demo")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
print("\n1. Testing Slack MCP server connection...")
|
|
||||||
|
|
||||||
# This would typically use a real MCP server command
|
|
||||||
# For demo purposes, we show what the command would look like
|
|
||||||
# slack_app = SlackMCPRAG() # Would be used for actual testing
|
|
||||||
|
|
||||||
# Simulate command line arguments for testing
|
|
||||||
class MockArgs:
|
|
||||||
mcp_server = "slack-mcp-server" # This would be the actual MCP server command
|
|
||||||
workspace_name = "my-workspace"
|
|
||||||
channels = ["general", "random", "dev-team"]
|
|
||||||
no_concatenate_conversations = False
|
|
||||||
max_messages_per_channel = 50
|
|
||||||
test_connection = True
|
|
||||||
|
|
||||||
print(f"MCP Server Command: {MockArgs.mcp_server}")
|
|
||||||
print(f"Workspace: {MockArgs.workspace_name}")
|
|
||||||
print(f"Channels: {', '.join(MockArgs.channels)}")
|
|
||||||
|
|
||||||
# In a real scenario, you would run:
|
|
||||||
# success = await slack_app.test_mcp_connection(MockArgs)
|
|
||||||
|
|
||||||
print("\n📝 Example usage:")
|
|
||||||
print("python -m apps.slack_rag \\")
|
|
||||||
print(" --mcp-server 'slack-mcp-server' \\")
|
|
||||||
print(" --workspace-name 'my-team' \\")
|
|
||||||
print(" --channels general dev-team \\")
|
|
||||||
print(" --test-connection")
|
|
||||||
|
|
||||||
print("\n🔍 After indexing, you could query:")
|
|
||||||
print("- 'What did the team discuss about the project deadline?'")
|
|
||||||
print("- 'Find messages about the new feature launch'")
|
|
||||||
print("- 'Show me conversations about budget planning'")
|
|
||||||
|
|
||||||
|
|
||||||
async def demo_twitter_mcp():
|
|
||||||
"""Demonstrate Twitter MCP integration."""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🐦 Twitter MCP RAG Demo")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
print("\n1. Testing Twitter MCP server connection...")
|
|
||||||
|
|
||||||
# twitter_app = TwitterMCPRAG() # Would be used for actual testing
|
|
||||||
|
|
||||||
class MockArgs:
|
|
||||||
mcp_server = "twitter-mcp-server"
|
|
||||||
username = None # Fetch all bookmarks
|
|
||||||
max_bookmarks = 500
|
|
||||||
no_tweet_content = False
|
|
||||||
no_metadata = False
|
|
||||||
test_connection = True
|
|
||||||
|
|
||||||
print(f"MCP Server Command: {MockArgs.mcp_server}")
|
|
||||||
print(f"Max Bookmarks: {MockArgs.max_bookmarks}")
|
|
||||||
print(f"Include Content: {not MockArgs.no_tweet_content}")
|
|
||||||
print(f"Include Metadata: {not MockArgs.no_metadata}")
|
|
||||||
|
|
||||||
print("\n📝 Example usage:")
|
|
||||||
print("python -m apps.twitter_rag \\")
|
|
||||||
print(" --mcp-server 'twitter-mcp-server' \\")
|
|
||||||
print(" --max-bookmarks 1000 \\")
|
|
||||||
print(" --test-connection")
|
|
||||||
|
|
||||||
print("\n🔍 After indexing, you could query:")
|
|
||||||
print("- 'What AI articles did I bookmark last month?'")
|
|
||||||
print("- 'Find tweets about machine learning techniques'")
|
|
||||||
print("- 'Show me bookmarked threads about startup advice'")
|
|
||||||
|
|
||||||
|
|
||||||
async def show_mcp_server_setup():
|
|
||||||
"""Show how to set up MCP servers."""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("⚙️ MCP Server Setup Guide")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
print("\n🔧 Setting up Slack MCP Server:")
|
|
||||||
print("1. Install a Slack MCP server (example commands):")
|
|
||||||
print(" npm install -g slack-mcp-server")
|
|
||||||
print(" # OR")
|
|
||||||
print(" pip install slack-mcp-server")
|
|
||||||
|
|
||||||
print("\n2. Configure Slack credentials:")
|
|
||||||
print(" export SLACK_BOT_TOKEN='xoxb-your-bot-token'")
|
|
||||||
print(" export SLACK_APP_TOKEN='xapp-your-app-token'")
|
|
||||||
|
|
||||||
print("\n3. Test the server:")
|
|
||||||
print(" slack-mcp-server --help")
|
|
||||||
|
|
||||||
print("\n🔧 Setting up Twitter MCP Server:")
|
|
||||||
print("1. Install a Twitter MCP server:")
|
|
||||||
print(" npm install -g twitter-mcp-server")
|
|
||||||
print(" # OR")
|
|
||||||
print(" pip install twitter-mcp-server")
|
|
||||||
|
|
||||||
print("\n2. Configure Twitter API credentials:")
|
|
||||||
print(" export TWITTER_API_KEY='your-api-key'")
|
|
||||||
print(" export TWITTER_API_SECRET='your-api-secret'")
|
|
||||||
print(" export TWITTER_ACCESS_TOKEN='your-access-token'")
|
|
||||||
print(" export TWITTER_ACCESS_TOKEN_SECRET='your-access-token-secret'")
|
|
||||||
|
|
||||||
print("\n3. Test the server:")
|
|
||||||
print(" twitter-mcp-server --help")
|
|
||||||
|
|
||||||
|
|
||||||
async def show_integration_benefits():
|
|
||||||
"""Show the benefits of MCP integration."""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🌟 Benefits of MCP Integration")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
benefits = [
|
|
||||||
("🔄 Live Data Access", "Fetch real-time data from platforms without manual exports"),
|
|
||||||
("🔌 Standardized Protocol", "Use any MCP-compatible server with minimal code changes"),
|
|
||||||
("🚀 Easy Extension", "Add new platforms by implementing MCP readers"),
|
|
||||||
("🔒 Secure Access", "MCP servers handle authentication and API management"),
|
|
||||||
("📊 Rich Metadata", "Access full platform metadata (timestamps, engagement, etc.)"),
|
|
||||||
("⚡ Efficient Processing", "Stream data directly into LEANN without intermediate files"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for title, description in benefits:
|
|
||||||
print(f"\n{title}")
|
|
||||||
print(f" {description}")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main demo function."""
|
|
||||||
print("🎯 LEANN MCP Integration Examples")
|
|
||||||
print("This demo shows how to integrate LEANN with MCP servers for various platforms.")
|
|
||||||
|
|
||||||
await demo_slack_mcp()
|
|
||||||
await demo_twitter_mcp()
|
|
||||||
await show_mcp_server_setup()
|
|
||||||
await show_integration_benefits()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("✨ Next Steps")
|
|
||||||
print("=" * 60)
|
|
||||||
print("1. Install and configure MCP servers for your platforms")
|
|
||||||
print("2. Test connections using --test-connection flag")
|
|
||||||
print("3. Run indexing to build your RAG knowledge base")
|
|
||||||
print("4. Start querying your personal data!")
|
|
||||||
|
|
||||||
print("\n📚 For more information:")
|
|
||||||
print("- Check the README for detailed setup instructions")
|
|
||||||
print("- Look at the apps/slack_rag.py and apps/twitter_rag.py for implementation details")
|
|
||||||
print("- Explore other MCP servers for additional platforms")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -29,25 +29,12 @@ if(APPLE)
|
|||||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Find ZMQ using pkg-config with IMPORTED_TARGET for automatic target creation
|
# Use system ZeroMQ instead of building from source
|
||||||
find_package(PkgConfig REQUIRED)
|
find_package(PkgConfig REQUIRED)
|
||||||
|
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||||
# On ARM64 macOS, ensure pkg-config finds ARM64 Homebrew packages first
|
|
||||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
|
||||||
set(ENV{PKG_CONFIG_PATH} "/opt/homebrew/lib/pkgconfig:/opt/homebrew/share/pkgconfig:$ENV{PKG_CONFIG_PATH}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
pkg_check_modules(ZMQ REQUIRED IMPORTED_TARGET libzmq)
|
|
||||||
|
|
||||||
# This creates PkgConfig::ZMQ target automatically with correct properties
|
|
||||||
if(TARGET PkgConfig::ZMQ)
|
|
||||||
message(STATUS "Found and configured ZMQ target: PkgConfig::ZMQ")
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "pkg_check_modules did not create IMPORTED target for ZMQ.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Add cppzmq headers
|
# Add cppzmq headers
|
||||||
include_directories(SYSTEM third_party/cppzmq)
|
include_directories(third_party/cppzmq)
|
||||||
|
|
||||||
# Configure msgpack-c - disable boost dependency
|
# Configure msgpack-c - disable boost dependency
|
||||||
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||||
|
|||||||
@@ -215,8 +215,6 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
if hasattr(self._index, "set_zmq_port"):
|
|
||||||
self._index.set_zmq_port(zmq_port)
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
|
|||||||
@@ -143,6 +143,8 @@ def create_hnsw_embedding_server(
|
|||||||
pass
|
pass
|
||||||
return str(nid)
|
return str(nid)
|
||||||
|
|
||||||
|
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||||
|
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
"""ZMQ server thread that respects shutdown signal.
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
@@ -156,31 +158,35 @@ def create_hnsw_embedding_server(
|
|||||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
|
# Keep sends from blocking during shutdown; fail fast and drop on close
|
||||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
last_request_type = "unknown"
|
# Track last request type/length for shape-correct fallbacks
|
||||||
|
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
||||||
last_request_length = 0
|
last_request_length = 0
|
||||||
|
|
||||||
def _build_safe_fallback():
|
try:
|
||||||
if last_request_type == "distance":
|
while not shutdown_event.is_set():
|
||||||
large_distance = 1e9
|
try:
|
||||||
fallback_len = max(0, int(last_request_length))
|
|
||||||
return [[large_distance] * fallback_len]
|
|
||||||
if last_request_type == "embedding":
|
|
||||||
bsz = max(0, int(last_request_length))
|
|
||||||
dim = max(0, int(embedding_dim))
|
|
||||||
if dim > 0:
|
|
||||||
return [[bsz, dim], [0.0] * (bsz * dim)]
|
|
||||||
return [[0, 0], []]
|
|
||||||
if last_request_type == "text":
|
|
||||||
return []
|
|
||||||
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
|
||||||
|
|
||||||
def _handle_text_embedding(request: list[str]) -> None:
|
|
||||||
nonlocal last_request_type, last_request_length
|
|
||||||
|
|
||||||
e2e_start = time.time()
|
e2e_start = time.time()
|
||||||
|
logger.debug("🔍 Waiting for ZMQ message...")
|
||||||
|
request_bytes = rep_socket.recv()
|
||||||
|
|
||||||
|
# Rest of the processing logic (same as original)
|
||||||
|
request = msgpack.unpackb(request_bytes)
|
||||||
|
|
||||||
|
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||||
|
response_bytes = msgpack.packb([model_name])
|
||||||
|
rep_socket.send(response_bytes)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle direct text embedding request
|
||||||
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and request
|
||||||
|
and all(isinstance(item, str) for item in request)
|
||||||
|
):
|
||||||
last_request_type = "text"
|
last_request_type = "text"
|
||||||
last_request_length = len(request)
|
last_request_length = len(request)
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
@@ -191,13 +197,18 @@ def create_hnsw_embedding_server(
|
|||||||
)
|
)
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
continue
|
||||||
|
|
||||||
def _handle_distance_request(request: list[Any]) -> None:
|
# Handle distance calculation request: [[ids], [query_vector]]
|
||||||
nonlocal last_request_type, last_request_length
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
e2e_start = time.time()
|
and len(request) == 2
|
||||||
|
and isinstance(request[0], list)
|
||||||
|
and isinstance(request[1], list)
|
||||||
|
):
|
||||||
node_ids = request[0]
|
node_ids = request[0]
|
||||||
|
# Handle nested [[ids]] shape defensively
|
||||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||||
node_ids = node_ids[0]
|
node_ids = node_ids[0]
|
||||||
query_vector = np.array(request[1], dtype=np.float32)
|
query_vector = np.array(request[1], dtype=np.float32)
|
||||||
@@ -208,6 +219,7 @@ def create_hnsw_embedding_server(
|
|||||||
logger.debug(f" Node IDs: {node_ids}")
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
|
# Gather texts for found ids
|
||||||
texts: list[str] = []
|
texts: list[str] = []
|
||||||
found_indices: list[int] = []
|
found_indices: list[int] = []
|
||||||
for idx, nid in enumerate(node_ids):
|
for idx, nid in enumerate(node_ids):
|
||||||
@@ -222,9 +234,10 @@ def create_hnsw_embedding_server(
|
|||||||
logger.error(f"Empty text for passage ID {passage_id}")
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Passage ID {nid} not found")
|
logger.error(f"Passage ID {nid} not found")
|
||||||
except Exception as exc:
|
except Exception as e:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
|
||||||
|
# Prepare full-length response with large sentinel values
|
||||||
large_distance = 1e9
|
large_distance = 1e9
|
||||||
response_distances = [large_distance] * len(node_ids)
|
response_distances = [large_distance] * len(node_ids)
|
||||||
|
|
||||||
@@ -243,33 +256,36 @@ def create_hnsw_embedding_server(
|
|||||||
partial = np.sum(
|
partial = np.sum(
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
)
|
)
|
||||||
else:
|
else: # mips or cosine
|
||||||
partial = -np.dot(embeddings, query_vector)
|
partial = -np.dot(embeddings, query_vector)
|
||||||
|
|
||||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||||
response_distances[pos] = float(dval)
|
response_distances[pos] = float(dval)
|
||||||
except Exception as exc:
|
except Exception as e:
|
||||||
logger.error(f"Distance computation error, using sentinels: {exc}")
|
logger.error(f"Distance computation error, using sentinels: {e}")
|
||||||
|
|
||||||
|
# Send response in expected shape [[distances]]
|
||||||
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
continue
|
||||||
|
|
||||||
def _handle_embedding_by_id(request: Any) -> None:
|
# Fallback: treat as embedding-by-id request
|
||||||
nonlocal last_request_type, last_request_length
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
|
and len(request) == 1
|
||||||
|
and isinstance(request[0], list)
|
||||||
|
):
|
||||||
node_ids = request[0]
|
node_ids = request[0]
|
||||||
elif isinstance(request, list):
|
elif isinstance(request, list):
|
||||||
node_ids = request
|
node_ids = request
|
||||||
else:
|
else:
|
||||||
node_ids = []
|
node_ids = []
|
||||||
|
|
||||||
e2e_start = time.time()
|
|
||||||
last_request_type = "embedding"
|
last_request_type = "embedding"
|
||||||
last_request_length = len(node_ids)
|
last_request_length = len(node_ids)
|
||||||
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||||
|
|
||||||
|
# Preallocate zero-filled flat data for robustness
|
||||||
if embedding_dim <= 0:
|
if embedding_dim <= 0:
|
||||||
dims = [0, 0]
|
dims = [0, 0]
|
||||||
flat_data: list[float] = []
|
flat_data: list[float] = []
|
||||||
@@ -277,6 +293,7 @@ def create_hnsw_embedding_server(
|
|||||||
dims = [len(node_ids), embedding_dim]
|
dims = [len(node_ids), embedding_dim]
|
||||||
flat_data = [0.0] * (dims[0] * dims[1])
|
flat_data = [0.0] * (dims[0] * dims[1])
|
||||||
|
|
||||||
|
# Collect texts for found ids
|
||||||
texts: list[str] = []
|
texts: list[str] = []
|
||||||
found_indices: list[int] = []
|
found_indices: list[int] = []
|
||||||
for idx, nid in enumerate(node_ids):
|
for idx, nid in enumerate(node_ids):
|
||||||
@@ -291,8 +308,8 @@ def create_hnsw_embedding_server(
|
|||||||
logger.error(f"Empty text for passage ID {passage_id}")
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Passage with ID {nid} not found")
|
logger.error(f"Passage with ID {nid} not found")
|
||||||
except Exception as exc:
|
except Exception as e:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
|
||||||
if texts:
|
if texts:
|
||||||
try:
|
try:
|
||||||
@@ -322,72 +339,44 @@ def create_hnsw_embedding_server(
|
|||||||
flat_data[start:end] = flat[
|
flat_data[start:end] = flat[
|
||||||
j * embedding_dim : (j + 1) * embedding_dim
|
j * embedding_dim : (j + 1) * embedding_dim
|
||||||
]
|
]
|
||||||
except Exception as exc:
|
except Exception as e:
|
||||||
logger.error(f"Embedding computation error, returning zeros: {exc}")
|
logger.error(f"Embedding computation error, returning zeros: {e}")
|
||||||
|
|
||||||
response_payload = [dims, flat_data]
|
response_payload = [dims, flat_data]
|
||||||
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||||
|
|
||||||
|
rep_socket.send(response_bytes)
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
try:
|
|
||||||
while not shutdown_event.is_set():
|
|
||||||
try:
|
|
||||||
logger.debug("🔍 Waiting for ZMQ message...")
|
|
||||||
request_bytes = rep_socket.recv()
|
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
|
# Timeout - check shutdown_event and continue
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
if not shutdown_event.is_set():
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
# Shape-correct fallback
|
||||||
try:
|
try:
|
||||||
request = msgpack.unpackb(request_bytes)
|
if last_request_type == "distance":
|
||||||
except Exception as exc:
|
large_distance = 1e9
|
||||||
if shutdown_event.is_set():
|
fallback_len = max(0, int(last_request_length))
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
safe = [[large_distance] * fallback_len]
|
||||||
break
|
elif last_request_type == "embedding":
|
||||||
logger.error(f"Error unpacking ZMQ message: {exc}")
|
bsz = max(0, int(last_request_length))
|
||||||
try:
|
dim = max(0, int(embedding_dim))
|
||||||
safe = _build_safe_fallback()
|
safe = (
|
||||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
||||||
except Exception:
|
)
|
||||||
pass
|
elif last_request_type == "text":
|
||||||
continue
|
safe = [] # direct text embeddings expectation is a flat list
|
||||||
|
|
||||||
try:
|
|
||||||
# Model query
|
|
||||||
if (
|
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 1
|
|
||||||
and request[0] == "__QUERY_MODEL__"
|
|
||||||
):
|
|
||||||
rep_socket.send(msgpack.packb([model_name]))
|
|
||||||
# Direct text embedding
|
|
||||||
elif (
|
|
||||||
isinstance(request, list)
|
|
||||||
and request
|
|
||||||
and all(isinstance(item, str) for item in request)
|
|
||||||
):
|
|
||||||
_handle_text_embedding(request)
|
|
||||||
# Distance calculation: [[ids], [query_vector]]
|
|
||||||
elif (
|
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 2
|
|
||||||
and isinstance(request[0], list)
|
|
||||||
and isinstance(request[1], list)
|
|
||||||
):
|
|
||||||
_handle_distance_request(request)
|
|
||||||
# Embedding-by-id fallback
|
|
||||||
else:
|
else:
|
||||||
_handle_embedding_by_id(request)
|
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||||
except Exception as exc:
|
|
||||||
if shutdown_event.is_set():
|
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
|
||||||
break
|
|
||||||
logger.error(f"Error in ZMQ server loop: {exc}")
|
|
||||||
try:
|
|
||||||
safe = _build_safe_fallback()
|
|
||||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
rep_socket.close(0)
|
rep_socket.close(0)
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 301bf24f14...5952745237
@@ -18,16 +18,14 @@ dependencies = [
|
|||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
"torch>=2.0.0",
|
"torch>=2.0.0",
|
||||||
"sentence-transformers>=3.0.0",
|
"sentence-transformers>=2.2.0",
|
||||||
"llama-index-core>=0.12.0",
|
"llama-index-core>=0.12.0",
|
||||||
"llama-index-readers-file>=0.4.0", # Essential for document reading
|
"llama-index-readers-file>=0.4.0", # Essential for document reading
|
||||||
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
|
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
"openai>=1.0.0",
|
"openai>=1.0.0",
|
||||||
"huggingface-hub>=0.20.0",
|
"huggingface-hub>=0.20.0",
|
||||||
# Keep transformers below 4.46: 4.46.0 adds Python 3.10-only return type syntax and
|
"transformers>=4.30.0",
|
||||||
# breaks Python 3.9 environments.
|
|
||||||
"transformers>=4.30.0,<4.46",
|
|
||||||
"requests>=2.25.0",
|
"requests>=2.25.0",
|
||||||
"accelerate>=0.20.0",
|
"accelerate>=0.20.0",
|
||||||
"PyPDF2>=3.0.0",
|
"PyPDF2>=3.0.0",
|
||||||
@@ -42,7 +40,7 @@ dependencies = [
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
colab = [
|
colab = [
|
||||||
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
|
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
|
||||||
"transformers>=4.30.0,<4.46", # 4.46.0 switches to PEP 604 typing (int | None), breaks Py3.9
|
"transformers>=4.30.0,<5.0.0", # Limit transformers version
|
||||||
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
|
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from typing import Any, Literal, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
from leann.interactive_utils import create_api_session
|
|
||||||
from leann.interface import LeannBackendSearcherInterface
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
|
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
@@ -814,16 +813,11 @@ class LeannBuilder:
|
|||||||
"Failed to start HNSW embedding server for recompute update."
|
"Failed to start HNSW embedding server for recompute update."
|
||||||
)
|
)
|
||||||
if actual_port != requested_zmq_port:
|
if actual_port != requested_zmq_port:
|
||||||
logger.warning(
|
server_manager.stop_server()
|
||||||
"Embedding server started on port %s instead of requested %s. "
|
raise RuntimeError(
|
||||||
"Using reassigned port.",
|
"Embedding server started on unexpected port "
|
||||||
actual_port,
|
f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free."
|
||||||
requested_zmq_port,
|
|
||||||
)
|
)
|
||||||
if hasattr(index.hnsw, "set_zmq_port"):
|
|
||||||
index.hnsw.set_zmq_port(actual_port)
|
|
||||||
elif hasattr(index, "set_zmq_port"):
|
|
||||||
index.set_zmq_port(actual_port)
|
|
||||||
|
|
||||||
if needs_recompute:
|
if needs_recompute:
|
||||||
for i in range(embeddings.shape[0]):
|
for i in range(embeddings.shape[0]):
|
||||||
@@ -864,13 +858,7 @@ class LeannBuilder:
|
|||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
self,
|
|
||||||
index_path: str,
|
|
||||||
enable_warmup: bool = True,
|
|
||||||
recompute_embeddings: bool = True,
|
|
||||||
**backend_kwargs,
|
|
||||||
):
|
|
||||||
# Fix path resolution for Colab and other environments
|
# Fix path resolution for Colab and other environments
|
||||||
if not Path(index_path).is_absolute():
|
if not Path(index_path).is_absolute():
|
||||||
index_path = str(Path(index_path).resolve())
|
index_path = str(Path(index_path).resolve())
|
||||||
@@ -901,32 +889,14 @@ class LeannSearcher:
|
|||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
|
|
||||||
# Global recompute flag for this searcher (explicit knob, default True)
|
|
||||||
self.recompute_embeddings: bool = bool(recompute_embeddings)
|
|
||||||
|
|
||||||
# Warmup flag: keep using the existing enable_warmup parameter,
|
|
||||||
# but default it to True so cold-start happens earlier.
|
|
||||||
self._warmup: bool = bool(enable_warmup)
|
|
||||||
|
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = self._warmup
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
if self.embedding_options:
|
if self.embedding_options:
|
||||||
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
index_path, **final_kwargs
|
index_path, **final_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional one-shot warmup at construction time to hide cold-start latency.
|
|
||||||
if self._warmup:
|
|
||||||
try:
|
|
||||||
_ = self.backend_impl.compute_query_embedding(
|
|
||||||
"__LEANN_WARMUP__",
|
|
||||||
use_server_if_available=self.recompute_embeddings,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"Warmup embedding failed (ignored): {exc}")
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -934,7 +904,7 @@ class LeannSearcher:
|
|||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: Optional[bool] = None,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
@@ -951,8 +921,7 @@ class LeannSearcher:
|
|||||||
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
beam_width: Number of parallel search paths/IO requests per iteration
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: (Deprecated) Per-call override for recompute mode.
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
||||||
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
|
|
||||||
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
expected_zmq_port: ZMQ port for embedding server communication
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
metadata_filters: Optional filters to apply to search results based on metadata.
|
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||||
@@ -991,19 +960,8 @@ class LeannSearcher:
|
|||||||
|
|
||||||
zmq_port = None
|
zmq_port = None
|
||||||
|
|
||||||
# Resolve effective recompute flag for this search.
|
|
||||||
if recompute_embeddings is not None:
|
|
||||||
logger.warning(
|
|
||||||
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
|
|
||||||
"will be removed in a future version. Configure recompute at "
|
|
||||||
"LeannSearcher(..., recompute_embeddings=...) instead."
|
|
||||||
)
|
|
||||||
effective_recompute = bool(recompute_embeddings)
|
|
||||||
else:
|
|
||||||
effective_recompute = self.recompute_embeddings
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if effective_recompute:
|
if recompute_embeddings:
|
||||||
zmq_port = self.backend_impl._ensure_server_running(
|
zmq_port = self.backend_impl._ensure_server_running(
|
||||||
self.meta_path_str,
|
self.meta_path_str,
|
||||||
port=expected_zmq_port,
|
port=expected_zmq_port,
|
||||||
@@ -1017,7 +975,7 @@ class LeannSearcher:
|
|||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
query,
|
query,
|
||||||
use_server_if_available=effective_recompute,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
@@ -1029,7 +987,7 @@ class LeannSearcher:
|
|||||||
"complexity": complexity,
|
"complexity": complexity,
|
||||||
"beam_width": beam_width,
|
"beam_width": beam_width,
|
||||||
"prune_ratio": prune_ratio,
|
"prune_ratio": prune_ratio,
|
||||||
"recompute_embeddings": effective_recompute,
|
"recompute_embeddings": recompute_embeddings,
|
||||||
"pruning_strategy": pruning_strategy,
|
"pruning_strategy": pruning_strategy,
|
||||||
"zmq_port": zmq_port,
|
"zmq_port": zmq_port,
|
||||||
}
|
}
|
||||||
@@ -1272,17 +1230,6 @@ class LeannChat:
|
|||||||
"Please provide the best answer you can based on this context and your knowledge."
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
)
|
)
|
||||||
|
|
||||||
print("The context provided to the LLM is:")
|
|
||||||
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
|
||||||
print("-" * 150)
|
|
||||||
for r in results:
|
|
||||||
chunk_relevance = f"{r.score:.3f}"
|
|
||||||
chunk_id = r.id
|
|
||||||
chunk_content = r.text[:60]
|
|
||||||
chunk_source = r.metadata.get("source", "")[:80]
|
|
||||||
print(
|
|
||||||
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
|
|
||||||
)
|
|
||||||
ask_time = time.time()
|
ask_time = time.time()
|
||||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||||
ask_time = time.time() - ask_time
|
ask_time = time.time() - ask_time
|
||||||
@@ -1290,14 +1237,19 @@ class LeannChat:
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
def start_interactive(self):
|
def start_interactive(self):
|
||||||
"""Start interactive chat session."""
|
print("\nLeann Chat started (type 'quit' to exit)")
|
||||||
session = create_api_session()
|
while True:
|
||||||
|
try:
|
||||||
def handle_query(user_input: str):
|
user_input = input("You: ").strip()
|
||||||
|
if user_input.lower() in ["quit", "exit"]:
|
||||||
|
break
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
response = self.ask(user_input)
|
response = self.ask(user_input)
|
||||||
print(f"Leann: {response}")
|
print(f"Leann: {response}")
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
session.run_interactive_loop(handle_query)
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Explicitly cleanup embedding server resources.
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
|||||||
@@ -546,30 +546,11 @@ class OllamaChat(LLMInterface):
|
|||||||
|
|
||||||
|
|
||||||
class HFChat(LLMInterface):
|
class HFChat(LLMInterface):
|
||||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates.
|
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||||
|
|
||||||
Args:
|
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||||
model_name (str): Name of the Hugging Face model to load.
|
|
||||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
|
||||||
Defaults to False for security. Only enable for trusted models as this can pose
|
|
||||||
a security risk if the model repository is compromised.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat", trust_remote_code: bool = False
|
|
||||||
):
|
|
||||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||||
|
|
||||||
# Security warning when trust_remote_code is enabled
|
|
||||||
if trust_remote_code:
|
|
||||||
logger.warning(
|
|
||||||
"SECURITY WARNING: trust_remote_code=True allows execution of arbitrary code from the model repository. "
|
|
||||||
"Only enable this for models from trusted sources. This creates a potential security risk if the model "
|
|
||||||
"repository is compromised."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.trust_remote_code = trust_remote_code
|
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model_name, "hf")
|
model_error = validate_model_and_suggest(model_name, "hf")
|
||||||
if model_error:
|
if model_error:
|
||||||
@@ -607,16 +588,14 @@ class HFChat(LLMInterface):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading tokenizer for {model_name}...")
|
logger.info(f"Loading tokenizer for {model_name}...")
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
model_name, trust_remote_code=self.trust_remote_code
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Loading model {model_name}...")
|
logger.info(f"Loading model {model_name}...")
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||||
device_map="auto" if self.device != "cpu" else None,
|
device_map="auto" if self.device != "cpu" else None,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
logger.info(f"Successfully loaded {model_name}")
|
logger.info(f"Successfully loaded {model_name}")
|
||||||
finally:
|
finally:
|
||||||
@@ -834,11 +813,6 @@ class OpenAIChat(LLMInterface):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
print(
|
|
||||||
f"Total tokens = {response.usage.total_tokens}, prompt tokens = {response.usage.prompt_tokens}, completion tokens = {response.usage.completion_tokens}"
|
|
||||||
)
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
print("The query is exceeding the maximum allowed number of tokens")
|
|
||||||
return response.choices[0].message.content.strip()
|
return response.choices[0].message.content.strip()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error communicating with OpenAI: {e}")
|
logger.error(f"Error communicating with OpenAI: {e}")
|
||||||
@@ -885,10 +859,7 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
host=llm_config.get("host"),
|
host=llm_config.get("host"),
|
||||||
)
|
)
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
return HFChat(
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
model_name=model or "deepseek-ai/deepseek-llm-7b-chat",
|
|
||||||
trust_remote_code=llm_config.get("trust_remote_code", False),
|
|
||||||
)
|
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(
|
return OpenAIChat(
|
||||||
model=model or "gpt-4o",
|
model=model or "gpt-4o",
|
||||||
|
|||||||
@@ -5,128 +5,12 @@ Packaged within leann-core so installed wheels can import it reliably.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Flag to ensure AST token warning only shown once per session
|
|
||||||
_ast_token_warning_shown = False
|
|
||||||
|
|
||||||
|
|
||||||
def estimate_token_count(text: str) -> int:
|
|
||||||
"""
|
|
||||||
Estimate token count for a text string.
|
|
||||||
Uses conservative estimation: ~4 characters per token for natural text,
|
|
||||||
~1.2 tokens per character for code (worse tokenization).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Input text to estimate tokens for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Estimated token count
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import tiktoken
|
|
||||||
|
|
||||||
encoder = tiktoken.get_encoding("cl100k_base")
|
|
||||||
return len(encoder.encode(text))
|
|
||||||
except ImportError:
|
|
||||||
# Fallback: Conservative character-based estimation
|
|
||||||
# Assume worst case for code: 1.2 tokens per character
|
|
||||||
return int(len(text) * 1.2)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_safe_chunk_size(
|
|
||||||
model_token_limit: int,
|
|
||||||
overlap_tokens: int,
|
|
||||||
chunking_mode: str = "traditional",
|
|
||||||
safety_factor: float = 0.9,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Calculate safe chunk size accounting for overlap and safety margin.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_token_limit: Maximum tokens supported by embedding model
|
|
||||||
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
|
|
||||||
chunking_mode: "traditional" (tokens) or "ast" (characters)
|
|
||||||
safety_factor: Safety margin (0.9 = 10% safety margin)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Safe chunk size: tokens for traditional, characters for AST
|
|
||||||
"""
|
|
||||||
safe_limit = int(model_token_limit * safety_factor)
|
|
||||||
|
|
||||||
if chunking_mode == "traditional":
|
|
||||||
# Traditional chunking uses tokens
|
|
||||||
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
|
|
||||||
return max(1, safe_limit - overlap_tokens)
|
|
||||||
else: # AST chunking
|
|
||||||
# AST uses characters, need to convert
|
|
||||||
# Conservative estimate: 1.2 tokens per char for code
|
|
||||||
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
|
|
||||||
safe_chars = int(safe_limit / 1.2)
|
|
||||||
return max(1, safe_chars - overlap_chars)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
|
|
||||||
"""
|
|
||||||
Validate that chunks don't exceed token limits and truncate if necessary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunks: List of text chunks to validate
|
|
||||||
max_tokens: Maximum tokens allowed per chunk
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (validated_chunks, num_truncated)
|
|
||||||
"""
|
|
||||||
validated_chunks = []
|
|
||||||
num_truncated = 0
|
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
estimated_tokens = estimate_token_count(chunk)
|
|
||||||
|
|
||||||
if estimated_tokens > max_tokens:
|
|
||||||
# Truncate chunk to fit token limit
|
|
||||||
try:
|
|
||||||
import tiktoken
|
|
||||||
|
|
||||||
encoder = tiktoken.get_encoding("cl100k_base")
|
|
||||||
tokens = encoder.encode(chunk)
|
|
||||||
if len(tokens) > max_tokens:
|
|
||||||
truncated_tokens = tokens[:max_tokens]
|
|
||||||
truncated_chunk = encoder.decode(truncated_tokens)
|
|
||||||
validated_chunks.append(truncated_chunk)
|
|
||||||
num_truncated += 1
|
|
||||||
logger.warning(
|
|
||||||
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
|
|
||||||
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
validated_chunks.append(chunk)
|
|
||||||
except ImportError:
|
|
||||||
# Fallback: Conservative character truncation
|
|
||||||
char_limit = int(max_tokens / 1.2) # Conservative for code
|
|
||||||
if len(chunk) > char_limit:
|
|
||||||
truncated_chunk = chunk[:char_limit]
|
|
||||||
validated_chunks.append(truncated_chunk)
|
|
||||||
num_truncated += 1
|
|
||||||
logger.warning(
|
|
||||||
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
|
|
||||||
f"(conservative estimate for {max_tokens} tokens)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
validated_chunks.append(chunk)
|
|
||||||
else:
|
|
||||||
validated_chunks.append(chunk)
|
|
||||||
|
|
||||||
if num_truncated > 0:
|
|
||||||
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
|
|
||||||
|
|
||||||
return validated_chunks, num_truncated
|
|
||||||
|
|
||||||
|
|
||||||
# Code file extensions supported by astchunk
|
# Code file extensions supported by astchunk
|
||||||
CODE_EXTENSIONS = {
|
CODE_EXTENSIONS = {
|
||||||
".py": "python",
|
".py": "python",
|
||||||
@@ -177,45 +61,27 @@ def create_ast_chunks(
|
|||||||
max_chunk_size: int = 512,
|
max_chunk_size: int = 512,
|
||||||
chunk_overlap: int = 64,
|
chunk_overlap: int = 64,
|
||||||
metadata_template: str = "default",
|
metadata_template: str = "default",
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[str]:
|
||||||
"""Create AST-aware chunks from code documents using astchunk.
|
"""Create AST-aware chunks from code documents using astchunk.
|
||||||
|
|
||||||
Falls back to traditional chunking if astchunk is unavailable.
|
Falls back to traditional chunking if astchunk is unavailable.
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with {"text": str, "metadata": dict}
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from astchunk import ASTChunkBuilder # optional dependency
|
from astchunk import ASTChunkBuilder # optional dependency
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"astchunk not available: {e}")
|
logger.error(f"astchunk not available: {e}")
|
||||||
logger.info("Falling back to traditional chunking for code files")
|
logger.info("Falling back to traditional chunking for code files")
|
||||||
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
|
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||||
|
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
language = doc.metadata.get("language")
|
language = doc.metadata.get("language")
|
||||||
if not language:
|
if not language:
|
||||||
logger.warning("No language detected; falling back to traditional chunking")
|
logger.warning("No language detected; falling back to traditional chunking")
|
||||||
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Warn once if AST chunk size + overlap might exceed common token limits
|
|
||||||
# Note: Actual truncation happens at embedding time with dynamic model limits
|
|
||||||
global _ast_token_warning_shown
|
|
||||||
estimated_max_tokens = int(
|
|
||||||
(max_chunk_size + chunk_overlap) * 1.2
|
|
||||||
) # Conservative estimate
|
|
||||||
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
|
|
||||||
logger.warning(
|
|
||||||
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
|
||||||
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
|
||||||
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
|
|
||||||
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
|
|
||||||
)
|
|
||||||
_ast_token_warning_shown = True
|
|
||||||
|
|
||||||
configs = {
|
configs = {
|
||||||
"max_chunk_size": max_chunk_size,
|
"max_chunk_size": max_chunk_size,
|
||||||
"language": language,
|
"language": language,
|
||||||
@@ -239,40 +105,17 @@ def create_ast_chunks(
|
|||||||
|
|
||||||
chunks = chunk_builder.chunkify(code_content)
|
chunks = chunk_builder.chunkify(code_content)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
chunk_text = None
|
|
||||||
astchunk_metadata = {}
|
|
||||||
|
|
||||||
if hasattr(chunk, "text"):
|
if hasattr(chunk, "text"):
|
||||||
chunk_text = chunk.text
|
chunk_text = chunk.text
|
||||||
|
elif isinstance(chunk, dict) and "text" in chunk:
|
||||||
|
chunk_text = chunk["text"]
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
chunk_text = chunk
|
chunk_text = chunk
|
||||||
elif isinstance(chunk, dict):
|
|
||||||
# Handle astchunk format: {"content": "...", "metadata": {...}}
|
|
||||||
if "content" in chunk:
|
|
||||||
chunk_text = chunk["content"]
|
|
||||||
astchunk_metadata = chunk.get("metadata", {})
|
|
||||||
elif "text" in chunk:
|
|
||||||
chunk_text = chunk["text"]
|
|
||||||
else:
|
|
||||||
chunk_text = str(chunk) # Last resort
|
|
||||||
else:
|
else:
|
||||||
chunk_text = str(chunk)
|
chunk_text = str(chunk)
|
||||||
|
|
||||||
if chunk_text and chunk_text.strip():
|
if chunk_text and chunk_text.strip():
|
||||||
# Extract document-level metadata
|
all_chunks.append(chunk_text.strip())
|
||||||
doc_metadata = {
|
|
||||||
"file_path": doc.metadata.get("file_path", ""),
|
|
||||||
"file_name": doc.metadata.get("file_name", ""),
|
|
||||||
}
|
|
||||||
if "creation_date" in doc.metadata:
|
|
||||||
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
|
||||||
if "last_modified_date" in doc.metadata:
|
|
||||||
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
|
||||||
|
|
||||||
# Merge document metadata + astchunk metadata
|
|
||||||
combined_metadata = {**doc_metadata, **astchunk_metadata}
|
|
||||||
|
|
||||||
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||||
@@ -280,19 +123,15 @@ def create_ast_chunks(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||||
logger.info("Falling back to traditional chunking")
|
logger.info("Falling back to traditional chunking")
|
||||||
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
|
|
||||||
return all_chunks
|
return all_chunks
|
||||||
|
|
||||||
|
|
||||||
def create_traditional_chunks(
|
def create_traditional_chunks(
|
||||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[str]:
|
||||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
|
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with {"text": str, "metadata": dict}
|
|
||||||
"""
|
|
||||||
if chunk_size <= 0:
|
if chunk_size <= 0:
|
||||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||||
chunk_size = 256
|
chunk_size = 256
|
||||||
@@ -308,40 +147,19 @@ def create_traditional_chunks(
|
|||||||
paragraph_separator="\n\n",
|
paragraph_separator="\n\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = []
|
all_texts = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
# Extract document-level metadata
|
|
||||||
doc_metadata = {
|
|
||||||
"file_path": doc.metadata.get("file_path", ""),
|
|
||||||
"file_name": doc.metadata.get("file_name", ""),
|
|
||||||
}
|
|
||||||
if "creation_date" in doc.metadata:
|
|
||||||
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
|
||||||
if "last_modified_date" in doc.metadata:
|
|
||||||
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
if nodes:
|
if nodes:
|
||||||
for node in nodes:
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
result.append({"text": node.get_content(), "metadata": doc_metadata})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Traditional chunking failed for document: {e}")
|
logger.error(f"Traditional chunking failed for document: {e}")
|
||||||
content = doc.get_content()
|
content = doc.get_content()
|
||||||
if content and content.strip():
|
if content and content.strip():
|
||||||
result.append({"text": content.strip(), "metadata": doc_metadata})
|
all_texts.append(content.strip())
|
||||||
|
|
||||||
return result
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
def _traditional_chunks_as_dicts(
|
|
||||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Helper: Traditional chunking that returns dict format for consistency.
|
|
||||||
|
|
||||||
This is now just an alias for create_traditional_chunks for backwards compatibility.
|
|
||||||
"""
|
|
||||||
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
|
||||||
|
|
||||||
|
|
||||||
def create_text_chunks(
|
def create_text_chunks(
|
||||||
@@ -353,12 +171,8 @@ def create_text_chunks(
|
|||||||
ast_chunk_overlap: int = 64,
|
ast_chunk_overlap: int = 64,
|
||||||
code_file_extensions: Optional[list[str]] = None,
|
code_file_extensions: Optional[list[str]] = None,
|
||||||
ast_fallback_traditional: bool = True,
|
ast_fallback_traditional: bool = True,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[str]:
|
||||||
"""Create text chunks from documents with optional AST support for code files.
|
"""Create text chunks from documents with optional AST support for code files."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with {"text": str, "metadata": dict}
|
|
||||||
"""
|
|
||||||
if not documents:
|
if not documents:
|
||||||
logger.warning("No documents provided for chunking")
|
logger.warning("No documents provided for chunking")
|
||||||
return []
|
return []
|
||||||
@@ -393,17 +207,14 @@ def create_text_chunks(
|
|||||||
logger.error(f"AST chunking failed: {e}")
|
logger.error(f"AST chunking failed: {e}")
|
||||||
if ast_fallback_traditional:
|
if ast_fallback_traditional:
|
||||||
all_chunks.extend(
|
all_chunks.extend(
|
||||||
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
|
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
if text_docs:
|
if text_docs:
|
||||||
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
|
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||||
else:
|
else:
|
||||||
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
|
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
|
|
||||||
# Note: Token truncation is now handled at embedding time with dynamic model limits
|
|
||||||
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
|
|
||||||
return all_chunks
|
return all_chunks
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
@@ -9,7 +8,6 @@ from llama_index.core.node_parser import SentenceSplitter
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .interactive_utils import create_cli_session
|
|
||||||
from .registry import register_project_directory
|
from .registry import register_project_directory
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
@@ -107,7 +105,7 @@ Examples:
|
|||||||
help="Documents directories and/or files (default: current directory)",
|
help="Documents directories and/or files (default: current directory)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--backend-name",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="hnsw",
|
default="hnsw",
|
||||||
choices=["hnsw", "diskann"],
|
choices=["hnsw", "diskann"],
|
||||||
@@ -181,25 +179,25 @@ Examples:
|
|||||||
"--doc-chunk-size",
|
"--doc-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=256,
|
default=256,
|
||||||
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
|
help="Document chunk size in tokens/characters (default: 256)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--doc-chunk-overlap",
|
"--doc-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=128,
|
default=128,
|
||||||
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
|
help="Document chunk overlap (default: 128)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--code-chunk-size",
|
"--code-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=512,
|
||||||
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
|
help="Code chunk size in tokens/lines (default: 512)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--code-chunk-overlap",
|
"--code-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=50,
|
||||||
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
|
help="Code chunk overlap (default: 50)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--use-ast-chunking",
|
"--use-ast-chunking",
|
||||||
@@ -209,14 +207,14 @@ Examples:
|
|||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--ast-chunk-size",
|
"--ast-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=300,
|
default=768,
|
||||||
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
|
help="AST chunk size in characters (default: 768)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--ast-chunk-overlap",
|
"--ast-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=96,
|
||||||
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
|
help="AST chunk overlap in characters (default: 96)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--ast-fallback-traditional",
|
"--ast-fallback-traditional",
|
||||||
@@ -255,11 +253,6 @@ Examples:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Non-interactive mode: automatically select index without prompting",
|
help="Non-interactive mode: automatically select index without prompting",
|
||||||
)
|
)
|
||||||
search_parser.add_argument(
|
|
||||||
"--show-metadata",
|
|
||||||
action="store_true",
|
|
||||||
help="Display file paths and metadata in search results",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ask command
|
# Ask command
|
||||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
@@ -1192,7 +1185,6 @@ Examples:
|
|||||||
for doc in other_docs:
|
for doc in other_docs:
|
||||||
file_path = doc.metadata.get("file_path", "")
|
file_path = doc.metadata.get("file_path", "")
|
||||||
if file_filter(file_path):
|
if file_filter(file_path):
|
||||||
doc.metadata["source"] = file_path
|
|
||||||
filtered_docs.append(doc)
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
documents.extend(filtered_docs)
|
documents.extend(filtered_docs)
|
||||||
@@ -1268,7 +1260,7 @@ Examples:
|
|||||||
from .chunking_utils import create_text_chunks
|
from .chunking_utils import create_text_chunks
|
||||||
|
|
||||||
# Use enhanced chunking with AST support
|
# Use enhanced chunking with AST support
|
||||||
chunk_texts = create_text_chunks(
|
all_texts = create_text_chunks(
|
||||||
documents,
|
documents,
|
||||||
chunk_size=self.node_parser.chunk_size,
|
chunk_size=self.node_parser.chunk_size,
|
||||||
chunk_overlap=self.node_parser.chunk_overlap,
|
chunk_overlap=self.node_parser.chunk_overlap,
|
||||||
@@ -1279,9 +1271,6 @@ Examples:
|
|||||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# create_text_chunks now returns list[dict] with metadata preserved
|
|
||||||
all_texts.extend(chunk_texts)
|
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(
|
print(
|
||||||
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
||||||
@@ -1293,27 +1282,14 @@ Examples:
|
|||||||
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
||||||
# Check if this is a code file based on source path
|
# Check if this is a code file based on source path
|
||||||
source_path = doc.metadata.get("source", "")
|
source_path = doc.metadata.get("source", "")
|
||||||
file_path = doc.metadata.get("file_path", "")
|
|
||||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
||||||
|
|
||||||
# Extract metadata to preserve with chunks
|
|
||||||
chunk_metadata = {
|
|
||||||
"file_path": file_path or source_path,
|
|
||||||
"file_name": doc.metadata.get("file_name", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add optional metadata if available
|
|
||||||
if "creation_date" in doc.metadata:
|
|
||||||
chunk_metadata["creation_date"] = doc.metadata["creation_date"]
|
|
||||||
if "last_modified_date" in doc.metadata:
|
|
||||||
chunk_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
|
||||||
|
|
||||||
# Use appropriate parser based on file type
|
# Use appropriate parser based on file type
|
||||||
parser = self.code_parser if is_code_file else self.node_parser
|
parser = self.code_parser if is_code_file else self.node_parser
|
||||||
nodes = parser.get_nodes_from_documents([doc])
|
nodes = parser.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append({"text": node.get_content(), "metadata": chunk_metadata})
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||||
return all_texts
|
return all_texts
|
||||||
@@ -1388,7 +1364,7 @@ Examples:
|
|||||||
|
|
||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"Building index '{index_name}' with {args.backend_name} backend...")
|
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||||
|
|
||||||
embedding_options: dict[str, Any] = {}
|
embedding_options: dict[str, Any] = {}
|
||||||
if args.embedding_mode == "ollama":
|
if args.embedding_mode == "ollama":
|
||||||
@@ -1400,7 +1376,7 @@ Examples:
|
|||||||
embedding_options["api_key"] = resolved_embedding_key
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend_name,
|
backend_name=args.backend,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
embedding_options=embedding_options or None,
|
embedding_options=embedding_options or None,
|
||||||
@@ -1411,8 +1387,8 @@ Examples:
|
|||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in all_texts:
|
for chunk_text in all_texts:
|
||||||
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"Index built at {index_path}")
|
print(f"Index built at {index_path}")
|
||||||
@@ -1533,25 +1509,7 @@ Examples:
|
|||||||
print(f"Search results for '{query}' (top {len(results)}):")
|
print(f"Search results for '{query}' (top {len(results)}):")
|
||||||
for i, result in enumerate(results, 1):
|
for i, result in enumerate(results, 1):
|
||||||
print(f"{i}. Score: {result.score:.3f}")
|
print(f"{i}. Score: {result.score:.3f}")
|
||||||
|
|
||||||
# Display metadata if flag is set
|
|
||||||
if args.show_metadata and result.metadata:
|
|
||||||
file_path = result.metadata.get("file_path", "")
|
|
||||||
if file_path:
|
|
||||||
print(f" 📄 File: {file_path}")
|
|
||||||
|
|
||||||
file_name = result.metadata.get("file_name", "")
|
|
||||||
if file_name and file_name != file_path:
|
|
||||||
print(f" 📝 Name: {file_name}")
|
|
||||||
|
|
||||||
# Show timestamps if available
|
|
||||||
if "creation_date" in result.metadata:
|
|
||||||
print(f" 🕐 Created: {result.metadata['creation_date']}")
|
|
||||||
if "last_modified_date" in result.metadata:
|
|
||||||
print(f" 🕑 Modified: {result.metadata['last_modified_date']}")
|
|
||||||
|
|
||||||
print(f" {result.text[:200]}...")
|
print(f" {result.text[:200]}...")
|
||||||
print(f" Source: {result.metadata.get('source', '')}")
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
async def ask_questions(self, args):
|
async def ask_questions(self, args):
|
||||||
@@ -1583,7 +1541,6 @@ Examples:
|
|||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
def _ask_once(prompt: str) -> None:
|
def _ask_once(prompt: str) -> None:
|
||||||
query_start_time = time.time()
|
|
||||||
response = chat.ask(
|
response = chat.ask(
|
||||||
prompt,
|
prompt,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
@@ -1594,20 +1551,27 @@ Examples:
|
|||||||
pruning_strategy=args.pruning_strategy,
|
pruning_strategy=args.pruning_strategy,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
)
|
)
|
||||||
query_completion_time = time.time() - query_start_time
|
|
||||||
print(f"LEANN: {response}")
|
print(f"LEANN: {response}")
|
||||||
print(f"The query took {query_completion_time:.3f} seconds to finish")
|
|
||||||
|
|
||||||
initial_query = (args.query or "").strip()
|
initial_query = (args.query or "").strip()
|
||||||
|
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
# Create interactive session
|
|
||||||
session = create_cli_session(index_name)
|
|
||||||
|
|
||||||
if initial_query:
|
if initial_query:
|
||||||
_ask_once(initial_query)
|
_ask_once(initial_query)
|
||||||
|
|
||||||
session.run_interactive_loop(_ask_once)
|
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("\nYou: ").strip()
|
||||||
|
if user_input.lower() in ["quit", "exit", "q"]:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
_ask_once(user_input)
|
||||||
else:
|
else:
|
||||||
query = initial_query or input("Enter your question: ").strip()
|
query = initial_query or input("Enter your question: ").strip()
|
||||||
if not query:
|
if not query:
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import time
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
@@ -21,170 +20,6 @@ LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
|||||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# Token limit registry for embedding models
|
|
||||||
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
|
|
||||||
# Ollama models use dynamic discovery via /api/show
|
|
||||||
EMBEDDING_MODEL_LIMITS = {
|
|
||||||
# Nomic models (common across servers)
|
|
||||||
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
|
|
||||||
"nomic-embed-text-v1.5": 2048,
|
|
||||||
"nomic-embed-text-v2": 512,
|
|
||||||
# Other embedding models
|
|
||||||
"mxbai-embed-large": 512,
|
|
||||||
"all-minilm": 512,
|
|
||||||
"bge-m3": 8192,
|
|
||||||
"snowflake-arctic-embed": 512,
|
|
||||||
# OpenAI models
|
|
||||||
"text-embedding-3-small": 8192,
|
|
||||||
"text-embedding-3-large": 8192,
|
|
||||||
"text-embedding-ada-002": 8192,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(
|
|
||||||
model_name: str,
|
|
||||||
base_url: Optional[str] = None,
|
|
||||||
default: int = 2048,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Get token limit for a given embedding model.
|
|
||||||
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the embedding model
|
|
||||||
base_url: Base URL of the embedding server (for dynamic discovery)
|
|
||||||
default: Default token limit if model not found
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Token limit for the model in tokens
|
|
||||||
"""
|
|
||||||
# Try Ollama dynamic discovery if base_url provided
|
|
||||||
if base_url:
|
|
||||||
# Detect Ollama servers by port or "ollama" in URL
|
|
||||||
if "11434" in base_url or "ollama" in base_url.lower():
|
|
||||||
limit = _query_ollama_context_limit(model_name, base_url)
|
|
||||||
if limit:
|
|
||||||
return limit
|
|
||||||
|
|
||||||
# Fallback to known model registry with version handling (from PR #154)
|
|
||||||
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
|
||||||
base_model_name = model_name.split(":")[0]
|
|
||||||
|
|
||||||
# Check exact match first
|
|
||||||
if model_name in EMBEDDING_MODEL_LIMITS:
|
|
||||||
return EMBEDDING_MODEL_LIMITS[model_name]
|
|
||||||
|
|
||||||
# Check base name match
|
|
||||||
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
|
||||||
return EMBEDDING_MODEL_LIMITS[base_model_name]
|
|
||||||
|
|
||||||
# Check partial matches for common patterns
|
|
||||||
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
|
|
||||||
if known_model in base_model_name or base_model_name in known_model:
|
|
||||||
return limit
|
|
||||||
|
|
||||||
# Default fallback
|
|
||||||
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
|
|
||||||
"""
|
|
||||||
Truncate texts to fit within token limit using tiktoken.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of text strings to truncate
|
|
||||||
token_limit: Maximum number of tokens allowed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of truncated texts (same length as input)
|
|
||||||
"""
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Use tiktoken with cl100k_base encoding
|
|
||||||
enc = tiktoken.get_encoding("cl100k_base")
|
|
||||||
|
|
||||||
truncated_texts = []
|
|
||||||
truncation_count = 0
|
|
||||||
total_tokens_removed = 0
|
|
||||||
max_original_length = 0
|
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
tokens = enc.encode(text)
|
|
||||||
original_length = len(tokens)
|
|
||||||
|
|
||||||
if original_length <= token_limit:
|
|
||||||
# Text is within limit, keep as is
|
|
||||||
truncated_texts.append(text)
|
|
||||||
else:
|
|
||||||
# Truncate to token_limit
|
|
||||||
truncated_tokens = tokens[:token_limit]
|
|
||||||
truncated_text = enc.decode(truncated_tokens)
|
|
||||||
truncated_texts.append(truncated_text)
|
|
||||||
|
|
||||||
# Track truncation statistics
|
|
||||||
truncation_count += 1
|
|
||||||
tokens_removed = original_length - token_limit
|
|
||||||
total_tokens_removed += tokens_removed
|
|
||||||
max_original_length = max(max_original_length, original_length)
|
|
||||||
|
|
||||||
# Log individual truncation at WARNING level (first few only)
|
|
||||||
if truncation_count <= 3:
|
|
||||||
logger.warning(
|
|
||||||
f"Text {i + 1} truncated: {original_length} → {token_limit} tokens "
|
|
||||||
f"({tokens_removed} tokens removed)"
|
|
||||||
)
|
|
||||||
elif truncation_count == 4:
|
|
||||||
logger.warning("Further truncation warnings suppressed...")
|
|
||||||
|
|
||||||
# Log summary at INFO level
|
|
||||||
if truncation_count > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
|
|
||||||
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
return truncated_texts
|
|
||||||
|
|
||||||
|
|
||||||
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
Query Ollama /api/show for model context limit.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the Ollama model
|
|
||||||
base_url: Base URL of the Ollama server
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Context limit in tokens if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import requests
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"{base_url}/api/show",
|
|
||||||
json={"name": model_name},
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
if response.status_code == 200:
|
|
||||||
data = response.json()
|
|
||||||
if "model_info" in data:
|
|
||||||
# Look for *.context_length in model_info
|
|
||||||
for key, value in data["model_info"].items():
|
|
||||||
if "context_length" in key and isinstance(value, int):
|
|
||||||
logger.info(f"Detected {model_name} context limit: {value} tokens")
|
|
||||||
return value
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Failed to query Ollama context limit: {e}")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# Global model cache to avoid repeated loading
|
# Global model cache to avoid repeated loading
|
||||||
_model_cache: dict[str, Any] = {}
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
@@ -215,14 +50,9 @@ def compute_embeddings(
|
|||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
"""
|
"""
|
||||||
provider_options = provider_options or {}
|
provider_options = provider_options or {}
|
||||||
wrapper_start_time = time.time()
|
|
||||||
logger.debug(
|
|
||||||
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if mode == "sentence-transformers":
|
if mode == "sentence-transformers":
|
||||||
inner_start_time = time.time()
|
return compute_embeddings_sentence_transformers(
|
||||||
result = compute_embeddings_sentence_transformers(
|
|
||||||
texts,
|
texts,
|
||||||
model_name,
|
model_name,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
@@ -231,14 +61,6 @@ def compute_embeddings(
|
|||||||
manual_tokenize=manual_tokenize,
|
manual_tokenize=manual_tokenize,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
inner_end_time = time.time()
|
|
||||||
wrapper_end_time = time.time()
|
|
||||||
logger.debug(
|
|
||||||
"[compute_embeddings] sentence-transformers timings: "
|
|
||||||
f"inner={inner_end_time - inner_start_time:.6f}s, "
|
|
||||||
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(
|
return compute_embeddings_openai(
|
||||||
texts,
|
texts,
|
||||||
@@ -284,7 +106,6 @@ def compute_embeddings_sentence_transformers(
|
|||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
"""
|
"""
|
||||||
outer_start_time = time.time()
|
|
||||||
# Handle empty input
|
# Handle empty input
|
||||||
if not texts:
|
if not texts:
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
@@ -315,14 +136,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
# Create cache key
|
# Create cache key
|
||||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||||
|
|
||||||
pre_model_init_end_time = time.time()
|
|
||||||
logger.debug(
|
|
||||||
"compute_embeddings_sentence_transformers pre-model-init time "
|
|
||||||
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if model is already cached
|
# Check if model is already cached
|
||||||
start_time = time.time()
|
|
||||||
if cache_key in _model_cache:
|
if cache_key in _model_cache:
|
||||||
logger.info(f"Using cached optimized model: {model_name}")
|
logger.info(f"Using cached optimized model: {model_name}")
|
||||||
model = _model_cache[cache_key]
|
model = _model_cache[cache_key]
|
||||||
@@ -369,73 +183,32 @@ def compute_embeddings_sentence_transformers(
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Try loading with advanced parameters first (newer versions)
|
# Try local loading first
|
||||||
local_model_kwargs = model_kwargs.copy()
|
model_kwargs["local_files_only"] = True
|
||||||
local_tokenizer_kwargs = tokenizer_kwargs.copy()
|
tokenizer_kwargs["local_files_only"] = True
|
||||||
local_model_kwargs["local_files_only"] = True
|
|
||||||
local_tokenizer_kwargs["local_files_only"] = True
|
|
||||||
|
|
||||||
model = SentenceTransformer(
|
model = SentenceTransformer(
|
||||||
model_name,
|
model_name,
|
||||||
device=device,
|
device=device,
|
||||||
model_kwargs=local_model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
tokenizer_kwargs=local_tokenizer_kwargs,
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
)
|
)
|
||||||
logger.info("Model loaded successfully! (local + optimized)")
|
logger.info("Model loaded successfully! (local + optimized)")
|
||||||
except TypeError as e:
|
|
||||||
if "model_kwargs" in str(e) or "tokenizer_kwargs" in str(e):
|
|
||||||
logger.warning(
|
|
||||||
f"Advanced parameters not supported ({e}), using basic initialization..."
|
|
||||||
)
|
|
||||||
# Fallback to basic initialization for older versions
|
|
||||||
try:
|
|
||||||
model = SentenceTransformer(
|
|
||||||
model_name,
|
|
||||||
device=device,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
logger.info("Model loaded successfully! (local + basic)")
|
|
||||||
except Exception as e2:
|
|
||||||
logger.warning(f"Local loading failed ({e2}), trying network download...")
|
|
||||||
model = SentenceTransformer(
|
|
||||||
model_name,
|
|
||||||
device=device,
|
|
||||||
local_files_only=False,
|
|
||||||
)
|
|
||||||
logger.info("Model loaded successfully! (network + basic)")
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Local loading failed ({e}), trying network download...")
|
logger.warning(f"Local loading failed ({e}), trying network download...")
|
||||||
# Fallback to network loading with advanced parameters
|
# Fallback to network loading
|
||||||
try:
|
model_kwargs["local_files_only"] = False
|
||||||
network_model_kwargs = model_kwargs.copy()
|
tokenizer_kwargs["local_files_only"] = False
|
||||||
network_tokenizer_kwargs = tokenizer_kwargs.copy()
|
|
||||||
network_model_kwargs["local_files_only"] = False
|
|
||||||
network_tokenizer_kwargs["local_files_only"] = False
|
|
||||||
|
|
||||||
model = SentenceTransformer(
|
model = SentenceTransformer(
|
||||||
model_name,
|
model_name,
|
||||||
device=device,
|
device=device,
|
||||||
model_kwargs=network_model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
tokenizer_kwargs=network_tokenizer_kwargs,
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
)
|
)
|
||||||
logger.info("Model loaded successfully! (network + optimized)")
|
logger.info("Model loaded successfully! (network + optimized)")
|
||||||
except TypeError as e2:
|
|
||||||
if "model_kwargs" in str(e2) or "tokenizer_kwargs" in str(e2):
|
|
||||||
logger.warning(
|
|
||||||
f"Advanced parameters not supported ({e2}), using basic network loading..."
|
|
||||||
)
|
|
||||||
model = SentenceTransformer(
|
|
||||||
model_name,
|
|
||||||
device=device,
|
|
||||||
local_files_only=False,
|
|
||||||
)
|
|
||||||
logger.info("Model loaded successfully! (network + basic)")
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Apply additional optimizations based on mode
|
# Apply additional optimizations based on mode
|
||||||
if use_fp16 and device in ["cuda", "mps"]:
|
if use_fp16 and device in ["cuda", "mps"]:
|
||||||
@@ -462,13 +235,10 @@ def compute_embeddings_sentence_transformers(
|
|||||||
_model_cache[cache_key] = model
|
_model_cache[cache_key] = model
|
||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
# Compute embeddings with optimized inference mode
|
# Compute embeddings with optimized inference mode
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||||
)
|
)
|
||||||
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if not manual_tokenize:
|
if not manual_tokenize:
|
||||||
@@ -489,46 +259,32 @@ def compute_embeddings_sentence_transformers(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel.
|
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
||||||
# This path is reserved for an aggressively optimized FP pipeline
|
|
||||||
# (no quantization), mainly for experimentation.
|
|
||||||
try:
|
try:
|
||||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||||
|
|
||||||
|
# Cache tokenizer and model
|
||||||
tok_cache_key = f"hf_tokenizer_{model_name}"
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp"
|
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
||||||
|
|
||||||
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||||
hf_tokenizer = _model_cache[tok_cache_key]
|
hf_tokenizer = _model_cache[tok_cache_key]
|
||||||
hf_model = _model_cache[mdl_cache_key]
|
hf_model = _model_cache[mdl_cache_key]
|
||||||
logger.info("Using cached HF tokenizer/model for manual FP path")
|
logger.info("Using cached HF tokenizer/model for manual path")
|
||||||
else:
|
else:
|
||||||
logger.info("Loading HF tokenizer/model for manual FP path")
|
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
|
||||||
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
hf_model = AutoModel.from_pretrained(
|
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
||||||
model_name,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
hf_model.to(device)
|
hf_model.to(device)
|
||||||
|
|
||||||
hf_model.eval()
|
hf_model.eval()
|
||||||
# Optional compile on supported devices
|
# Optional compile on supported devices
|
||||||
if device in ["cuda", "mps"]:
|
if device in ["cuda", "mps"]:
|
||||||
try:
|
try:
|
||||||
hf_model = torch.compile( # type: ignore
|
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
||||||
hf_model, mode="reduce-overhead", dynamic=True
|
except Exception:
|
||||||
)
|
pass
|
||||||
logger.info(
|
|
||||||
f"Applied torch.compile to HF model for {model_name} "
|
|
||||||
f"(device={device}, dtype={torch_dtype})"
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"torch.compile optimization failed: {exc}")
|
|
||||||
|
|
||||||
_model_cache[tok_cache_key] = hf_tokenizer
|
_model_cache[tok_cache_key] = hf_tokenizer
|
||||||
_model_cache[mdl_cache_key] = hf_model
|
_model_cache[mdl_cache_key] = hf_model
|
||||||
|
|
||||||
@@ -554,6 +310,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
for start_index in batch_iter:
|
for start_index in batch_iter:
|
||||||
end_index = min(start_index + batch_size, len(texts))
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
batch_texts = texts[start_index:end_index]
|
batch_texts = texts[start_index:end_index]
|
||||||
|
tokenize_start_time = time.time()
|
||||||
inputs = hf_tokenizer(
|
inputs = hf_tokenizer(
|
||||||
batch_texts,
|
batch_texts,
|
||||||
padding=True,
|
padding=True,
|
||||||
@@ -561,17 +318,34 @@ def compute_embeddings_sentence_transformers(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
tokenize_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
||||||
|
)
|
||||||
|
# Print shapes of all input tensors for debugging
|
||||||
|
for k, v in inputs.items():
|
||||||
|
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
||||||
|
to_device_start_time = time.time()
|
||||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
|
to_device_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
||||||
|
)
|
||||||
|
forward_start_time = time.time()
|
||||||
outputs = hf_model(**inputs)
|
outputs = hf_model(**inputs)
|
||||||
|
forward_end_time = time.time()
|
||||||
|
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
||||||
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
attention_mask = inputs.get("attention_mask")
|
attention_mask = inputs.get("attention_mask")
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
|
# Fallback: assume all tokens are valid
|
||||||
pooled = last_hidden_state.mean(dim=1)
|
pooled = last_hidden_state.mean(dim=1)
|
||||||
else:
|
else:
|
||||||
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||||
masked = last_hidden_state * mask
|
masked = last_hidden_state * mask
|
||||||
lengths = mask.sum(dim=1).clamp(min=1)
|
lengths = mask.sum(dim=1).clamp(min=1)
|
||||||
pooled = masked.sum(dim=1) / lengths
|
pooled = masked.sum(dim=1) / lengths
|
||||||
|
# Move to CPU float32
|
||||||
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||||
all_embeddings.append(batch_embeddings)
|
all_embeddings.append(batch_embeddings)
|
||||||
|
|
||||||
@@ -591,12 +365,6 @@ def compute_embeddings_sentence_transformers(
|
|||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
||||||
|
|
||||||
outer_end_time = time.time()
|
|
||||||
logger.debug(
|
|
||||||
"compute_embeddings_sentence_transformers total time "
|
|
||||||
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -765,10 +533,9 @@ def compute_embeddings_ollama(
|
|||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with true batch processing.
|
Compute embeddings using Ollama API with simplified batch processing.
|
||||||
|
|
||||||
Uses the /api/embed endpoint which supports batch inputs.
|
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
||||||
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
@@ -873,11 +640,11 @@ def compute_embeddings_ollama(
|
|||||||
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
||||||
model_name = resolved_model_name
|
model_name = resolved_model_name
|
||||||
|
|
||||||
# Verify the model supports embeddings by testing it with /api/embed
|
# Verify the model supports embeddings by testing it
|
||||||
try:
|
try:
|
||||||
test_response = requests.post(
|
test_response = requests.post(
|
||||||
f"{resolved_host}/api/embed",
|
f"{resolved_host}/api/embeddings",
|
||||||
json={"model": model_name, "input": "test"},
|
json={"model": model_name, "prompt": "test"},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
if test_response.status_code != 200:
|
if test_response.status_code != 200:
|
||||||
@@ -909,71 +676,56 @@ def compute_embeddings_ollama(
|
|||||||
# If torch is not available, use conservative batch size
|
# If torch is not available, use conservative batch size
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
logger.info(f"Using batch size: {batch_size}")
|
||||||
|
|
||||||
# Get model token limit and apply truncation before batching
|
|
||||||
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
|
||||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
|
||||||
|
|
||||||
# Apply truncation to all texts before batch processing
|
|
||||||
# Function logs truncation details internally
|
|
||||||
texts = truncate_to_token_limit(texts, token_limit)
|
|
||||||
|
|
||||||
def get_batch_embeddings(batch_texts):
|
def get_batch_embeddings(batch_texts):
|
||||||
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
"""Get embeddings for a batch of texts."""
|
||||||
|
all_embeddings = []
|
||||||
|
failed_indices = []
|
||||||
|
|
||||||
|
for i, text in enumerate(batch_texts):
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
|
|
||||||
# Texts are already truncated to token limit by the outer function
|
# Truncate very long texts to avoid API issues
|
||||||
|
truncated_text = text[:8000] if len(text) > 8000 else text
|
||||||
while retry_count < max_retries:
|
while retry_count < max_retries:
|
||||||
try:
|
try:
|
||||||
# Use /api/embed endpoint with "input" parameter for batch processing
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{resolved_host}/api/embed",
|
f"{resolved_host}/api/embeddings",
|
||||||
json={"model": model_name, "input": batch_texts},
|
json={"model": model_name, "prompt": truncated_text},
|
||||||
timeout=60, # Increased timeout for batch processing
|
timeout=30,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
batch_embeddings = result.get("embeddings")
|
embedding = result.get("embedding")
|
||||||
|
|
||||||
if batch_embeddings is None:
|
if embedding is None:
|
||||||
raise ValueError("No embeddings returned from API")
|
raise ValueError(f"No embedding returned for text {i}")
|
||||||
|
|
||||||
if not isinstance(batch_embeddings, list):
|
if not isinstance(embedding, list) or len(embedding) == 0:
|
||||||
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
|
raise ValueError(f"Invalid embedding format for text {i}")
|
||||||
|
|
||||||
if len(batch_embeddings) != len(batch_texts):
|
all_embeddings.append(embedding)
|
||||||
raise ValueError(
|
break
|
||||||
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return batch_embeddings, []
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
if retry_count >= max_retries:
|
if retry_count >= max_retries:
|
||||||
logger.warning(f"Timeout for batch after {max_retries} retries")
|
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
||||||
return None, list(range(len(batch_texts)))
|
failed_indices.append(i)
|
||||||
|
all_embeddings.append(None)
|
||||||
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
if retry_count >= max_retries:
|
if retry_count >= max_retries:
|
||||||
# Enhanced error detection for token limit violations
|
logger.error(f"Failed to get embedding for text {i}: {e}")
|
||||||
error_msg = str(e).lower()
|
failed_indices.append(i)
|
||||||
if "token" in error_msg and (
|
all_embeddings.append(None)
|
||||||
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
|
break
|
||||||
):
|
return all_embeddings, failed_indices
|
||||||
logger.error(
|
|
||||||
f"Token limit exceeded for batch. Error: {e}. "
|
|
||||||
f"Consider reducing chunk sizes or check token truncation."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to get embeddings for batch: {e}")
|
|
||||||
return None, list(range(len(batch_texts)))
|
|
||||||
|
|
||||||
return None, list(range(len(batch_texts)))
|
|
||||||
|
|
||||||
# Process texts in batches
|
# Process texts in batches
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
@@ -991,7 +743,7 @@ def compute_embeddings_ollama(
|
|||||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
if show_progress:
|
if show_progress:
|
||||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
||||||
else:
|
else:
|
||||||
batch_iterator = range(num_batches)
|
batch_iterator = range(num_batches)
|
||||||
|
|
||||||
@@ -1002,14 +754,10 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||||
|
|
||||||
if batch_embeddings is not None:
|
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
else:
|
|
||||||
# Entire batch failed, add None placeholders
|
|
||||||
all_embeddings.extend([None] * len(batch_texts))
|
|
||||||
# Adjust failed indices to global indices
|
# Adjust failed indices to global indices
|
||||||
global_failed = [start_idx + idx for idx in batch_failed]
|
global_failed = [start_idx + idx for idx in batch_failed]
|
||||||
all_failed_indices.extend(global_failed)
|
all_failed_indices.extend(global_failed)
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
# Handle failed embeddings
|
# Handle failed embeddings
|
||||||
if all_failed_indices:
|
if all_failed_indices:
|
||||||
|
|||||||
@@ -1,189 +0,0 @@
|
|||||||
"""
|
|
||||||
Interactive session utilities for LEANN applications.
|
|
||||||
|
|
||||||
Provides shared readline functionality and command handling across
|
|
||||||
CLI, API, and RAG example interactive modes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import atexit
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
# Try to import readline with fallback for Windows
|
|
||||||
try:
|
|
||||||
import readline
|
|
||||||
|
|
||||||
HAS_READLINE = True
|
|
||||||
except ImportError:
|
|
||||||
# Windows doesn't have readline by default
|
|
||||||
HAS_READLINE = False
|
|
||||||
readline = None
|
|
||||||
|
|
||||||
|
|
||||||
class InteractiveSession:
|
|
||||||
"""Manages interactive session with optional readline support and common commands."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
history_name: str,
|
|
||||||
prompt: str = "You: ",
|
|
||||||
welcome_message: str = "",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize interactive session with optional readline support.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
history_name: Name for history file (e.g., "cli", "api_chat")
|
|
||||||
(ignored if readline not available)
|
|
||||||
prompt: Input prompt to display
|
|
||||||
welcome_message: Message to show when starting session
|
|
||||||
|
|
||||||
Note:
|
|
||||||
On systems without readline (e.g., Windows), falls back to basic input()
|
|
||||||
with limited functionality (no history, no line editing).
|
|
||||||
"""
|
|
||||||
self.history_name = history_name
|
|
||||||
self.prompt = prompt
|
|
||||||
self.welcome_message = welcome_message
|
|
||||||
self._setup_complete = False
|
|
||||||
|
|
||||||
def setup_readline(self):
|
|
||||||
"""Setup readline with history support (if available)."""
|
|
||||||
if self._setup_complete:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not HAS_READLINE:
|
|
||||||
# Readline not available (likely Windows), skip setup
|
|
||||||
self._setup_complete = True
|
|
||||||
return
|
|
||||||
|
|
||||||
# History file setup
|
|
||||||
history_dir = Path.home() / ".leann" / "history"
|
|
||||||
history_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
history_file = history_dir / f"{self.history_name}.history"
|
|
||||||
|
|
||||||
# Load history if exists
|
|
||||||
try:
|
|
||||||
readline.read_history_file(str(history_file))
|
|
||||||
readline.set_history_length(1000)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Save history on exit
|
|
||||||
atexit.register(readline.write_history_file, str(history_file))
|
|
||||||
|
|
||||||
# Optional: Enable vi editing mode (commented out by default)
|
|
||||||
# readline.parse_and_bind("set editing-mode vi")
|
|
||||||
|
|
||||||
self._setup_complete = True
|
|
||||||
|
|
||||||
def _show_help(self):
|
|
||||||
"""Show available commands."""
|
|
||||||
print("Commands:")
|
|
||||||
print(" quit/exit/q - Exit the chat")
|
|
||||||
print(" help - Show this help message")
|
|
||||||
print(" clear - Clear screen")
|
|
||||||
print(" history - Show command history")
|
|
||||||
|
|
||||||
def _show_history(self):
|
|
||||||
"""Show command history."""
|
|
||||||
if not HAS_READLINE:
|
|
||||||
print(" History not available (readline not supported on this system)")
|
|
||||||
return
|
|
||||||
|
|
||||||
history_length = readline.get_current_history_length()
|
|
||||||
if history_length == 0:
|
|
||||||
print(" No history available")
|
|
||||||
return
|
|
||||||
|
|
||||||
for i in range(history_length):
|
|
||||||
item = readline.get_history_item(i + 1)
|
|
||||||
if item:
|
|
||||||
print(f" {i + 1}: {item}")
|
|
||||||
|
|
||||||
def get_user_input(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get user input with readline support.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User input string, or None if EOF (Ctrl+D)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return input(self.prompt).strip()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n(Use 'quit' to exit)")
|
|
||||||
return "" # Return empty string to continue
|
|
||||||
except EOFError:
|
|
||||||
print("\nGoodbye!")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def run_interactive_loop(self, handler_func: Callable[[str], None]):
|
|
||||||
"""
|
|
||||||
Run the interactive loop with a custom handler function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
handler_func: Function to handle user input that's not a built-in command
|
|
||||||
Should accept a string and handle the user's query
|
|
||||||
"""
|
|
||||||
self.setup_readline()
|
|
||||||
|
|
||||||
if self.welcome_message:
|
|
||||||
print(self.welcome_message)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
user_input = self.get_user_input()
|
|
||||||
|
|
||||||
if user_input is None: # EOF (Ctrl+D)
|
|
||||||
break
|
|
||||||
|
|
||||||
if not user_input: # Empty input or KeyboardInterrupt
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle built-in commands
|
|
||||||
command = user_input.lower()
|
|
||||||
if command in ["quit", "exit", "q"]:
|
|
||||||
print("Goodbye!")
|
|
||||||
break
|
|
||||||
elif command == "help":
|
|
||||||
self._show_help()
|
|
||||||
elif command == "clear":
|
|
||||||
os.system("clear" if os.name != "nt" else "cls")
|
|
||||||
elif command == "history":
|
|
||||||
self._show_history()
|
|
||||||
else:
|
|
||||||
# Regular user input - pass to handler
|
|
||||||
try:
|
|
||||||
handler_func(user_input)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def create_cli_session(index_name: str) -> InteractiveSession:
|
|
||||||
"""Create an interactive session for CLI usage."""
|
|
||||||
return InteractiveSession(
|
|
||||||
history_name=index_name,
|
|
||||||
prompt="\nYou: ",
|
|
||||||
welcome_message="LEANN Assistant ready! Type 'quit' to exit, 'help' for commands\n"
|
|
||||||
+ "=" * 40,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_api_session() -> InteractiveSession:
|
|
||||||
"""Create an interactive session for API chat."""
|
|
||||||
return InteractiveSession(
|
|
||||||
history_name="api_chat",
|
|
||||||
prompt="You: ",
|
|
||||||
welcome_message="Leann Chat started (type 'quit' to exit, 'help' for commands)\n"
|
|
||||||
+ "=" * 40,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_rag_session(app_name: str, data_description: str) -> InteractiveSession:
|
|
||||||
"""Create an interactive session for RAG examples."""
|
|
||||||
return InteractiveSession(
|
|
||||||
history_name=f"{app_name}_rag",
|
|
||||||
prompt="You: ",
|
|
||||||
welcome_message=f"[Interactive Mode] Chat with your {data_description} data!\nType 'quit' or 'exit' to stop, 'help' for commands.\n"
|
|
||||||
+ "=" * 40,
|
|
||||||
)
|
|
||||||
@@ -60,11 +60,6 @@ def handle_request(request):
|
|||||||
"maximum": 128,
|
"maximum": 128,
|
||||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
||||||
},
|
},
|
||||||
"show_metadata": {
|
|
||||||
"type": "boolean",
|
|
||||||
"default": False,
|
|
||||||
"description": "Include file paths and metadata in search results. Useful for understanding which files contain the results.",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"required": ["index_name", "query"],
|
"required": ["index_name", "query"],
|
||||||
},
|
},
|
||||||
@@ -109,8 +104,6 @@ def handle_request(request):
|
|||||||
f"--complexity={args.get('complexity', 32)}",
|
f"--complexity={args.get('complexity', 32)}",
|
||||||
"--non-interactive",
|
"--non-interactive",
|
||||||
]
|
]
|
||||||
if args.get("show_metadata", False):
|
|
||||||
cmd.append("--show-metadata")
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
elif tool_name == "leann_list":
|
elif tool_name == "leann_list":
|
||||||
|
|||||||
@@ -22,10 +22,7 @@ dependencies = [
|
|||||||
"sglang",
|
"sglang",
|
||||||
"ollama",
|
"ollama",
|
||||||
"requests>=2.25.0",
|
"requests>=2.25.0",
|
||||||
"sentence-transformers>=3.0.0",
|
"sentence-transformers>=2.2.0",
|
||||||
# Pin transformers below 4.46: 4.46.0 introduced Python 3.10-only typing (PEP 604) and
|
|
||||||
# breaks our Python 3.9 test matrix when pulled in by sentence-transformers.
|
|
||||||
"transformers<4.46",
|
|
||||||
"openai>=1.0.0",
|
"openai>=1.0.0",
|
||||||
# PDF parsing dependencies - essential for document processing
|
# PDF parsing dependencies - essential for document processing
|
||||||
"PyPDF2>=3.0.0",
|
"PyPDF2>=3.0.0",
|
||||||
@@ -57,8 +54,6 @@ dependencies = [
|
|||||||
"tree-sitter-c-sharp>=0.20.0",
|
"tree-sitter-c-sharp>=0.20.0",
|
||||||
"tree-sitter-typescript>=0.20.0",
|
"tree-sitter-typescript>=0.20.0",
|
||||||
"torchvision>=0.23.0",
|
"torchvision>=0.23.0",
|
||||||
"einops",
|
|
||||||
"seaborn",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -116,10 +116,8 @@ class TestChunkingFunctions:
|
|||||||
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
# Traditional chunks now return dict format for consistency
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
assert all(isinstance(chunk, dict) for chunk in chunks)
|
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
|
|
||||||
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
|
|
||||||
|
|
||||||
def test_create_traditional_chunks_empty_docs(self):
|
def test_create_traditional_chunks_empty_docs(self):
|
||||||
"""Test traditional chunking with empty documents."""
|
"""Test traditional chunking with empty documents."""
|
||||||
@@ -160,22 +158,11 @@ class Calculator:
|
|||||||
|
|
||||||
# Should have multiple chunks due to different functions/classes
|
# Should have multiple chunks due to different functions/classes
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
# R3: Expect dict format with "text" and "metadata" keys
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
|
||||||
"Each chunk should have 'text' and 'metadata' keys"
|
|
||||||
)
|
|
||||||
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
|
|
||||||
"Each chunk text should be non-empty"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check metadata is present
|
|
||||||
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
|
|
||||||
"Each chunk should have file_path metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that code structure is somewhat preserved
|
# Check that code structure is somewhat preserved
|
||||||
combined_content = " ".join([c["text"] for c in chunks])
|
combined_content = " ".join(chunks)
|
||||||
assert "def hello_world" in combined_content
|
assert "def hello_world" in combined_content
|
||||||
assert "class Calculator" in combined_content
|
assert "class Calculator" in combined_content
|
||||||
|
|
||||||
@@ -207,11 +194,7 @@ class Calculator:
|
|||||||
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
# R3: Traditional chunking should also return dict format for consistency
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
|
||||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
|
||||||
"Each chunk should have 'text' and 'metadata' keys"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_create_text_chunks_ast_mode(self):
|
def test_create_text_chunks_ast_mode(self):
|
||||||
"""Test text chunking in AST mode."""
|
"""Test text chunking in AST mode."""
|
||||||
@@ -230,11 +213,7 @@ class Calculator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
# R3: AST mode should also return dict format
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
|
||||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
|
||||||
"Each chunk should have 'text' and 'metadata' keys"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_create_text_chunks_custom_extensions(self):
|
def test_create_text_chunks_custom_extensions(self):
|
||||||
"""Test text chunking with custom code file extensions."""
|
"""Test text chunking with custom code file extensions."""
|
||||||
@@ -374,552 +353,6 @@ class MathUtils:
|
|||||||
pytest.skip("Test timed out - likely due to model download in CI")
|
pytest.skip("Test timed out - likely due to model download in CI")
|
||||||
|
|
||||||
|
|
||||||
class TestASTContentExtraction:
|
|
||||||
"""Test AST content extraction bug fix.
|
|
||||||
|
|
||||||
These tests verify that astchunk's dict format with 'content' key is handled correctly,
|
|
||||||
and that the extraction logic doesn't fall through to stringifying entire dicts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_extract_content_from_astchunk_dict(self):
|
|
||||||
"""Test that astchunk dict format with 'content' key is handled correctly.
|
|
||||||
|
|
||||||
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
|
|
||||||
This causes fallthrough to str(chunk), stringifying the entire dict.
|
|
||||||
|
|
||||||
This test will FAIL until the bug is fixed because:
|
|
||||||
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
|
|
||||||
- Fixed code should extract just the content value
|
|
||||||
"""
|
|
||||||
# Mock the ASTChunkBuilder class
|
|
||||||
mock_builder = Mock()
|
|
||||||
|
|
||||||
# Astchunk returns this format
|
|
||||||
astchunk_format_chunk = {
|
|
||||||
"content": "def hello():\n print('world')",
|
|
||||||
"metadata": {
|
|
||||||
"filepath": "test.py",
|
|
||||||
"line_count": 2,
|
|
||||||
"start_line_no": 0,
|
|
||||||
"end_line_no": 1,
|
|
||||||
"node_count": 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
mock_builder.chunkify.return_value = [astchunk_format_chunk]
|
|
||||||
|
|
||||||
# Create mock document
|
|
||||||
doc = MockDocument(
|
|
||||||
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the astchunk module and its ASTChunkBuilder class
|
|
||||||
mock_astchunk = Mock()
|
|
||||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
|
||||||
|
|
||||||
# Patch sys.modules to inject our mock before the import
|
|
||||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
|
||||||
# Call create_ast_chunks
|
|
||||||
chunks = create_ast_chunks([doc])
|
|
||||||
|
|
||||||
# R3: Should return dict format with proper metadata
|
|
||||||
assert len(chunks) > 0, "Should return at least one chunk"
|
|
||||||
|
|
||||||
# R3: Each chunk should be a dict
|
|
||||||
chunk = chunks[0]
|
|
||||||
assert isinstance(chunk, dict), "Chunk should be a dict"
|
|
||||||
assert "text" in chunk, "Chunk should have 'text' key"
|
|
||||||
assert "metadata" in chunk, "Chunk should have 'metadata' key"
|
|
||||||
|
|
||||||
chunk_text = chunk["text"]
|
|
||||||
|
|
||||||
# CRITICAL: Should NOT contain stringified dict markers in the text field
|
|
||||||
# These assertions will FAIL with current buggy code
|
|
||||||
assert "'content':" not in chunk_text, (
|
|
||||||
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
|
|
||||||
)
|
|
||||||
assert "'metadata':" not in chunk_text, (
|
|
||||||
"Chunk text contains stringified metadata - extraction failed! "
|
|
||||||
f"Got: {chunk_text[:100]}..."
|
|
||||||
)
|
|
||||||
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
|
|
||||||
"Chunk text appears to be a stringified dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should contain actual content
|
|
||||||
assert "def hello()" in chunk_text, "Should extract actual code content"
|
|
||||||
assert "print('world')" in chunk_text, "Should extract complete code content"
|
|
||||||
|
|
||||||
# R3: Should preserve astchunk metadata
|
|
||||||
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
|
|
||||||
"Should preserve file path metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_extract_text_key_fallback(self):
|
|
||||||
"""Test that 'text' key still works for backward compatibility.
|
|
||||||
|
|
||||||
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
|
|
||||||
This test should PASS even with current code.
|
|
||||||
"""
|
|
||||||
mock_builder = Mock()
|
|
||||||
|
|
||||||
# Some chunks might use "text" key
|
|
||||||
text_key_chunk = {"text": "def legacy_function():\n return True"}
|
|
||||||
mock_builder.chunkify.return_value = [text_key_chunk]
|
|
||||||
|
|
||||||
# Create mock document
|
|
||||||
doc = MockDocument(
|
|
||||||
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the astchunk module
|
|
||||||
mock_astchunk = Mock()
|
|
||||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
|
||||||
|
|
||||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
|
||||||
# Call create_ast_chunks
|
|
||||||
chunks = create_ast_chunks([doc])
|
|
||||||
|
|
||||||
# R3: Should extract text correctly as dict format
|
|
||||||
assert len(chunks) > 0
|
|
||||||
chunk = chunks[0]
|
|
||||||
assert isinstance(chunk, dict), "Chunk should be a dict"
|
|
||||||
assert "text" in chunk, "Chunk should have 'text' key"
|
|
||||||
|
|
||||||
chunk_text = chunk["text"]
|
|
||||||
|
|
||||||
# Should NOT be stringified
|
|
||||||
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
|
|
||||||
|
|
||||||
# Should contain actual content
|
|
||||||
assert "def legacy_function()" in chunk_text
|
|
||||||
assert "return True" in chunk_text
|
|
||||||
|
|
||||||
def test_handles_string_chunks(self):
|
|
||||||
"""Test that plain string chunks still work.
|
|
||||||
|
|
||||||
Some chunkers might return plain strings - verify these are preserved.
|
|
||||||
This test should PASS with current code.
|
|
||||||
"""
|
|
||||||
mock_builder = Mock()
|
|
||||||
|
|
||||||
# Plain string chunk
|
|
||||||
plain_string_chunk = "def simple_function():\n pass"
|
|
||||||
mock_builder.chunkify.return_value = [plain_string_chunk]
|
|
||||||
|
|
||||||
# Create mock document
|
|
||||||
doc = MockDocument(
|
|
||||||
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the astchunk module
|
|
||||||
mock_astchunk = Mock()
|
|
||||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
|
||||||
|
|
||||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
|
||||||
# Call create_ast_chunks
|
|
||||||
chunks = create_ast_chunks([doc])
|
|
||||||
|
|
||||||
# R3: Should wrap string in dict format
|
|
||||||
assert len(chunks) > 0
|
|
||||||
chunk = chunks[0]
|
|
||||||
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
|
|
||||||
assert "text" in chunk, "Chunk should have 'text' key"
|
|
||||||
|
|
||||||
chunk_text = chunk["text"]
|
|
||||||
|
|
||||||
assert chunk_text == plain_string_chunk.strip(), (
|
|
||||||
"Should preserve plain string chunk content"
|
|
||||||
)
|
|
||||||
assert "def simple_function()" in chunk_text
|
|
||||||
assert "pass" in chunk_text
|
|
||||||
|
|
||||||
def test_multiple_chunks_with_mixed_formats(self):
|
|
||||||
"""Test handling of multiple chunks with different formats.
|
|
||||||
|
|
||||||
Real-world scenario: astchunk might return a mix of formats.
|
|
||||||
This test will FAIL if any chunk with 'content' key gets stringified.
|
|
||||||
"""
|
|
||||||
mock_builder = Mock()
|
|
||||||
|
|
||||||
# Mix of formats
|
|
||||||
mixed_chunks = [
|
|
||||||
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
|
|
||||||
"def second():\n return 2", # Plain string
|
|
||||||
{"text": "def third():\n return 3"}, # Old format
|
|
||||||
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
|
|
||||||
]
|
|
||||||
mock_builder.chunkify.return_value = mixed_chunks
|
|
||||||
|
|
||||||
# Create mock document
|
|
||||||
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
|
|
||||||
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
|
|
||||||
|
|
||||||
# Mock the astchunk module
|
|
||||||
mock_astchunk = Mock()
|
|
||||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
|
||||||
|
|
||||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
|
||||||
# Call create_ast_chunks
|
|
||||||
chunks = create_ast_chunks([doc])
|
|
||||||
|
|
||||||
# R3: Should extract all chunks correctly as dicts
|
|
||||||
assert len(chunks) == 4, "Should extract all 4 chunks"
|
|
||||||
|
|
||||||
# Check each chunk
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
|
|
||||||
assert "text" in chunk, f"Chunk {i} should have 'text' key"
|
|
||||||
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
|
|
||||||
|
|
||||||
chunk_text = chunk["text"]
|
|
||||||
# None should be stringified dicts
|
|
||||||
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
|
|
||||||
assert "'metadata':" not in chunk_text, (
|
|
||||||
f"Chunk {i} text is stringified (has 'metadata':)"
|
|
||||||
)
|
|
||||||
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
|
|
||||||
|
|
||||||
# Verify actual content is present
|
|
||||||
combined = "\n".join([c["text"] for c in chunks])
|
|
||||||
assert "def first()" in combined
|
|
||||||
assert "def second()" in combined
|
|
||||||
assert "def third()" in combined
|
|
||||||
assert "class MyClass:" in combined
|
|
||||||
|
|
||||||
def test_empty_content_value_handling(self):
|
|
||||||
"""Test handling of chunks with empty content values.
|
|
||||||
|
|
||||||
Edge case: chunk has 'content' key but value is empty.
|
|
||||||
Should skip these chunks, not stringify them.
|
|
||||||
"""
|
|
||||||
mock_builder = Mock()
|
|
||||||
|
|
||||||
chunks_with_empty = [
|
|
||||||
{"content": "", "metadata": {"line_count": 0}}, # Empty content
|
|
||||||
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
|
|
||||||
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
|
|
||||||
]
|
|
||||||
mock_builder.chunkify.return_value = chunks_with_empty
|
|
||||||
|
|
||||||
doc = MockDocument(
|
|
||||||
"def valid():\n return True", "/test/empty.py", {"language": "python"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the astchunk module
|
|
||||||
mock_astchunk = Mock()
|
|
||||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
|
||||||
|
|
||||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
|
||||||
chunks = create_ast_chunks([doc])
|
|
||||||
|
|
||||||
# R3: Should only have the valid chunk (empty ones filtered out)
|
|
||||||
assert len(chunks) == 1, "Should filter out empty content chunks"
|
|
||||||
|
|
||||||
chunk = chunks[0]
|
|
||||||
assert isinstance(chunk, dict), "Chunk should be a dict"
|
|
||||||
assert "text" in chunk, "Chunk should have 'text' key"
|
|
||||||
assert "def valid()" in chunk["text"]
|
|
||||||
|
|
||||||
# Should not have stringified the empty dict
|
|
||||||
assert "'content': ''" not in chunk["text"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestASTMetadataPreservation:
|
|
||||||
"""Test metadata preservation in AST chunk dictionaries.
|
|
||||||
|
|
||||||
R3: These tests define the contract for metadata preservation when returning
|
|
||||||
chunk dictionaries instead of plain strings. Each chunk dict should have:
|
|
||||||
- "text": str - the actual chunk content
|
|
||||||
- "metadata": dict - all metadata from document AND astchunk
|
|
||||||
|
|
||||||
These tests will FAIL until G3 implementation changes return type to list[dict].
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_ast_chunks_preserve_file_metadata(self):
|
|
||||||
"""Test that document metadata is preserved in chunk metadata.
|
|
||||||
|
|
||||||
This test verifies that all document-level metadata (file_path, file_name,
|
|
||||||
creation_date, last_modified_date) is included in each chunk's metadata dict.
|
|
||||||
|
|
||||||
This will FAIL because current code returns list[str], not list[dict].
|
|
||||||
"""
|
|
||||||
# Create mock document with rich metadata
|
|
||||||
python_code = '''
|
|
||||||
def calculate_sum(numbers):
|
|
||||||
"""Calculate sum of numbers."""
|
|
||||||
return sum(numbers)
|
|
||||||
|
|
||||||
class DataProcessor:
|
|
||||||
"""Process data records."""
|
|
||||||
|
|
||||||
def process(self, data):
|
|
||||||
return [x * 2 for x in data]
|
|
||||||
'''
|
|
||||||
doc = MockDocument(
|
|
||||||
python_code,
|
|
||||||
file_path="/project/src/utils.py",
|
|
||||||
metadata={
|
|
||||||
"language": "python",
|
|
||||||
"file_path": "/project/src/utils.py",
|
|
||||||
"file_name": "utils.py",
|
|
||||||
"creation_date": "2024-01-15T10:30:00",
|
|
||||||
"last_modified_date": "2024-10-31T15:45:00",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock astchunk to return chunks with metadata
|
|
||||||
mock_builder = Mock()
|
|
||||||
astchunk_chunks = [
|
|
||||||
{
|
|
||||||
"content": "def calculate_sum(numbers):\n return sum(numbers)",
|
|
||||||
"metadata": {
|
|
||||||
"filepath": "/project/src/utils.py",
|
|
||||||
"line_count": 2,
|
|
||||||
"start_line_no": 1,
|
|
||||||
"end_line_no": 2,
|
|
||||||
"node_count": 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
|
|
||||||
"metadata": {
|
|
||||||
"filepath": "/project/src/utils.py",
|
|
||||||
"line_count": 3,
|
|
||||||
"start_line_no": 5,
|
|
||||||
"end_line_no": 7,
|
|
||||||
"node_count": 2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
mock_builder.chunkify.return_value = astchunk_chunks
|
|
||||||
|
|
||||||
mock_astchunk = Mock()
|
|
||||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
|
||||||
|
|
||||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
|
||||||
chunks = create_ast_chunks([doc])
|
|
||||||
|
|
||||||
# CRITICAL: These assertions will FAIL with current list[str] return type
|
|
||||||
assert len(chunks) == 2, "Should return 2 chunks"
|
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
# Structure assertions - WILL FAIL: current code returns strings
|
|
||||||
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
|
||||||
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
|
||||||
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
|
||||||
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
|
|
||||||
|
|
||||||
# Document metadata preservation - WILL FAIL
|
|
||||||
metadata = chunk["metadata"]
|
|
||||||
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
|
|
||||||
assert metadata["file_path"] == "/project/src/utils.py", (
|
|
||||||
f"Chunk {i} file_path incorrect"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
|
|
||||||
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
|
|
||||||
|
|
||||||
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
|
|
||||||
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
|
|
||||||
f"Chunk {i} creation_date incorrect"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
|
|
||||||
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
|
|
||||||
f"Chunk {i} last_modified_date incorrect"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify metadata is consistent across chunks from same document
|
|
||||||
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
|
|
||||||
"All chunks from same document should have same file_path"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify text content is present and not stringified
|
|
||||||
assert "def calculate_sum" in chunks[0]["text"]
|
|
||||||
assert "class DataProcessor" in chunks[1]["text"]
|
|
||||||
|
|
||||||
def test_ast_chunks_include_astchunk_metadata(self):
|
|
||||||
"""Test that astchunk-specific metadata is merged into chunk metadata.
|
|
||||||
|
|
||||||
This test verifies that astchunk's metadata (line_count, start_line_no,
|
|
||||||
end_line_no, node_count) is merged with document metadata.
|
|
||||||
|
|
||||||
This will FAIL because current code returns list[str], not list[dict].
|
|
||||||
"""
|
|
||||||
python_code = '''
|
|
||||||
def function_one():
|
|
||||||
"""First function."""
|
|
||||||
x = 1
|
|
||||||
y = 2
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
def function_two():
|
|
||||||
"""Second function."""
|
|
||||||
return 42
|
|
||||||
'''
|
|
||||||
doc = MockDocument(
|
|
||||||
python_code,
|
|
||||||
file_path="/test/code.py",
|
|
||||||
metadata={
|
|
||||||
"language": "python",
|
|
||||||
"file_path": "/test/code.py",
|
|
||||||
"file_name": "code.py",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock astchunk with detailed metadata
|
|
||||||
mock_builder = Mock()
|
|
||||||
astchunk_chunks = [
|
|
||||||
{
|
|
||||||
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
|
|
||||||
"metadata": {
|
|
||||||
"filepath": "/test/code.py",
|
|
||||||
"line_count": 4,
|
|
||||||
"start_line_no": 1,
|
|
||||||
"end_line_no": 4,
|
|
||||||
"node_count": 5, # function, assignments, return
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"content": "def function_two():\n return 42",
|
|
||||||
"metadata": {
|
|
||||||
"filepath": "/test/code.py",
|
|
||||||
"line_count": 2,
|
|
||||||
"start_line_no": 7,
|
|
||||||
"end_line_no": 8,
|
|
||||||
"node_count": 2, # function, return
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
mock_builder.chunkify.return_value = astchunk_chunks
|
|
||||||
|
|
||||||
mock_astchunk = Mock()
|
|
||||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
|
||||||
|
|
||||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
|
||||||
chunks = create_ast_chunks([doc])
|
|
||||||
|
|
||||||
# CRITICAL: These will FAIL with current list[str] return
|
|
||||||
assert len(chunks) == 2
|
|
||||||
|
|
||||||
# First chunk - function_one
|
|
||||||
chunk1 = chunks[0]
|
|
||||||
assert isinstance(chunk1, dict), "Chunk should be dict"
|
|
||||||
assert "metadata" in chunk1
|
|
||||||
|
|
||||||
metadata1 = chunk1["metadata"]
|
|
||||||
|
|
||||||
# Check astchunk metadata is present
|
|
||||||
assert "line_count" in metadata1, "Should include astchunk line_count"
|
|
||||||
assert metadata1["line_count"] == 4, "line_count should be 4"
|
|
||||||
|
|
||||||
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
|
|
||||||
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
|
|
||||||
|
|
||||||
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
|
|
||||||
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
|
|
||||||
|
|
||||||
assert "node_count" in metadata1, "Should include astchunk node_count"
|
|
||||||
assert metadata1["node_count"] == 5, "node_count should be 5"
|
|
||||||
|
|
||||||
# Second chunk - function_two
|
|
||||||
chunk2 = chunks[1]
|
|
||||||
metadata2 = chunk2["metadata"]
|
|
||||||
|
|
||||||
assert metadata2["line_count"] == 2, "line_count should be 2"
|
|
||||||
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
|
|
||||||
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
|
|
||||||
assert metadata2["node_count"] == 2, "node_count should be 2"
|
|
||||||
|
|
||||||
# Verify document metadata is ALSO present (merged, not replaced)
|
|
||||||
assert metadata1["file_path"] == "/test/code.py"
|
|
||||||
assert metadata1["file_name"] == "code.py"
|
|
||||||
assert metadata2["file_path"] == "/test/code.py"
|
|
||||||
assert metadata2["file_name"] == "code.py"
|
|
||||||
|
|
||||||
# Verify text content is correct
|
|
||||||
assert "def function_one" in chunk1["text"]
|
|
||||||
assert "def function_two" in chunk2["text"]
|
|
||||||
|
|
||||||
def test_traditional_chunks_as_dicts_helper(self):
|
|
||||||
"""Test the helper function that wraps traditional chunks as dicts.
|
|
||||||
|
|
||||||
This test verifies that when create_traditional_chunks is called,
|
|
||||||
its plain string chunks are wrapped into dict format with metadata.
|
|
||||||
|
|
||||||
This will FAIL because the helper function _traditional_chunks_as_dicts()
|
|
||||||
doesn't exist yet, and create_traditional_chunks returns list[str].
|
|
||||||
"""
|
|
||||||
# Create documents with various metadata
|
|
||||||
docs = [
|
|
||||||
MockDocument(
|
|
||||||
"This is the first paragraph of text. It contains multiple sentences. "
|
|
||||||
"This should be split into chunks based on size.",
|
|
||||||
file_path="/docs/readme.txt",
|
|
||||||
metadata={
|
|
||||||
"file_path": "/docs/readme.txt",
|
|
||||||
"file_name": "readme.txt",
|
|
||||||
"creation_date": "2024-01-01",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
MockDocument(
|
|
||||||
"Second document with different metadata. It also has content that needs chunking.",
|
|
||||||
file_path="/docs/guide.md",
|
|
||||||
metadata={
|
|
||||||
"file_path": "/docs/guide.md",
|
|
||||||
"file_name": "guide.md",
|
|
||||||
"last_modified_date": "2024-10-31",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Call create_traditional_chunks (which should now return list[dict])
|
|
||||||
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
|
||||||
|
|
||||||
# CRITICAL: Will FAIL - current code returns list[str]
|
|
||||||
assert len(chunks) > 0, "Should return chunks"
|
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
# Structure assertions - WILL FAIL
|
|
||||||
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
|
||||||
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
|
||||||
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
|
||||||
|
|
||||||
# Text should be non-empty
|
|
||||||
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
|
|
||||||
|
|
||||||
# Metadata should include document info
|
|
||||||
metadata = chunk["metadata"]
|
|
||||||
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
|
|
||||||
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
|
|
||||||
|
|
||||||
# Verify metadata tracking works correctly
|
|
||||||
# At least one chunk should be from readme.txt
|
|
||||||
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
|
|
||||||
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
|
|
||||||
|
|
||||||
# At least one chunk should be from guide.md
|
|
||||||
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
|
|
||||||
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
|
|
||||||
|
|
||||||
# Verify creation_date is preserved for readme chunks
|
|
||||||
for chunk in readme_chunks:
|
|
||||||
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
|
|
||||||
"readme.txt chunks should preserve creation_date"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify last_modified_date is preserved for guide chunks
|
|
||||||
for chunk in guide_chunks:
|
|
||||||
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
|
|
||||||
"guide.md chunks should preserve last_modified_date"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify text content is present
|
|
||||||
all_text = " ".join([c["text"] for c in chunks])
|
|
||||||
assert "first paragraph" in all_text
|
|
||||||
assert "Second document" in all_text
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
class TestErrorHandling:
|
||||||
"""Test error handling and edge cases."""
|
"""Test error handling and edge cases."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,208 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for MCP integration implementations.
|
|
||||||
|
|
||||||
This script tests the basic functionality of the MCP readers and RAG applications
|
|
||||||
without requiring actual MCP servers to be running.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add the parent directory to the path so we can import from apps
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
|
||||||
from apps.slack_rag import SlackMCPRAG
|
|
||||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
|
||||||
from apps.twitter_rag import TwitterMCPRAG
|
|
||||||
|
|
||||||
|
|
||||||
def test_slack_reader_initialization():
|
|
||||||
"""Test that SlackMCPReader can be initialized with various parameters."""
|
|
||||||
print("Testing SlackMCPReader initialization...")
|
|
||||||
|
|
||||||
# Test basic initialization
|
|
||||||
reader = SlackMCPReader("slack-mcp-server")
|
|
||||||
assert reader.mcp_server_command == "slack-mcp-server"
|
|
||||||
assert reader.concatenate_conversations
|
|
||||||
assert reader.max_messages_per_conversation == 100
|
|
||||||
|
|
||||||
# Test with custom parameters
|
|
||||||
reader = SlackMCPReader(
|
|
||||||
"custom-slack-server",
|
|
||||||
workspace_name="test-workspace",
|
|
||||||
concatenate_conversations=False,
|
|
||||||
max_messages_per_conversation=50,
|
|
||||||
)
|
|
||||||
assert reader.workspace_name == "test-workspace"
|
|
||||||
assert not reader.concatenate_conversations
|
|
||||||
assert reader.max_messages_per_conversation == 50
|
|
||||||
|
|
||||||
print("✅ SlackMCPReader initialization tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_twitter_reader_initialization():
|
|
||||||
"""Test that TwitterMCPReader can be initialized with various parameters."""
|
|
||||||
print("Testing TwitterMCPReader initialization...")
|
|
||||||
|
|
||||||
# Test basic initialization
|
|
||||||
reader = TwitterMCPReader("twitter-mcp-server")
|
|
||||||
assert reader.mcp_server_command == "twitter-mcp-server"
|
|
||||||
assert reader.include_tweet_content
|
|
||||||
assert reader.include_metadata
|
|
||||||
assert reader.max_bookmarks == 1000
|
|
||||||
|
|
||||||
# Test with custom parameters
|
|
||||||
reader = TwitterMCPReader(
|
|
||||||
"custom-twitter-server",
|
|
||||||
username="testuser",
|
|
||||||
include_tweet_content=False,
|
|
||||||
include_metadata=False,
|
|
||||||
max_bookmarks=500,
|
|
||||||
)
|
|
||||||
assert reader.username == "testuser"
|
|
||||||
assert not reader.include_tweet_content
|
|
||||||
assert not reader.include_metadata
|
|
||||||
assert reader.max_bookmarks == 500
|
|
||||||
|
|
||||||
print("✅ TwitterMCPReader initialization tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_slack_message_formatting():
|
|
||||||
"""Test Slack message formatting functionality."""
|
|
||||||
print("Testing Slack message formatting...")
|
|
||||||
|
|
||||||
reader = SlackMCPReader("slack-mcp-server")
|
|
||||||
|
|
||||||
# Test basic message formatting
|
|
||||||
message = {
|
|
||||||
"text": "Hello, world!",
|
|
||||||
"user": "john_doe",
|
|
||||||
"channel": "general",
|
|
||||||
"ts": "1234567890.123456",
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = reader._format_message(message)
|
|
||||||
assert "Channel: #general" in formatted
|
|
||||||
assert "User: john_doe" in formatted
|
|
||||||
assert "Message: Hello, world!" in formatted
|
|
||||||
assert "Time:" in formatted
|
|
||||||
|
|
||||||
# Test with missing fields
|
|
||||||
message = {"text": "Simple message"}
|
|
||||||
formatted = reader._format_message(message)
|
|
||||||
assert "Message: Simple message" in formatted
|
|
||||||
|
|
||||||
print("✅ Slack message formatting tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_twitter_bookmark_formatting():
|
|
||||||
"""Test Twitter bookmark formatting functionality."""
|
|
||||||
print("Testing Twitter bookmark formatting...")
|
|
||||||
|
|
||||||
reader = TwitterMCPReader("twitter-mcp-server")
|
|
||||||
|
|
||||||
# Test basic bookmark formatting
|
|
||||||
bookmark = {
|
|
||||||
"text": "This is a great article about AI!",
|
|
||||||
"author": "ai_researcher",
|
|
||||||
"created_at": "2024-01-01T12:00:00Z",
|
|
||||||
"url": "https://twitter.com/ai_researcher/status/123456789",
|
|
||||||
"likes": 42,
|
|
||||||
"retweets": 15,
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = reader._format_bookmark(bookmark)
|
|
||||||
assert "=== Twitter Bookmark ===" in formatted
|
|
||||||
assert "Author: @ai_researcher" in formatted
|
|
||||||
assert "Content:" in formatted
|
|
||||||
assert "This is a great article about AI!" in formatted
|
|
||||||
assert "URL: https://twitter.com" in formatted
|
|
||||||
assert "Likes: 42" in formatted
|
|
||||||
assert "Retweets: 15" in formatted
|
|
||||||
|
|
||||||
# Test with minimal data
|
|
||||||
bookmark = {"text": "Simple tweet"}
|
|
||||||
formatted = reader._format_bookmark(bookmark)
|
|
||||||
assert "=== Twitter Bookmark ===" in formatted
|
|
||||||
assert "Simple tweet" in formatted
|
|
||||||
|
|
||||||
print("✅ Twitter bookmark formatting tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_slack_rag_initialization():
|
|
||||||
"""Test that SlackMCPRAG can be initialized."""
|
|
||||||
print("Testing SlackMCPRAG initialization...")
|
|
||||||
|
|
||||||
app = SlackMCPRAG()
|
|
||||||
assert app.default_index_name == "slack_messages"
|
|
||||||
assert hasattr(app, "parser")
|
|
||||||
|
|
||||||
print("✅ SlackMCPRAG initialization tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_twitter_rag_initialization():
|
|
||||||
"""Test that TwitterMCPRAG can be initialized."""
|
|
||||||
print("Testing TwitterMCPRAG initialization...")
|
|
||||||
|
|
||||||
app = TwitterMCPRAG()
|
|
||||||
assert app.default_index_name == "twitter_bookmarks"
|
|
||||||
assert hasattr(app, "parser")
|
|
||||||
|
|
||||||
print("✅ TwitterMCPRAG initialization tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_concatenated_content_creation():
|
|
||||||
"""Test creation of concatenated content from multiple messages."""
|
|
||||||
print("Testing concatenated content creation...")
|
|
||||||
|
|
||||||
reader = SlackMCPReader("slack-mcp-server", workspace_name="test-workspace")
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"text": "First message", "user": "alice", "ts": "1000"},
|
|
||||||
{"text": "Second message", "user": "bob", "ts": "2000"},
|
|
||||||
{"text": "Third message", "user": "charlie", "ts": "3000"},
|
|
||||||
]
|
|
||||||
|
|
||||||
content = reader._create_concatenated_content(messages, "general")
|
|
||||||
|
|
||||||
assert "Slack Channel: #general" in content
|
|
||||||
assert "Message Count: 3" in content
|
|
||||||
assert "Workspace: test-workspace" in content
|
|
||||||
assert "First message" in content
|
|
||||||
assert "Second message" in content
|
|
||||||
assert "Third message" in content
|
|
||||||
|
|
||||||
print("✅ Concatenated content creation tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all tests."""
|
|
||||||
print("🧪 Running MCP Integration Tests")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_slack_reader_initialization()
|
|
||||||
test_twitter_reader_initialization()
|
|
||||||
test_slack_message_formatting()
|
|
||||||
test_twitter_bookmark_formatting()
|
|
||||||
test_slack_rag_initialization()
|
|
||||||
test_twitter_rag_initialization()
|
|
||||||
test_concatenated_content_creation()
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("🎉 All tests passed! MCP integration is working correctly.")
|
|
||||||
print("\nNext steps:")
|
|
||||||
print("1. Install actual MCP servers for Slack and Twitter")
|
|
||||||
print("2. Configure API credentials")
|
|
||||||
print("3. Test with --test-connection flag")
|
|
||||||
print("4. Start indexing your live data!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Test failed: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,221 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Standalone test script for MCP integration implementations.
|
|
||||||
|
|
||||||
This script tests the basic functionality of the MCP readers
|
|
||||||
without requiring LEANN core dependencies.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add the parent directory to the path so we can import from apps
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
|
|
||||||
def test_slack_reader_basic():
|
|
||||||
"""Test basic SlackMCPReader functionality without async operations."""
|
|
||||||
print("Testing SlackMCPReader basic functionality...")
|
|
||||||
|
|
||||||
# Import and test initialization
|
|
||||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
|
||||||
|
|
||||||
reader = SlackMCPReader("slack-mcp-server")
|
|
||||||
assert reader.mcp_server_command == "slack-mcp-server"
|
|
||||||
assert reader.concatenate_conversations
|
|
||||||
|
|
||||||
# Test message formatting
|
|
||||||
message = {
|
|
||||||
"text": "Hello team! How's the project going?",
|
|
||||||
"user": "john_doe",
|
|
||||||
"channel": "general",
|
|
||||||
"ts": "1234567890.123456",
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = reader._format_message(message)
|
|
||||||
assert "Channel: #general" in formatted
|
|
||||||
assert "User: john_doe" in formatted
|
|
||||||
assert "Message: Hello team!" in formatted
|
|
||||||
|
|
||||||
# Test concatenated content creation
|
|
||||||
messages = [
|
|
||||||
{"text": "First message", "user": "alice", "ts": "1000"},
|
|
||||||
{"text": "Second message", "user": "bob", "ts": "2000"},
|
|
||||||
]
|
|
||||||
|
|
||||||
content = reader._create_concatenated_content(messages, "dev-team")
|
|
||||||
assert "Slack Channel: #dev-team" in content
|
|
||||||
assert "Message Count: 2" in content
|
|
||||||
assert "First message" in content
|
|
||||||
assert "Second message" in content
|
|
||||||
|
|
||||||
print("✅ SlackMCPReader basic tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_twitter_reader_basic():
|
|
||||||
"""Test basic TwitterMCPReader functionality."""
|
|
||||||
print("Testing TwitterMCPReader basic functionality...")
|
|
||||||
|
|
||||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
|
||||||
|
|
||||||
reader = TwitterMCPReader("twitter-mcp-server")
|
|
||||||
assert reader.mcp_server_command == "twitter-mcp-server"
|
|
||||||
assert reader.include_tweet_content
|
|
||||||
assert reader.max_bookmarks == 1000
|
|
||||||
|
|
||||||
# Test bookmark formatting
|
|
||||||
bookmark = {
|
|
||||||
"text": "Amazing article about the future of AI! Must read for everyone interested in tech.",
|
|
||||||
"author": "tech_guru",
|
|
||||||
"created_at": "2024-01-15T14:30:00Z",
|
|
||||||
"url": "https://twitter.com/tech_guru/status/123456789",
|
|
||||||
"likes": 156,
|
|
||||||
"retweets": 42,
|
|
||||||
"replies": 23,
|
|
||||||
"hashtags": ["AI", "tech", "future"],
|
|
||||||
"mentions": ["@openai", "@anthropic"],
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = reader._format_bookmark(bookmark)
|
|
||||||
assert "=== Twitter Bookmark ===" in formatted
|
|
||||||
assert "Author: @tech_guru" in formatted
|
|
||||||
assert "Amazing article about the future of AI!" in formatted
|
|
||||||
assert "Likes: 156" in formatted
|
|
||||||
assert "Retweets: 42" in formatted
|
|
||||||
assert "Hashtags: AI, tech, future" in formatted
|
|
||||||
assert "Mentions: @openai, @anthropic" in formatted
|
|
||||||
|
|
||||||
# Test with minimal data
|
|
||||||
simple_bookmark = {"text": "Short tweet", "author": "user123"}
|
|
||||||
formatted_simple = reader._format_bookmark(simple_bookmark)
|
|
||||||
assert "=== Twitter Bookmark ===" in formatted_simple
|
|
||||||
assert "Short tweet" in formatted_simple
|
|
||||||
assert "Author: @user123" in formatted_simple
|
|
||||||
|
|
||||||
print("✅ TwitterMCPReader basic tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_request_format():
|
|
||||||
"""Test MCP request formatting."""
|
|
||||||
print("Testing MCP request formatting...")
|
|
||||||
|
|
||||||
# Test initialization request format
|
|
||||||
init_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 1,
|
|
||||||
"method": "initialize",
|
|
||||||
"params": {
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Verify it's valid JSON
|
|
||||||
json_str = json.dumps(init_request)
|
|
||||||
parsed = json.loads(json_str)
|
|
||||||
assert parsed["jsonrpc"] == "2.0"
|
|
||||||
assert parsed["method"] == "initialize"
|
|
||||||
assert parsed["params"]["protocolVersion"] == "2024-11-05"
|
|
||||||
|
|
||||||
# Test tools/list request
|
|
||||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
|
||||||
|
|
||||||
json_str = json.dumps(list_request)
|
|
||||||
parsed = json.loads(json_str)
|
|
||||||
assert parsed["method"] == "tools/list"
|
|
||||||
|
|
||||||
print("✅ MCP request formatting tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_data_processing():
|
|
||||||
"""Test data processing capabilities."""
|
|
||||||
print("Testing data processing capabilities...")
|
|
||||||
|
|
||||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
|
||||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
|
||||||
|
|
||||||
# Test Slack message processing with various formats
|
|
||||||
slack_reader = SlackMCPReader("test-server")
|
|
||||||
|
|
||||||
messages_with_timestamps = [
|
|
||||||
{"text": "Meeting in 5 minutes", "user": "alice", "ts": "1000.123"},
|
|
||||||
{"text": "On my way!", "user": "bob", "ts": "1001.456"},
|
|
||||||
{"text": "Starting now", "user": "charlie", "ts": "1002.789"},
|
|
||||||
]
|
|
||||||
|
|
||||||
content = slack_reader._create_concatenated_content(messages_with_timestamps, "meetings")
|
|
||||||
assert "Meeting in 5 minutes" in content
|
|
||||||
assert "On my way!" in content
|
|
||||||
assert "Starting now" in content
|
|
||||||
|
|
||||||
# Test Twitter bookmark processing with engagement data
|
|
||||||
twitter_reader = TwitterMCPReader("test-server", include_metadata=True)
|
|
||||||
|
|
||||||
high_engagement_bookmark = {
|
|
||||||
"text": "Thread about startup lessons learned 🧵",
|
|
||||||
"author": "startup_founder",
|
|
||||||
"likes": 1250,
|
|
||||||
"retweets": 340,
|
|
||||||
"replies": 89,
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = twitter_reader._format_bookmark(high_engagement_bookmark)
|
|
||||||
assert "Thread about startup lessons learned" in formatted
|
|
||||||
assert "Likes: 1250" in formatted
|
|
||||||
assert "Retweets: 340" in formatted
|
|
||||||
assert "Replies: 89" in formatted
|
|
||||||
|
|
||||||
# Test with metadata disabled
|
|
||||||
twitter_reader_no_meta = TwitterMCPReader("test-server", include_metadata=False)
|
|
||||||
formatted_no_meta = twitter_reader_no_meta._format_bookmark(high_engagement_bookmark)
|
|
||||||
assert "Thread about startup lessons learned" in formatted_no_meta
|
|
||||||
assert "Likes:" not in formatted_no_meta
|
|
||||||
assert "Retweets:" not in formatted_no_meta
|
|
||||||
|
|
||||||
print("✅ Data processing tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all standalone tests."""
|
|
||||||
print("🧪 Running MCP Integration Standalone Tests")
|
|
||||||
print("=" * 60)
|
|
||||||
print("Testing core functionality without LEANN dependencies...")
|
|
||||||
print()
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_slack_reader_basic()
|
|
||||||
test_twitter_reader_basic()
|
|
||||||
test_mcp_request_format()
|
|
||||||
test_data_processing()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🎉 All standalone tests passed!")
|
|
||||||
print("\n✨ MCP Integration Summary:")
|
|
||||||
print("- SlackMCPReader: Ready for Slack message processing")
|
|
||||||
print("- TwitterMCPReader: Ready for Twitter bookmark processing")
|
|
||||||
print("- MCP Protocol: Properly formatted JSON-RPC requests")
|
|
||||||
print("- Data Processing: Handles various message/bookmark formats")
|
|
||||||
|
|
||||||
print("\n🚀 Next Steps:")
|
|
||||||
print("1. Install MCP servers: npm install -g slack-mcp-server twitter-mcp-server")
|
|
||||||
print("2. Configure API credentials for Slack and Twitter")
|
|
||||||
print("3. Test connections: python -m apps.slack_rag --test-connection")
|
|
||||||
print("4. Start indexing live data from your platforms!")
|
|
||||||
|
|
||||||
print("\n📖 Documentation:")
|
|
||||||
print("- Check README.md for detailed setup instructions")
|
|
||||||
print("- Run examples/mcp_integration_demo.py for usage examples")
|
|
||||||
print("- Explore apps/slack_rag.py and apps/twitter_rag.py for implementation details")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,268 +0,0 @@
|
|||||||
"""Unit tests for token-aware truncation functionality.
|
|
||||||
|
|
||||||
This test suite defines the contract for token truncation functions that prevent
|
|
||||||
500 errors from Ollama when text exceeds model token limits. These tests verify:
|
|
||||||
|
|
||||||
1. Model token limit retrieval (known and unknown models)
|
|
||||||
2. Text truncation behavior for single and multiple texts
|
|
||||||
3. Token counting and truncation accuracy using tiktoken
|
|
||||||
|
|
||||||
All tests are written in Red Phase - they should FAIL initially because the
|
|
||||||
implementation does not exist yet.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import tiktoken
|
|
||||||
from leann.embedding_compute import (
|
|
||||||
EMBEDDING_MODEL_LIMITS,
|
|
||||||
get_model_token_limit,
|
|
||||||
truncate_to_token_limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestModelTokenLimits:
|
|
||||||
"""Tests for retrieving model-specific token limits."""
|
|
||||||
|
|
||||||
def test_get_model_token_limit_known_model(self):
|
|
||||||
"""Verify correct token limit is returned for known models.
|
|
||||||
|
|
||||||
Known models should return their specific token limits from
|
|
||||||
EMBEDDING_MODEL_LIMITS dictionary.
|
|
||||||
"""
|
|
||||||
# Test nomic-embed-text (2048 tokens)
|
|
||||||
limit = get_model_token_limit("nomic-embed-text")
|
|
||||||
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
|
|
||||||
|
|
||||||
# Test nomic-embed-text-v1.5 (2048 tokens)
|
|
||||||
limit = get_model_token_limit("nomic-embed-text-v1.5")
|
|
||||||
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
|
|
||||||
|
|
||||||
# Test nomic-embed-text-v2 (512 tokens)
|
|
||||||
limit = get_model_token_limit("nomic-embed-text-v2")
|
|
||||||
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
|
|
||||||
|
|
||||||
# Test OpenAI models (8192 tokens)
|
|
||||||
limit = get_model_token_limit("text-embedding-3-small")
|
|
||||||
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
|
|
||||||
|
|
||||||
def test_get_model_token_limit_unknown_model(self):
|
|
||||||
"""Verify default token limit is returned for unknown models.
|
|
||||||
|
|
||||||
Unknown models should return the default limit (2048) to allow
|
|
||||||
operation with reasonable safety margin.
|
|
||||||
"""
|
|
||||||
# Test with completely unknown model
|
|
||||||
limit = get_model_token_limit("unknown-model-xyz")
|
|
||||||
assert limit == 2048, "Unknown models should return default 2048"
|
|
||||||
|
|
||||||
# Test with empty string
|
|
||||||
limit = get_model_token_limit("")
|
|
||||||
assert limit == 2048, "Empty model name should return default 2048"
|
|
||||||
|
|
||||||
def test_get_model_token_limit_custom_default(self):
|
|
||||||
"""Verify custom default can be specified for unknown models.
|
|
||||||
|
|
||||||
Allow callers to specify their own default token limit when
|
|
||||||
model is not in the known models dictionary.
|
|
||||||
"""
|
|
||||||
limit = get_model_token_limit("unknown-model", default=4096)
|
|
||||||
assert limit == 4096, "Should return custom default for unknown models"
|
|
||||||
|
|
||||||
# Known model should ignore custom default
|
|
||||||
limit = get_model_token_limit("nomic-embed-text", default=4096)
|
|
||||||
assert limit == 2048, "Known model should ignore custom default"
|
|
||||||
|
|
||||||
def test_embedding_model_limits_dictionary_exists(self):
|
|
||||||
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
|
|
||||||
|
|
||||||
The dictionary should be importable and contain at least the
|
|
||||||
known nomic models with correct token limits.
|
|
||||||
"""
|
|
||||||
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
|
|
||||||
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
|
|
||||||
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
|
|
||||||
"Should contain nomic-embed-text-v1.5"
|
|
||||||
)
|
|
||||||
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
|
|
||||||
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
|
|
||||||
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
|
|
||||||
# OpenAI models
|
|
||||||
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
|
|
||||||
|
|
||||||
|
|
||||||
class TestTokenTruncation:
|
|
||||||
"""Tests for truncating texts to token limits."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def tokenizer(self):
|
|
||||||
"""Provide tiktoken tokenizer for token counting verification."""
|
|
||||||
return tiktoken.get_encoding("cl100k_base")
|
|
||||||
|
|
||||||
def test_truncate_single_text_under_limit(self, tokenizer):
|
|
||||||
"""Verify text under token limit remains unchanged.
|
|
||||||
|
|
||||||
When text is already within the token limit, it should be
|
|
||||||
returned unchanged with no truncation.
|
|
||||||
"""
|
|
||||||
text = "This is a short text that is well under the token limit."
|
|
||||||
token_count = len(tokenizer.encode(text))
|
|
||||||
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
|
|
||||||
|
|
||||||
# Truncate with generous limit
|
|
||||||
result = truncate_to_token_limit([text], token_limit=512)
|
|
||||||
|
|
||||||
assert len(result) == 1, "Should return same number of texts"
|
|
||||||
assert result[0] == text, "Text under limit should be unchanged"
|
|
||||||
|
|
||||||
def test_truncate_single_text_over_limit(self, tokenizer):
|
|
||||||
"""Verify text over token limit is truncated correctly.
|
|
||||||
|
|
||||||
When text exceeds the token limit, it should be truncated to
|
|
||||||
fit within the limit while maintaining valid token boundaries.
|
|
||||||
"""
|
|
||||||
# Create a text that definitely exceeds limit
|
|
||||||
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
|
|
||||||
original_token_count = len(tokenizer.encode(text))
|
|
||||||
assert original_token_count > 50, (
|
|
||||||
f"Test setup: text should be long (has {original_token_count} tokens)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Truncate to 50 tokens
|
|
||||||
result = truncate_to_token_limit([text], token_limit=50)
|
|
||||||
|
|
||||||
assert len(result) == 1, "Should return same number of texts"
|
|
||||||
assert result[0] != text, "Text over limit should be truncated"
|
|
||||||
assert len(result[0]) < len(text), "Truncated text should be shorter"
|
|
||||||
|
|
||||||
# Verify truncated text is within token limit
|
|
||||||
truncated_token_count = len(tokenizer.encode(result[0]))
|
|
||||||
assert truncated_token_count <= 50, (
|
|
||||||
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
|
|
||||||
"""Verify multiple texts with mixed lengths are handled correctly.
|
|
||||||
|
|
||||||
When processing multiple texts:
|
|
||||||
- Texts under limit should remain unchanged
|
|
||||||
- Texts over limit should be truncated independently
|
|
||||||
- Output list should maintain same order and length
|
|
||||||
"""
|
|
||||||
texts = [
|
|
||||||
"Short text.", # Under limit
|
|
||||||
"word " * 200, # Over limit
|
|
||||||
"Another short one.", # Under limit
|
|
||||||
"token " * 150, # Over limit
|
|
||||||
]
|
|
||||||
|
|
||||||
# Verify test setup
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
token_count = len(tokenizer.encode(text))
|
|
||||||
if i in [1, 3]:
|
|
||||||
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
|
|
||||||
else:
|
|
||||||
assert token_count < 50, (
|
|
||||||
f"Text {i} should be under limit (has {token_count} tokens)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Truncate with 50 token limit
|
|
||||||
result = truncate_to_token_limit(texts, token_limit=50)
|
|
||||||
|
|
||||||
assert len(result) == len(texts), "Should return same number of texts"
|
|
||||||
|
|
||||||
# Verify each text individually
|
|
||||||
for i, (original, truncated) in enumerate(zip(texts, result)):
|
|
||||||
token_count = len(tokenizer.encode(truncated))
|
|
||||||
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
|
|
||||||
|
|
||||||
# Short texts should be unchanged
|
|
||||||
if i in [0, 2]:
|
|
||||||
assert truncated == original, f"Short text {i} should be unchanged"
|
|
||||||
# Long texts should be truncated
|
|
||||||
else:
|
|
||||||
assert len(truncated) < len(original), f"Long text {i} should be truncated"
|
|
||||||
|
|
||||||
def test_truncate_empty_list(self):
|
|
||||||
"""Verify empty input list returns empty output list.
|
|
||||||
|
|
||||||
Edge case: empty list should return empty list without errors.
|
|
||||||
"""
|
|
||||||
result = truncate_to_token_limit([], token_limit=512)
|
|
||||||
assert result == [], "Empty input should return empty output"
|
|
||||||
|
|
||||||
def test_truncate_preserves_order(self, tokenizer):
|
|
||||||
"""Verify truncation preserves original text order.
|
|
||||||
|
|
||||||
Output list should maintain the same order as input list,
|
|
||||||
regardless of which texts were truncated.
|
|
||||||
"""
|
|
||||||
texts = [
|
|
||||||
"First text " * 50, # Will be truncated
|
|
||||||
"Second text.", # Won't be truncated
|
|
||||||
"Third text " * 50, # Will be truncated
|
|
||||||
]
|
|
||||||
|
|
||||||
result = truncate_to_token_limit(texts, token_limit=20)
|
|
||||||
|
|
||||||
assert len(result) == 3, "Should preserve list length"
|
|
||||||
# Check that order is maintained by looking for distinctive words
|
|
||||||
assert "First" in result[0], "First text should remain in first position"
|
|
||||||
assert "Second" in result[1], "Second text should remain in second position"
|
|
||||||
assert "Third" in result[2], "Third text should remain in third position"
|
|
||||||
|
|
||||||
def test_truncate_extremely_long_text(self, tokenizer):
|
|
||||||
"""Verify extremely long texts are truncated efficiently.
|
|
||||||
|
|
||||||
Test with text that far exceeds token limit to ensure
|
|
||||||
truncation handles extreme cases without performance issues.
|
|
||||||
"""
|
|
||||||
# Create very long text (simulate real-world scenario)
|
|
||||||
text = "token " * 5000 # ~5000+ tokens
|
|
||||||
original_token_count = len(tokenizer.encode(text))
|
|
||||||
assert original_token_count > 1000, "Test setup: text should be very long"
|
|
||||||
|
|
||||||
# Truncate to small limit
|
|
||||||
result = truncate_to_token_limit([text], token_limit=100)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
truncated_token_count = len(tokenizer.encode(result[0]))
|
|
||||||
assert truncated_token_count <= 100, (
|
|
||||||
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
|
|
||||||
)
|
|
||||||
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
|
|
||||||
|
|
||||||
def test_truncate_exact_token_limit(self, tokenizer):
|
|
||||||
"""Verify text at exactly token limit is handled correctly.
|
|
||||||
|
|
||||||
Edge case: text with exactly the token limit should either
|
|
||||||
remain unchanged or be safely truncated by 1 token.
|
|
||||||
"""
|
|
||||||
# Create text with approximately 50 tokens
|
|
||||||
# We'll adjust to get exactly 50
|
|
||||||
target_tokens = 50
|
|
||||||
text = "word " * 50
|
|
||||||
tokens = tokenizer.encode(text)
|
|
||||||
|
|
||||||
# Adjust to get exactly target_tokens
|
|
||||||
if len(tokens) > target_tokens:
|
|
||||||
tokens = tokens[:target_tokens]
|
|
||||||
text = tokenizer.decode(tokens)
|
|
||||||
elif len(tokens) < target_tokens:
|
|
||||||
# Add more words
|
|
||||||
while len(tokenizer.encode(text)) < target_tokens:
|
|
||||||
text += "word "
|
|
||||||
tokens = tokenizer.encode(text)[:target_tokens]
|
|
||||||
text = tokenizer.decode(tokens)
|
|
||||||
|
|
||||||
# Verify we have exactly target_tokens
|
|
||||||
assert len(tokenizer.encode(text)) == target_tokens, (
|
|
||||||
"Test setup: should have exactly 50 tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = truncate_to_token_limit([text], token_limit=target_tokens)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
result_tokens = len(tokenizer.encode(result[0]))
|
|
||||||
assert result_tokens <= target_tokens, (
|
|
||||||
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user