Compare commits
2 Commits
feature/mc
...
fix/ask-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
47aeb85f82 | ||
|
|
db7ba27ff6 |
358
README.md
358
README.md
@@ -20,7 +20,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
|||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||||
@@ -176,7 +176,7 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
|||||||
|
|
||||||
## RAG on Everything!
|
## RAG on Everything!
|
||||||
|
|
||||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, ChatGPT conversations, Claude conversations, iMessage conversations, and **live data from any platform through MCP (Model Context Protocol) servers** - including Slack, Twitter, and more.
|
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -477,355 +477,6 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 🤖 ChatGPT Chat History: Your Personal AI Conversation Archive!
|
|
||||||
|
|
||||||
Transform your ChatGPT conversations into a searchable knowledge base! Search through all your ChatGPT discussions about coding, research, brainstorming, and more.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.chatgpt_rag --export-path chatgpt_export.html --query "How do I create a list in Python?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your ChatGPT discussions again.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: How to Export ChatGPT Data</strong></summary>
|
|
||||||
|
|
||||||
**Step-by-step export process:**
|
|
||||||
|
|
||||||
1. **Sign in to ChatGPT**
|
|
||||||
2. **Click your profile icon** in the top right corner
|
|
||||||
3. **Navigate to Settings** → **Data Controls**
|
|
||||||
4. **Click "Export"** under Export Data
|
|
||||||
5. **Confirm the export** request
|
|
||||||
6. **Download the ZIP file** from the email link (expires in 24 hours)
|
|
||||||
7. **Extract or use directly** with LEANN
|
|
||||||
|
|
||||||
**Supported formats:**
|
|
||||||
- `.html` files from ChatGPT exports
|
|
||||||
- `.zip` archives from ChatGPT
|
|
||||||
- Directories with multiple export files
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: ChatGPT-Specific Arguments</strong></summary>
|
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
|
||||||
--export-path PATH # Path to ChatGPT export file (.html/.zip) or directory (default: ./chatgpt_export)
|
|
||||||
--separate-messages # Process each message separately instead of concatenated conversations
|
|
||||||
--chunk-size N # Text chunk size (default: 512)
|
|
||||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
|
||||||
```bash
|
|
||||||
# Basic usage with HTML export
|
|
||||||
python -m apps.chatgpt_rag --export-path conversations.html
|
|
||||||
|
|
||||||
# Process ZIP archive from ChatGPT
|
|
||||||
python -m apps.chatgpt_rag --export-path chatgpt_export.zip
|
|
||||||
|
|
||||||
# Search with specific query
|
|
||||||
python -m apps.chatgpt_rag --export-path chatgpt_data.html --query "Python programming help"
|
|
||||||
|
|
||||||
# Process individual messages for fine-grained search
|
|
||||||
python -m apps.chatgpt_rag --separate-messages --export-path chatgpt_export.html
|
|
||||||
|
|
||||||
# Process directory containing multiple exports
|
|
||||||
python -m apps.chatgpt_rag --export-path ./chatgpt_exports/ --max-items 1000
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
|
||||||
|
|
||||||
Once your ChatGPT conversations are indexed, you can search with queries like:
|
|
||||||
- "What did I ask ChatGPT about Python programming?"
|
|
||||||
- "Show me conversations about machine learning algorithms"
|
|
||||||
- "Find discussions about web development frameworks"
|
|
||||||
- "What coding advice did ChatGPT give me?"
|
|
||||||
- "Search for conversations about debugging techniques"
|
|
||||||
- "Find ChatGPT's recommendations for learning resources"
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 🤖 Claude Chat History: Your Personal AI Conversation Archive!
|
|
||||||
|
|
||||||
Transform your Claude conversations into a searchable knowledge base! Search through all your Claude discussions about coding, research, brainstorming, and more.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.claude_rag --export-path claude_export.json --query "What did I ask about Python dictionaries?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your Claude discussions again.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: How to Export Claude Data</strong></summary>
|
|
||||||
|
|
||||||
**Step-by-step export process:**
|
|
||||||
|
|
||||||
1. **Open Claude** in your browser
|
|
||||||
2. **Navigate to Settings** (look for gear icon or settings menu)
|
|
||||||
3. **Find Export/Download** options in your account settings
|
|
||||||
4. **Download conversation data** (usually in JSON format)
|
|
||||||
5. **Place the file** in your project directory
|
|
||||||
|
|
||||||
*Note: Claude export methods may vary depending on the interface you're using. Check Claude's help documentation for the most current export instructions.*
|
|
||||||
|
|
||||||
**Supported formats:**
|
|
||||||
- `.json` files (recommended)
|
|
||||||
- `.zip` archives containing JSON data
|
|
||||||
- Directories with multiple export files
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: Claude-Specific Arguments</strong></summary>
|
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
|
||||||
--export-path PATH # Path to Claude export file (.json/.zip) or directory (default: ./claude_export)
|
|
||||||
--separate-messages # Process each message separately instead of concatenated conversations
|
|
||||||
--chunk-size N # Text chunk size (default: 512)
|
|
||||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
|
||||||
```bash
|
|
||||||
# Basic usage with JSON export
|
|
||||||
python -m apps.claude_rag --export-path my_claude_conversations.json
|
|
||||||
|
|
||||||
# Process ZIP archive from Claude
|
|
||||||
python -m apps.claude_rag --export-path claude_export.zip
|
|
||||||
|
|
||||||
# Search with specific query
|
|
||||||
python -m apps.claude_rag --export-path claude_data.json --query "machine learning advice"
|
|
||||||
|
|
||||||
# Process individual messages for fine-grained search
|
|
||||||
python -m apps.claude_rag --separate-messages --export-path claude_export.json
|
|
||||||
|
|
||||||
# Process directory containing multiple exports
|
|
||||||
python -m apps.claude_rag --export-path ./claude_exports/ --max-items 1000
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
|
||||||
|
|
||||||
Once your Claude conversations are indexed, you can search with queries like:
|
|
||||||
- "What did I ask Claude about Python programming?"
|
|
||||||
- "Show me conversations about machine learning algorithms"
|
|
||||||
- "Find discussions about software architecture patterns"
|
|
||||||
- "What debugging advice did Claude give me?"
|
|
||||||
- "Search for conversations about data structures"
|
|
||||||
- "Find Claude's recommendations for learning resources"
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 💬 iMessage History: Your Personal Conversation Archive!
|
|
||||||
|
|
||||||
Transform your iMessage conversations into a searchable knowledge base! Search through all your text messages, group chats, and conversations with friends, family, and colleagues.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.imessage_rag --query "What did we discuss about the weekend plans?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Unlock your message history.** Never lose track of important conversations, shared links, or memorable moments from your iMessage history.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: How to Access iMessage Data</strong></summary>
|
|
||||||
|
|
||||||
**iMessage data location:**
|
|
||||||
|
|
||||||
iMessage conversations are stored in a SQLite database on your Mac at:
|
|
||||||
```
|
|
||||||
~/Library/Messages/chat.db
|
|
||||||
```
|
|
||||||
|
|
||||||
**Important setup requirements:**
|
|
||||||
|
|
||||||
1. **Grant Full Disk Access** to your terminal or IDE:
|
|
||||||
- Open **System Preferences** → **Security & Privacy** → **Privacy**
|
|
||||||
- Select **Full Disk Access** from the left sidebar
|
|
||||||
- Click the **+** button and add your terminal app (Terminal, iTerm2) or IDE (VS Code, etc.)
|
|
||||||
- Restart your terminal/IDE after granting access
|
|
||||||
|
|
||||||
2. **Alternative: Use a backup database**
|
|
||||||
- If you have Time Machine backups or manual copies of the database
|
|
||||||
- Use `--db-path` to specify a custom location
|
|
||||||
|
|
||||||
**Supported formats:**
|
|
||||||
- Direct access to `~/Library/Messages/chat.db` (default)
|
|
||||||
- Custom database path with `--db-path`
|
|
||||||
- Works with backup copies of the database
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: iMessage-Specific Arguments</strong></summary>
|
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
|
||||||
--db-path PATH # Path to chat.db file (default: ~/Library/Messages/chat.db)
|
|
||||||
--concatenate-conversations # Group messages by conversation (default: True)
|
|
||||||
--no-concatenate-conversations # Process each message individually
|
|
||||||
--chunk-size N # Text chunk size (default: 1000)
|
|
||||||
--chunk-overlap N # Overlap between chunks (default: 200)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
|
||||||
```bash
|
|
||||||
# Basic usage (requires Full Disk Access)
|
|
||||||
python -m apps.imessage_rag
|
|
||||||
|
|
||||||
# Search with specific query
|
|
||||||
python -m apps.imessage_rag --query "family dinner plans"
|
|
||||||
|
|
||||||
# Use custom database path
|
|
||||||
python -m apps.imessage_rag --db-path /path/to/backup/chat.db
|
|
||||||
|
|
||||||
# Process individual messages instead of conversations
|
|
||||||
python -m apps.imessage_rag --no-concatenate-conversations
|
|
||||||
|
|
||||||
# Limit processing for testing
|
|
||||||
python -m apps.imessage_rag --max-items 100 --query "weekend"
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
|
||||||
|
|
||||||
Once your iMessage conversations are indexed, you can search with queries like:
|
|
||||||
- "What did we discuss about vacation plans?"
|
|
||||||
- "Find messages about restaurant recommendations"
|
|
||||||
- "Show me conversations with John about the project"
|
|
||||||
- "Search for shared links about technology"
|
|
||||||
- "Find group chat discussions about weekend events"
|
|
||||||
- "What did mom say about the family gathering?"
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 🔌 MCP Integration: RAG on Live Data from Any Platform!
|
|
||||||
|
|
||||||
**NEW!** Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
|
|
||||||
|
|
||||||
**Key Benefits:**
|
|
||||||
- 🔄 **Live Data Access**: Fetch real-time data without manual exports
|
|
||||||
- 🔌 **Standardized Protocol**: Use any MCP-compatible server
|
|
||||||
- 🚀 **Easy Extension**: Add new platforms with minimal code
|
|
||||||
- 🔒 **Secure Access**: MCP servers handle authentication
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>💬 Slack Messages: Search Your Team Conversations</strong></summary>
|
|
||||||
|
|
||||||
Transform your Slack workspace into a searchable knowledge base! Find discussions, decisions, and shared knowledge across all your channels.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Test MCP server connection
|
|
||||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
|
|
||||||
|
|
||||||
# Index and search Slack messages
|
|
||||||
python -m apps.slack_rag \
|
|
||||||
--mcp-server "slack-mcp-server" \
|
|
||||||
--workspace-name "my-team" \
|
|
||||||
--channels general dev-team random \
|
|
||||||
--query "What did we decide about the product launch?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Setup Requirements:**
|
|
||||||
1. Install a Slack MCP server (e.g., `npm install -g slack-mcp-server`)
|
|
||||||
2. Configure Slack API credentials:
|
|
||||||
```bash
|
|
||||||
export SLACK_BOT_TOKEN="xoxb-your-bot-token"
|
|
||||||
export SLACK_APP_TOKEN="xapp-your-app-token"
|
|
||||||
```
|
|
||||||
3. Test connection with `--test-connection` flag
|
|
||||||
|
|
||||||
**Arguments:**
|
|
||||||
- `--mcp-server`: Command to start the Slack MCP server
|
|
||||||
- `--workspace-name`: Slack workspace name for organization
|
|
||||||
- `--channels`: Specific channels to index (optional)
|
|
||||||
- `--concatenate-conversations`: Group messages by channel (default: true)
|
|
||||||
- `--max-messages-per-channel`: Limit messages per channel (default: 100)
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>🐦 Twitter Bookmarks: Your Personal Tweet Library</strong></summary>
|
|
||||||
|
|
||||||
Search through your Twitter bookmarks! Find that perfect article, thread, or insight you saved for later.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Test MCP server connection
|
|
||||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --test-connection
|
|
||||||
|
|
||||||
# Index and search Twitter bookmarks
|
|
||||||
python -m apps.twitter_rag \
|
|
||||||
--mcp-server "twitter-mcp-server" \
|
|
||||||
--max-bookmarks 1000 \
|
|
||||||
--query "What AI articles did I bookmark about machine learning?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Setup Requirements:**
|
|
||||||
1. Install a Twitter MCP server (e.g., `npm install -g twitter-mcp-server`)
|
|
||||||
2. Configure Twitter API credentials:
|
|
||||||
```bash
|
|
||||||
export TWITTER_API_KEY="your-api-key"
|
|
||||||
export TWITTER_API_SECRET="your-api-secret"
|
|
||||||
export TWITTER_ACCESS_TOKEN="your-access-token"
|
|
||||||
export TWITTER_ACCESS_TOKEN_SECRET="your-access-token-secret"
|
|
||||||
```
|
|
||||||
3. Test connection with `--test-connection` flag
|
|
||||||
|
|
||||||
**Arguments:**
|
|
||||||
- `--mcp-server`: Command to start the Twitter MCP server
|
|
||||||
- `--username`: Filter bookmarks by username (optional)
|
|
||||||
- `--max-bookmarks`: Maximum bookmarks to fetch (default: 1000)
|
|
||||||
- `--no-tweet-content`: Exclude tweet content, only metadata
|
|
||||||
- `--no-metadata`: Exclude engagement metadata
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
|
||||||
|
|
||||||
**Slack Queries:**
|
|
||||||
- "What did the team discuss about the project deadline?"
|
|
||||||
- "Find messages about the new feature launch"
|
|
||||||
- "Show me conversations about budget planning"
|
|
||||||
- "What decisions were made in the dev-team channel?"
|
|
||||||
|
|
||||||
**Twitter Queries:**
|
|
||||||
- "What AI articles did I bookmark last month?"
|
|
||||||
- "Find tweets about machine learning techniques"
|
|
||||||
- "Show me bookmarked threads about startup advice"
|
|
||||||
- "What Python tutorials did I save?"
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>🔧 Adding New MCP Platforms</strong></summary>
|
|
||||||
|
|
||||||
Want to add support for other platforms? LEANN's MCP integration is designed for easy extension:
|
|
||||||
|
|
||||||
1. **Find or create an MCP server** for your platform
|
|
||||||
2. **Create a reader class** following the pattern in `apps/slack_data/slack_mcp_reader.py`
|
|
||||||
3. **Create a RAG application** following the pattern in `apps/slack_rag.py`
|
|
||||||
4. **Test and contribute** back to the community!
|
|
||||||
|
|
||||||
**Popular MCP servers to explore:**
|
|
||||||
- GitHub repositories and issues
|
|
||||||
- Discord messages
|
|
||||||
- Notion pages
|
|
||||||
- Google Drive documents
|
|
||||||
- And many more in the MCP ecosystem!
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -895,6 +546,9 @@ leann search my-docs "machine learning concepts"
|
|||||||
# Interactive chat with your documents
|
# Interactive chat with your documents
|
||||||
leann ask my-docs --interactive
|
leann ask my-docs --interactive
|
||||||
|
|
||||||
|
# Ask a single question (non-interactive)
|
||||||
|
leann ask my-docs "Where are prompts configured?"
|
||||||
|
|
||||||
# List all your indexes
|
# List all your indexes
|
||||||
leann list
|
leann list
|
||||||
|
|
||||||
@@ -1097,7 +751,7 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||||
|
|
||||||
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan), [Aakash Suresh](https://github.com/ASuresh0524)
|
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
||||||
|
|
||||||
|
|
||||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import Any
|
|||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
from leann.registry import register_project_directory
|
from leann.registry import register_project_directory
|
||||||
|
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -78,6 +79,24 @@ class BaseRAGExample(ABC):
|
|||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||||
)
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# LLM parameters
|
# LLM parameters
|
||||||
llm_group = parser.add_argument_group("LLM Parameters")
|
llm_group = parser.add_argument_group("LLM Parameters")
|
||||||
@@ -97,8 +116,8 @@ class BaseRAGExample(ABC):
|
|||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--llm-host",
|
"--llm-host",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://localhost:11434",
|
default=None,
|
||||||
help="Host for Ollama API (default: http://localhost:11434)",
|
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--thinking-budget",
|
"--thinking-budget",
|
||||||
@@ -107,6 +126,18 @@ class BaseRAGExample(ABC):
|
|||||||
default=None,
|
default=None,
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
)
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# AST Chunking parameters
|
# AST Chunking parameters
|
||||||
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||||
@@ -205,9 +236,13 @@ class BaseRAGExample(ABC):
|
|||||||
|
|
||||||
if args.llm == "openai":
|
if args.llm == "openai":
|
||||||
config["model"] = args.llm_model or "gpt-4o"
|
config["model"] = args.llm_model or "gpt-4o"
|
||||||
|
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||||
|
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||||
|
if resolved_key:
|
||||||
|
config["api_key"] = resolved_key
|
||||||
elif args.llm == "ollama":
|
elif args.llm == "ollama":
|
||||||
config["model"] = args.llm_model or "llama3.2:1b"
|
config["model"] = args.llm_model or "llama3.2:1b"
|
||||||
config["host"] = args.llm_host
|
config["host"] = resolve_ollama_host(args.llm_host)
|
||||||
elif args.llm == "hf":
|
elif args.llm == "hf":
|
||||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
elif args.llm == "simulated":
|
elif args.llm == "simulated":
|
||||||
@@ -223,10 +258,20 @@ class BaseRAGExample(ABC):
|
|||||||
print(f"\n[Building Index] Creating {self.name} index...")
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
print(f"Total text chunks: {len(texts)}")
|
print(f"Total text chunks: {len(texts)}")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend_name,
|
backend_name=args.backend_name,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
graph_degree=args.graph_degree,
|
graph_degree=args.graph_degree,
|
||||||
complexity=args.build_complexity,
|
complexity=args.build_complexity,
|
||||||
is_compact=not args.no_compact,
|
is_compact=not args.no_compact,
|
||||||
|
|||||||
@@ -1,413 +0,0 @@
|
|||||||
"""
|
|
||||||
ChatGPT export data reader.
|
|
||||||
|
|
||||||
Reads and processes ChatGPT export data from chat.html files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from zipfile import ZipFile
|
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTReader(BaseReader):
|
|
||||||
"""
|
|
||||||
ChatGPT export data reader.
|
|
||||||
|
|
||||||
Reads ChatGPT conversation data from exported chat.html files or zip archives.
|
|
||||||
Processes conversations into structured documents with metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup # noqa
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
|
||||||
|
|
||||||
self.concatenate_conversations = concatenate_conversations
|
|
||||||
|
|
||||||
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
|
|
||||||
"""
|
|
||||||
Extract chat.html from ChatGPT export zip file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
zip_path: Path to the ChatGPT export zip file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HTML content as string, or None if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with ZipFile(zip_path, "r") as zip_file:
|
|
||||||
# Look for chat.html or conversations.html
|
|
||||||
html_files = [
|
|
||||||
f
|
|
||||||
for f in zip_file.namelist()
|
|
||||||
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
|
|
||||||
]
|
|
||||||
|
|
||||||
if not html_files:
|
|
||||||
print(f"No HTML chat file found in {zip_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Use the first HTML file found
|
|
||||||
html_file = html_files[0]
|
|
||||||
print(f"Found HTML file: {html_file}")
|
|
||||||
|
|
||||||
with zip_file.open(html_file) as f:
|
|
||||||
return f.read().decode("utf-8", errors="ignore")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error extracting HTML from zip {zip_path}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Parse ChatGPT HTML export to extract conversations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
html_content: HTML content from ChatGPT export
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of conversation dictionaries
|
|
||||||
"""
|
|
||||||
soup = BeautifulSoup(html_content, "html.parser")
|
|
||||||
conversations = []
|
|
||||||
|
|
||||||
# Try different possible structures for ChatGPT exports
|
|
||||||
# Structure 1: Look for conversation containers
|
|
||||||
conversation_containers = soup.find_all(
|
|
||||||
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not conversation_containers:
|
|
||||||
# Structure 2: Look for message containers directly
|
|
||||||
conversation_containers = [soup] # Use the entire document as one conversation
|
|
||||||
|
|
||||||
for container in conversation_containers:
|
|
||||||
conversation = self._extract_conversation_from_container(container)
|
|
||||||
if conversation and conversation.get("messages"):
|
|
||||||
conversations.append(conversation)
|
|
||||||
|
|
||||||
# If no structured conversations found, try to extract all text as one conversation
|
|
||||||
if not conversations:
|
|
||||||
all_text = soup.get_text(separator="\n", strip=True)
|
|
||||||
if all_text:
|
|
||||||
conversations.append(
|
|
||||||
{
|
|
||||||
"title": "ChatGPT Conversation",
|
|
||||||
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
|
|
||||||
"timestamp": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return conversations
|
|
||||||
|
|
||||||
def _extract_conversation_from_container(self, container) -> dict | None:
|
|
||||||
"""
|
|
||||||
Extract conversation data from a container element.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
container: BeautifulSoup element containing conversation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with conversation data or None
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# Look for message elements with various possible structures
|
|
||||||
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
|
|
||||||
|
|
||||||
for selector in message_selectors:
|
|
||||||
message_elements = container.select(selector)
|
|
||||||
if message_elements:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
message_elements = []
|
|
||||||
|
|
||||||
# If no structured messages found, treat the entire container as one message
|
|
||||||
if not message_elements:
|
|
||||||
text_content = container.get_text(separator="\n", strip=True)
|
|
||||||
if text_content:
|
|
||||||
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
|
|
||||||
else:
|
|
||||||
for element in message_elements:
|
|
||||||
message = self._extract_message_from_element(element)
|
|
||||||
if message:
|
|
||||||
messages.append(message)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try to extract conversation title
|
|
||||||
title_element = container.find(["h1", "h2", "h3", "title"])
|
|
||||||
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
|
|
||||||
|
|
||||||
# Try to extract timestamp from various possible locations
|
|
||||||
timestamp = self._extract_timestamp_from_container(container)
|
|
||||||
|
|
||||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
|
||||||
|
|
||||||
def _extract_message_from_element(self, element) -> dict | None:
|
|
||||||
"""
|
|
||||||
Extract message data from an element.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
element: BeautifulSoup element containing message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with message data or None
|
|
||||||
"""
|
|
||||||
text_content = element.get_text(separator=" ", strip=True)
|
|
||||||
|
|
||||||
# Skip empty or very short messages
|
|
||||||
if not text_content or len(text_content.strip()) < 3:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try to determine role (user/assistant) from class names or content
|
|
||||||
role = "mixed" # Default role
|
|
||||||
|
|
||||||
class_names = " ".join(element.get("class", [])).lower()
|
|
||||||
if "user" in class_names or "human" in class_names:
|
|
||||||
role = "user"
|
|
||||||
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
|
|
||||||
role = "assistant"
|
|
||||||
elif text_content.lower().startswith(("you:", "user:", "me:")):
|
|
||||||
role = "user"
|
|
||||||
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
|
|
||||||
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
|
|
||||||
role = "assistant"
|
|
||||||
text_content = re.sub(
|
|
||||||
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to extract timestamp
|
|
||||||
timestamp = self._extract_timestamp_from_element(element)
|
|
||||||
|
|
||||||
return {"role": role, "content": text_content, "timestamp": timestamp}
|
|
||||||
|
|
||||||
def _extract_timestamp_from_element(self, element) -> str | None:
|
|
||||||
"""Extract timestamp from element."""
|
|
||||||
# Look for timestamp in various attributes and child elements
|
|
||||||
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
|
|
||||||
for attr in timestamp_attrs:
|
|
||||||
if element.get(attr):
|
|
||||||
return element.get(attr)
|
|
||||||
|
|
||||||
# Look for time elements
|
|
||||||
time_element = element.find("time")
|
|
||||||
if time_element:
|
|
||||||
return time_element.get("datetime") or time_element.get_text(strip=True)
|
|
||||||
|
|
||||||
# Look for date-like text patterns
|
|
||||||
text = element.get_text()
|
|
||||||
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
|
|
||||||
|
|
||||||
for pattern in date_patterns:
|
|
||||||
match = re.search(pattern, text)
|
|
||||||
if match:
|
|
||||||
return match.group()
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_timestamp_from_container(self, container) -> str | None:
|
|
||||||
"""Extract timestamp from conversation container."""
|
|
||||||
return self._extract_timestamp_from_element(container)
|
|
||||||
|
|
||||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
|
||||||
"""
|
|
||||||
Create concatenated content from conversation messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation: Dictionary containing conversation data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted concatenated content
|
|
||||||
"""
|
|
||||||
title = conversation.get("title", "ChatGPT Conversation")
|
|
||||||
messages = conversation.get("messages", [])
|
|
||||||
timestamp = conversation.get("timestamp", "Unknown")
|
|
||||||
|
|
||||||
# Build message content
|
|
||||||
message_parts = []
|
|
||||||
for message in messages:
|
|
||||||
role = message.get("role", "mixed")
|
|
||||||
content = message.get("content", "")
|
|
||||||
msg_timestamp = message.get("timestamp", "")
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
prefix = "[You]"
|
|
||||||
elif role == "assistant":
|
|
||||||
prefix = "[ChatGPT]"
|
|
||||||
else:
|
|
||||||
prefix = "[Message]"
|
|
||||||
|
|
||||||
# Add timestamp if available
|
|
||||||
if msg_timestamp:
|
|
||||||
prefix += f" ({msg_timestamp})"
|
|
||||||
|
|
||||||
message_parts.append(f"{prefix}: {content}")
|
|
||||||
|
|
||||||
concatenated_text = "\n\n".join(message_parts)
|
|
||||||
|
|
||||||
# Create final document content
|
|
||||||
doc_content = f"""Conversation: {title}
|
|
||||||
Date: {timestamp}
|
|
||||||
Messages ({len(messages)} messages):
|
|
||||||
|
|
||||||
{concatenated_text}
|
|
||||||
"""
|
|
||||||
return doc_content
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
|
||||||
"""
|
|
||||||
Load ChatGPT export data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Directory containing ChatGPT export files or path to specific file
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum number of conversations to process
|
|
||||||
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
|
|
||||||
include_metadata (bool): Whether to include metadata in documents
|
|
||||||
"""
|
|
||||||
docs: list[Document] = []
|
|
||||||
max_count = load_kwargs.get("max_count", -1)
|
|
||||||
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
|
|
||||||
include_metadata = load_kwargs.get("include_metadata", True)
|
|
||||||
|
|
||||||
if not chatgpt_export_path:
|
|
||||||
print("No ChatGPT export path provided")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
export_path = Path(chatgpt_export_path)
|
|
||||||
|
|
||||||
if not export_path.exists():
|
|
||||||
print(f"ChatGPT export path not found: {export_path}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
html_content = None
|
|
||||||
|
|
||||||
# Handle different input types
|
|
||||||
if export_path.is_file():
|
|
||||||
if export_path.suffix.lower() == ".zip":
|
|
||||||
# Extract HTML from zip file
|
|
||||||
html_content = self._extract_html_from_zip(export_path)
|
|
||||||
elif export_path.suffix.lower() == ".html":
|
|
||||||
# Read HTML file directly
|
|
||||||
try:
|
|
||||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
|
||||||
html_content = f.read()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading HTML file {export_path}: {e}")
|
|
||||||
return docs
|
|
||||||
else:
|
|
||||||
print(f"Unsupported file type: {export_path.suffix}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
elif export_path.is_dir():
|
|
||||||
# Look for HTML files in directory
|
|
||||||
html_files = list(export_path.glob("*.html"))
|
|
||||||
zip_files = list(export_path.glob("*.zip"))
|
|
||||||
|
|
||||||
if html_files:
|
|
||||||
# Use first HTML file found
|
|
||||||
html_file = html_files[0]
|
|
||||||
print(f"Found HTML file: {html_file}")
|
|
||||||
try:
|
|
||||||
with open(html_file, encoding="utf-8", errors="ignore") as f:
|
|
||||||
html_content = f.read()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading HTML file {html_file}: {e}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
elif zip_files:
|
|
||||||
# Use first zip file found
|
|
||||||
zip_file = zip_files[0]
|
|
||||||
print(f"Found zip file: {zip_file}")
|
|
||||||
html_content = self._extract_html_from_zip(zip_file)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"No HTML or zip files found in {export_path}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
if not html_content:
|
|
||||||
print("No HTML content found to process")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
# Parse conversations from HTML
|
|
||||||
print("Parsing ChatGPT conversations from HTML...")
|
|
||||||
conversations = self._parse_chatgpt_html(html_content)
|
|
||||||
|
|
||||||
if not conversations:
|
|
||||||
print("No conversations found in HTML content")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
print(f"Found {len(conversations)} conversations")
|
|
||||||
|
|
||||||
# Process conversations into documents
|
|
||||||
count = 0
|
|
||||||
for conversation in conversations:
|
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
# Create one document per conversation with concatenated messages
|
|
||||||
doc_content = self._create_concatenated_content(conversation)
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
if include_metadata:
|
|
||||||
metadata = {
|
|
||||||
"title": conversation.get("title", "ChatGPT Conversation"),
|
|
||||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
|
||||||
"message_count": len(conversation.get("messages", [])),
|
|
||||||
"source": "ChatGPT Export",
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=doc_content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Create separate documents for each message
|
|
||||||
for message in conversation.get("messages", []):
|
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
role = message.get("role", "mixed")
|
|
||||||
content = message.get("content", "")
|
|
||||||
msg_timestamp = message.get("timestamp", "")
|
|
||||||
|
|
||||||
if not content.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create document content with context
|
|
||||||
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
|
|
||||||
Role: {role}
|
|
||||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
|
||||||
Message: {content}
|
|
||||||
"""
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
if include_metadata:
|
|
||||||
metadata = {
|
|
||||||
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
|
|
||||||
"role": role,
|
|
||||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
|
||||||
"source": "ChatGPT Export",
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=doc_content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
print(f"Created {len(docs)} documents from ChatGPT export")
|
|
||||||
return docs
|
|
||||||
@@ -1,186 +0,0 @@
|
|||||||
"""
|
|
||||||
ChatGPT RAG example using the unified interface.
|
|
||||||
Supports ChatGPT export data from chat.html files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
from chunking import create_text_chunks
|
|
||||||
|
|
||||||
from .chatgpt_data.chatgpt_reader import ChatGPTReader
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTRAG(BaseRAGExample):
|
|
||||||
"""RAG example for ChatGPT conversation data."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.max_items_default = -1 # Process all conversations by default
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="ChatGPT",
|
|
||||||
description="Process and query ChatGPT conversation exports with LEANN",
|
|
||||||
default_index_name="chatgpt_conversations_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add ChatGPT-specific arguments."""
|
|
||||||
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--export-path",
|
|
||||||
type=str,
|
|
||||||
default="./chatgpt_export",
|
|
||||||
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
|
|
||||||
)
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Concatenate messages within conversations for better context (default: True)",
|
|
||||||
)
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--separate-messages",
|
|
||||||
action="store_true",
|
|
||||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
|
||||||
)
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
|
||||||
)
|
|
||||||
chatgpt_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
|
|
||||||
"""
|
|
||||||
Find ChatGPT export files in the given path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
export_path: Path to search for exports
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of paths to ChatGPT export files
|
|
||||||
"""
|
|
||||||
export_files = []
|
|
||||||
|
|
||||||
if export_path.is_file():
|
|
||||||
if export_path.suffix.lower() in [".zip", ".html"]:
|
|
||||||
export_files.append(export_path)
|
|
||||||
elif export_path.is_dir():
|
|
||||||
# Look for zip and html files
|
|
||||||
export_files.extend(export_path.glob("*.zip"))
|
|
||||||
export_files.extend(export_path.glob("*.html"))
|
|
||||||
|
|
||||||
return export_files
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load ChatGPT export data and convert to text chunks."""
|
|
||||||
export_path = Path(args.export_path)
|
|
||||||
|
|
||||||
if not export_path.exists():
|
|
||||||
print(f"ChatGPT export path not found: {export_path}")
|
|
||||||
print(
|
|
||||||
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
|
|
||||||
)
|
|
||||||
print("\nTo export your ChatGPT data:")
|
|
||||||
print("1. Sign in to ChatGPT")
|
|
||||||
print("2. Click on your profile icon → Settings → Data Controls")
|
|
||||||
print("3. Click 'Export' under Export Data")
|
|
||||||
print("4. Download the zip file from the email link")
|
|
||||||
print("5. Extract or place the file/directory at the specified path")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Find export files
|
|
||||||
export_files = self._find_chatgpt_exports(export_path)
|
|
||||||
|
|
||||||
if not export_files:
|
|
||||||
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(export_files)} ChatGPT export files")
|
|
||||||
|
|
||||||
# Create reader with appropriate settings
|
|
||||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
|
||||||
reader = ChatGPTReader(concatenate_conversations=concatenate)
|
|
||||||
|
|
||||||
# Process each export file
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, export_file in enumerate(export_files):
|
|
||||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply max_items limit per file
|
|
||||||
max_per_file = -1
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_file = remaining
|
|
||||||
|
|
||||||
# Load conversations
|
|
||||||
documents = reader.load_data(
|
|
||||||
chatgpt_export_path=str(export_file),
|
|
||||||
max_count=max_per_file,
|
|
||||||
include_metadata=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
print(f"Processed {len(documents)} conversations from this file")
|
|
||||||
else:
|
|
||||||
print(f"No conversations loaded from {export_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {export_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No conversations found to process!")
|
|
||||||
print("\nTroubleshooting:")
|
|
||||||
print("- Ensure the export file is a valid ChatGPT export")
|
|
||||||
print("- Check that the HTML file contains conversation data")
|
|
||||||
print("- Try extracting the zip file and pointing to the HTML file directly")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
|
||||||
print("Now starting to split into text chunks... this may take some time")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Example queries for ChatGPT RAG
|
|
||||||
print("\n🤖 ChatGPT RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What did I ask about Python programming?'")
|
|
||||||
print("- 'Show me conversations about machine learning'")
|
|
||||||
print("- 'Find discussions about travel planning'")
|
|
||||||
print("- 'What advice did ChatGPT give me about career development?'")
|
|
||||||
print("- 'Search for conversations about cooking recipes'")
|
|
||||||
print("\nTo get started:")
|
|
||||||
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
|
|
||||||
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
|
|
||||||
print("3. Run this script to build your personal ChatGPT knowledge base!")
|
|
||||||
print("\nOr run without --query for interactive mode\n")
|
|
||||||
|
|
||||||
rag = ChatGPTRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
@@ -1,420 +0,0 @@
|
|||||||
"""
|
|
||||||
Claude export data reader.
|
|
||||||
|
|
||||||
Reads and processes Claude conversation data from exported JSON files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from zipfile import ZipFile
|
|
||||||
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeReader(BaseReader):
|
|
||||||
"""
|
|
||||||
Claude export data reader.
|
|
||||||
|
|
||||||
Reads Claude conversation data from exported JSON files or zip archives.
|
|
||||||
Processes conversations into structured documents with metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
|
||||||
"""
|
|
||||||
self.concatenate_conversations = concatenate_conversations
|
|
||||||
|
|
||||||
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
|
|
||||||
"""
|
|
||||||
Extract JSON files from Claude export zip file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
zip_path: Path to the Claude export zip file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of JSON content strings, or empty list if not found
|
|
||||||
"""
|
|
||||||
json_contents = []
|
|
||||||
try:
|
|
||||||
with ZipFile(zip_path, "r") as zip_file:
|
|
||||||
# Look for JSON files
|
|
||||||
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
|
|
||||||
|
|
||||||
if not json_files:
|
|
||||||
print(f"No JSON files found in {zip_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(json_files)} JSON files in archive")
|
|
||||||
|
|
||||||
for json_file in json_files:
|
|
||||||
with zip_file.open(json_file) as f:
|
|
||||||
content = f.read().decode("utf-8", errors="ignore")
|
|
||||||
json_contents.append(content)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error extracting JSON from zip {zip_path}: {e}")
|
|
||||||
|
|
||||||
return json_contents
|
|
||||||
|
|
||||||
def _parse_claude_json(self, json_content: str) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Parse Claude JSON export to extract conversations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_content: JSON content from Claude export
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of conversation dictionaries
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = json.loads(json_content)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
print(f"Error parsing JSON: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
conversations = []
|
|
||||||
|
|
||||||
# Handle different possible JSON structures
|
|
||||||
if isinstance(data, list):
|
|
||||||
# If data is a list of conversations
|
|
||||||
for item in data:
|
|
||||||
conversation = self._extract_conversation_from_json(item)
|
|
||||||
if conversation:
|
|
||||||
conversations.append(conversation)
|
|
||||||
elif isinstance(data, dict):
|
|
||||||
# Check for common structures
|
|
||||||
if "conversations" in data:
|
|
||||||
# Structure: {"conversations": [...]}
|
|
||||||
for item in data["conversations"]:
|
|
||||||
conversation = self._extract_conversation_from_json(item)
|
|
||||||
if conversation:
|
|
||||||
conversations.append(conversation)
|
|
||||||
elif "messages" in data:
|
|
||||||
# Single conversation with messages
|
|
||||||
conversation = self._extract_conversation_from_json(data)
|
|
||||||
if conversation:
|
|
||||||
conversations.append(conversation)
|
|
||||||
else:
|
|
||||||
# Try to treat the whole object as a conversation
|
|
||||||
conversation = self._extract_conversation_from_json(data)
|
|
||||||
if conversation:
|
|
||||||
conversations.append(conversation)
|
|
||||||
|
|
||||||
return conversations
|
|
||||||
|
|
||||||
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
|
|
||||||
"""
|
|
||||||
Extract conversation data from a JSON object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conv_data: Dictionary containing conversation data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with conversation data or None
|
|
||||||
"""
|
|
||||||
if not isinstance(conv_data, dict):
|
|
||||||
return None
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# Look for messages in various possible structures
|
|
||||||
message_sources = []
|
|
||||||
if "messages" in conv_data:
|
|
||||||
message_sources = conv_data["messages"]
|
|
||||||
elif "chat" in conv_data:
|
|
||||||
message_sources = conv_data["chat"]
|
|
||||||
elif "conversation" in conv_data:
|
|
||||||
message_sources = conv_data["conversation"]
|
|
||||||
else:
|
|
||||||
# If no clear message structure, try to extract from the object itself
|
|
||||||
if "content" in conv_data and "role" in conv_data:
|
|
||||||
message_sources = [conv_data]
|
|
||||||
|
|
||||||
for msg_data in message_sources:
|
|
||||||
message = self._extract_message_from_json(msg_data)
|
|
||||||
if message:
|
|
||||||
messages.append(message)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Extract conversation metadata
|
|
||||||
title = self._extract_title_from_conversation(conv_data, messages)
|
|
||||||
timestamp = self._extract_timestamp_from_conversation(conv_data)
|
|
||||||
|
|
||||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
|
||||||
|
|
||||||
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
|
|
||||||
"""
|
|
||||||
Extract message data from a JSON message object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg_data: Dictionary containing message data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with message data or None
|
|
||||||
"""
|
|
||||||
if not isinstance(msg_data, dict):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Extract content from various possible fields
|
|
||||||
content = ""
|
|
||||||
content_fields = ["content", "text", "message", "body"]
|
|
||||||
for field in content_fields:
|
|
||||||
if msg_data.get(field):
|
|
||||||
content = str(msg_data[field])
|
|
||||||
break
|
|
||||||
|
|
||||||
if not content or len(content.strip()) < 3:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Extract role (user/assistant/human/ai/claude)
|
|
||||||
role = "mixed" # Default role
|
|
||||||
role_fields = ["role", "sender", "from", "author", "type"]
|
|
||||||
for field in role_fields:
|
|
||||||
if msg_data.get(field):
|
|
||||||
role_value = str(msg_data[field]).lower()
|
|
||||||
if role_value in ["user", "human", "person"]:
|
|
||||||
role = "user"
|
|
||||||
elif role_value in ["assistant", "ai", "claude", "bot"]:
|
|
||||||
role = "assistant"
|
|
||||||
break
|
|
||||||
|
|
||||||
# Extract timestamp
|
|
||||||
timestamp = self._extract_timestamp_from_message(msg_data)
|
|
||||||
|
|
||||||
return {"role": role, "content": content, "timestamp": timestamp}
|
|
||||||
|
|
||||||
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
|
|
||||||
"""Extract timestamp from message data."""
|
|
||||||
timestamp_fields = ["timestamp", "created_at", "date", "time"]
|
|
||||||
for field in timestamp_fields:
|
|
||||||
if msg_data.get(field):
|
|
||||||
return str(msg_data[field])
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
|
|
||||||
"""Extract timestamp from conversation data."""
|
|
||||||
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
|
|
||||||
for field in timestamp_fields:
|
|
||||||
if conv_data.get(field):
|
|
||||||
return str(conv_data[field])
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
|
|
||||||
"""Extract or generate title for conversation."""
|
|
||||||
# Try to find explicit title
|
|
||||||
title_fields = ["title", "name", "subject", "topic"]
|
|
||||||
for field in title_fields:
|
|
||||||
if conv_data.get(field):
|
|
||||||
return str(conv_data[field])
|
|
||||||
|
|
||||||
# Generate title from first user message
|
|
||||||
for message in messages:
|
|
||||||
if message.get("role") == "user":
|
|
||||||
content = message.get("content", "")
|
|
||||||
if content:
|
|
||||||
# Use first 50 characters as title
|
|
||||||
title = content[:50].strip()
|
|
||||||
if len(content) > 50:
|
|
||||||
title += "..."
|
|
||||||
return title
|
|
||||||
|
|
||||||
return "Claude Conversation"
|
|
||||||
|
|
||||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
|
||||||
"""
|
|
||||||
Create concatenated content from conversation messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation: Dictionary containing conversation data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted concatenated content
|
|
||||||
"""
|
|
||||||
title = conversation.get("title", "Claude Conversation")
|
|
||||||
messages = conversation.get("messages", [])
|
|
||||||
timestamp = conversation.get("timestamp", "Unknown")
|
|
||||||
|
|
||||||
# Build message content
|
|
||||||
message_parts = []
|
|
||||||
for message in messages:
|
|
||||||
role = message.get("role", "mixed")
|
|
||||||
content = message.get("content", "")
|
|
||||||
msg_timestamp = message.get("timestamp", "")
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
prefix = "[You]"
|
|
||||||
elif role == "assistant":
|
|
||||||
prefix = "[Claude]"
|
|
||||||
else:
|
|
||||||
prefix = "[Message]"
|
|
||||||
|
|
||||||
# Add timestamp if available
|
|
||||||
if msg_timestamp:
|
|
||||||
prefix += f" ({msg_timestamp})"
|
|
||||||
|
|
||||||
message_parts.append(f"{prefix}: {content}")
|
|
||||||
|
|
||||||
concatenated_text = "\n\n".join(message_parts)
|
|
||||||
|
|
||||||
# Create final document content
|
|
||||||
doc_content = f"""Conversation: {title}
|
|
||||||
Date: {timestamp}
|
|
||||||
Messages ({len(messages)} messages):
|
|
||||||
|
|
||||||
{concatenated_text}
|
|
||||||
"""
|
|
||||||
return doc_content
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
|
||||||
"""
|
|
||||||
Load Claude export data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Directory containing Claude export files or path to specific file
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum number of conversations to process
|
|
||||||
claude_export_path (str): Specific path to Claude export file/directory
|
|
||||||
include_metadata (bool): Whether to include metadata in documents
|
|
||||||
"""
|
|
||||||
docs: list[Document] = []
|
|
||||||
max_count = load_kwargs.get("max_count", -1)
|
|
||||||
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
|
|
||||||
include_metadata = load_kwargs.get("include_metadata", True)
|
|
||||||
|
|
||||||
if not claude_export_path:
|
|
||||||
print("No Claude export path provided")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
export_path = Path(claude_export_path)
|
|
||||||
|
|
||||||
if not export_path.exists():
|
|
||||||
print(f"Claude export path not found: {export_path}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
json_contents = []
|
|
||||||
|
|
||||||
# Handle different input types
|
|
||||||
if export_path.is_file():
|
|
||||||
if export_path.suffix.lower() == ".zip":
|
|
||||||
# Extract JSON from zip file
|
|
||||||
json_contents = self._extract_json_from_zip(export_path)
|
|
||||||
elif export_path.suffix.lower() == ".json":
|
|
||||||
# Read JSON file directly
|
|
||||||
try:
|
|
||||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
|
||||||
json_contents.append(f.read())
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading JSON file {export_path}: {e}")
|
|
||||||
return docs
|
|
||||||
else:
|
|
||||||
print(f"Unsupported file type: {export_path.suffix}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
elif export_path.is_dir():
|
|
||||||
# Look for JSON files in directory
|
|
||||||
json_files = list(export_path.glob("*.json"))
|
|
||||||
zip_files = list(export_path.glob("*.zip"))
|
|
||||||
|
|
||||||
if json_files:
|
|
||||||
print(f"Found {len(json_files)} JSON files in directory")
|
|
||||||
for json_file in json_files:
|
|
||||||
try:
|
|
||||||
with open(json_file, encoding="utf-8", errors="ignore") as f:
|
|
||||||
json_contents.append(f.read())
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading JSON file {json_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if zip_files:
|
|
||||||
print(f"Found {len(zip_files)} ZIP files in directory")
|
|
||||||
for zip_file in zip_files:
|
|
||||||
zip_contents = self._extract_json_from_zip(zip_file)
|
|
||||||
json_contents.extend(zip_contents)
|
|
||||||
|
|
||||||
if not json_files and not zip_files:
|
|
||||||
print(f"No JSON or ZIP files found in {export_path}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
if not json_contents:
|
|
||||||
print("No JSON content found to process")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
# Parse conversations from JSON content
|
|
||||||
print("Parsing Claude conversations from JSON...")
|
|
||||||
all_conversations = []
|
|
||||||
for json_content in json_contents:
|
|
||||||
conversations = self._parse_claude_json(json_content)
|
|
||||||
all_conversations.extend(conversations)
|
|
||||||
|
|
||||||
if not all_conversations:
|
|
||||||
print("No conversations found in JSON content")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
print(f"Found {len(all_conversations)} conversations")
|
|
||||||
|
|
||||||
# Process conversations into documents
|
|
||||||
count = 0
|
|
||||||
for conversation in all_conversations:
|
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
# Create one document per conversation with concatenated messages
|
|
||||||
doc_content = self._create_concatenated_content(conversation)
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
if include_metadata:
|
|
||||||
metadata = {
|
|
||||||
"title": conversation.get("title", "Claude Conversation"),
|
|
||||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
|
||||||
"message_count": len(conversation.get("messages", [])),
|
|
||||||
"source": "Claude Export",
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=doc_content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Create separate documents for each message
|
|
||||||
for message in conversation.get("messages", []):
|
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
role = message.get("role", "mixed")
|
|
||||||
content = message.get("content", "")
|
|
||||||
msg_timestamp = message.get("timestamp", "")
|
|
||||||
|
|
||||||
if not content.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create document content with context
|
|
||||||
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
|
|
||||||
Role: {role}
|
|
||||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
|
||||||
Message: {content}
|
|
||||||
"""
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
if include_metadata:
|
|
||||||
metadata = {
|
|
||||||
"conversation_title": conversation.get("title", "Claude Conversation"),
|
|
||||||
"role": role,
|
|
||||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
|
||||||
"source": "Claude Export",
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=doc_content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
print(f"Created {len(docs)} documents from Claude export")
|
|
||||||
return docs
|
|
||||||
@@ -1,189 +0,0 @@
|
|||||||
"""
|
|
||||||
Claude RAG example using the unified interface.
|
|
||||||
Supports Claude export data from JSON files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
from chunking import create_text_chunks
|
|
||||||
|
|
||||||
from .claude_data.claude_reader import ClaudeReader
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeRAG(BaseRAGExample):
|
|
||||||
"""RAG example for Claude conversation data."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.max_items_default = -1 # Process all conversations by default
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="Claude",
|
|
||||||
description="Process and query Claude conversation exports with LEANN",
|
|
||||||
default_index_name="claude_conversations_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add Claude-specific arguments."""
|
|
||||||
claude_group = parser.add_argument_group("Claude Parameters")
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--export-path",
|
|
||||||
type=str,
|
|
||||||
default="./claude_export",
|
|
||||||
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
|
|
||||||
)
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Concatenate messages within conversations for better context (default: True)",
|
|
||||||
)
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--separate-messages",
|
|
||||||
action="store_true",
|
|
||||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
|
||||||
)
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
|
||||||
)
|
|
||||||
claude_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_claude_exports(self, export_path: Path) -> list[Path]:
|
|
||||||
"""
|
|
||||||
Find Claude export files in the given path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
export_path: Path to search for exports
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of paths to Claude export files
|
|
||||||
"""
|
|
||||||
export_files = []
|
|
||||||
|
|
||||||
if export_path.is_file():
|
|
||||||
if export_path.suffix.lower() in [".zip", ".json"]:
|
|
||||||
export_files.append(export_path)
|
|
||||||
elif export_path.is_dir():
|
|
||||||
# Look for zip and json files
|
|
||||||
export_files.extend(export_path.glob("*.zip"))
|
|
||||||
export_files.extend(export_path.glob("*.json"))
|
|
||||||
|
|
||||||
return export_files
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load Claude export data and convert to text chunks."""
|
|
||||||
export_path = Path(args.export_path)
|
|
||||||
|
|
||||||
if not export_path.exists():
|
|
||||||
print(f"Claude export path not found: {export_path}")
|
|
||||||
print(
|
|
||||||
"Please ensure you have exported your Claude data and placed it in the correct location."
|
|
||||||
)
|
|
||||||
print("\nTo export your Claude data:")
|
|
||||||
print("1. Open Claude in your browser")
|
|
||||||
print("2. Look for export/download options in settings or conversation menu")
|
|
||||||
print("3. Download the conversation data (usually in JSON format)")
|
|
||||||
print("4. Place the file/directory at the specified path")
|
|
||||||
print(
|
|
||||||
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Find export files
|
|
||||||
export_files = self._find_claude_exports(export_path)
|
|
||||||
|
|
||||||
if not export_files:
|
|
||||||
print(f"No Claude export files (.json or .zip) found in: {export_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(export_files)} Claude export files")
|
|
||||||
|
|
||||||
# Create reader with appropriate settings
|
|
||||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
|
||||||
reader = ClaudeReader(concatenate_conversations=concatenate)
|
|
||||||
|
|
||||||
# Process each export file
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, export_file in enumerate(export_files):
|
|
||||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply max_items limit per file
|
|
||||||
max_per_file = -1
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_file = remaining
|
|
||||||
|
|
||||||
# Load conversations
|
|
||||||
documents = reader.load_data(
|
|
||||||
claude_export_path=str(export_file),
|
|
||||||
max_count=max_per_file,
|
|
||||||
include_metadata=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
print(f"Processed {len(documents)} conversations from this file")
|
|
||||||
else:
|
|
||||||
print(f"No conversations loaded from {export_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {export_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No conversations found to process!")
|
|
||||||
print("\nTroubleshooting:")
|
|
||||||
print("- Ensure the export file is a valid Claude export")
|
|
||||||
print("- Check that the JSON file contains conversation data")
|
|
||||||
print("- Try using a different export format or method")
|
|
||||||
print("- Check Claude's documentation for current export procedures")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
|
||||||
print("Now starting to split into text chunks... this may take some time")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Example queries for Claude RAG
|
|
||||||
print("\n🤖 Claude RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What did I ask Claude about Python programming?'")
|
|
||||||
print("- 'Show me conversations about machine learning'")
|
|
||||||
print("- 'Find discussions about code optimization'")
|
|
||||||
print("- 'What advice did Claude give me about software design?'")
|
|
||||||
print("- 'Search for conversations about debugging techniques'")
|
|
||||||
print("\nTo get started:")
|
|
||||||
print("1. Export your Claude conversation data")
|
|
||||||
print("2. Place the JSON/ZIP file in ./claude_export/")
|
|
||||||
print("3. Run this script to build your personal Claude knowledge base!")
|
|
||||||
print("\nOr run without --query for interactive mode\n")
|
|
||||||
|
|
||||||
rag = ClaudeRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""iMessage data processing module."""
|
|
||||||
@@ -1,342 +0,0 @@
|
|||||||
"""
|
|
||||||
iMessage data reader.
|
|
||||||
|
|
||||||
Reads and processes iMessage conversation data from the macOS Messages database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
|
|
||||||
class IMessageReader(BaseReader):
|
|
||||||
"""
|
|
||||||
iMessage data reader.
|
|
||||||
|
|
||||||
Reads iMessage conversation data from the macOS Messages database (chat.db).
|
|
||||||
Processes conversations into structured documents with metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
|
||||||
"""
|
|
||||||
self.concatenate_conversations = concatenate_conversations
|
|
||||||
|
|
||||||
def _get_default_chat_db_path(self) -> Path:
|
|
||||||
"""
|
|
||||||
Get the default path to the iMessage chat database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the chat.db file
|
|
||||||
"""
|
|
||||||
home = Path.home()
|
|
||||||
return home / "Library" / "Messages" / "chat.db"
|
|
||||||
|
|
||||||
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
|
|
||||||
"""
|
|
||||||
Convert Cocoa timestamp to readable format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted timestamp string
|
|
||||||
"""
|
|
||||||
if cocoa_timestamp == 0:
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
|
|
||||||
# Convert to seconds and add to Unix epoch
|
|
||||||
cocoa_epoch = datetime(2001, 1, 1)
|
|
||||||
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
|
|
||||||
message_time = cocoa_epoch.timestamp() + unix_timestamp
|
|
||||||
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
except (ValueError, OSError):
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
def _get_contact_name(self, handle_id: str) -> str:
|
|
||||||
"""
|
|
||||||
Get a readable contact name from handle ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
handle_id: The handle ID (phone number or email)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted contact name
|
|
||||||
"""
|
|
||||||
if not handle_id:
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
# Clean up phone numbers and emails for display
|
|
||||||
if "@" in handle_id:
|
|
||||||
return handle_id # Email address
|
|
||||||
elif handle_id.startswith("+"):
|
|
||||||
return handle_id # International phone number
|
|
||||||
else:
|
|
||||||
# Try to format as phone number
|
|
||||||
digits = "".join(filter(str.isdigit, handle_id))
|
|
||||||
if len(digits) == 10:
|
|
||||||
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
|
|
||||||
elif len(digits) == 11 and digits[0] == "1":
|
|
||||||
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
|
|
||||||
else:
|
|
||||||
return handle_id
|
|
||||||
|
|
||||||
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Read messages from the iMessage database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_path: Path to the chat.db file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of message dictionaries
|
|
||||||
"""
|
|
||||||
if not db_path.exists():
|
|
||||||
print(f"iMessage database not found at: {db_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Connect to the database
|
|
||||||
conn = sqlite3.connect(str(db_path))
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# Query to get messages with chat and handle information
|
|
||||||
query = """
|
|
||||||
SELECT
|
|
||||||
m.ROWID as message_id,
|
|
||||||
m.text,
|
|
||||||
m.date,
|
|
||||||
m.is_from_me,
|
|
||||||
m.service,
|
|
||||||
c.chat_identifier,
|
|
||||||
c.display_name as chat_display_name,
|
|
||||||
h.id as handle_id,
|
|
||||||
c.ROWID as chat_id
|
|
||||||
FROM message m
|
|
||||||
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
|
|
||||||
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
|
|
||||||
LEFT JOIN handle h ON m.handle_id = h.ROWID
|
|
||||||
WHERE m.text IS NOT NULL AND m.text != ''
|
|
||||||
ORDER BY c.ROWID, m.date
|
|
||||||
"""
|
|
||||||
|
|
||||||
cursor.execute(query)
|
|
||||||
rows = cursor.fetchall()
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
for row in rows:
|
|
||||||
(
|
|
||||||
message_id,
|
|
||||||
text,
|
|
||||||
date,
|
|
||||||
is_from_me,
|
|
||||||
service,
|
|
||||||
chat_identifier,
|
|
||||||
chat_display_name,
|
|
||||||
handle_id,
|
|
||||||
chat_id,
|
|
||||||
) = row
|
|
||||||
|
|
||||||
message = {
|
|
||||||
"message_id": message_id,
|
|
||||||
"text": text,
|
|
||||||
"timestamp": self._convert_cocoa_timestamp(date),
|
|
||||||
"is_from_me": bool(is_from_me),
|
|
||||||
"service": service or "iMessage",
|
|
||||||
"chat_identifier": chat_identifier or "Unknown",
|
|
||||||
"chat_display_name": chat_display_name or "Unknown Chat",
|
|
||||||
"handle_id": handle_id or "Unknown",
|
|
||||||
"contact_name": self._get_contact_name(handle_id or ""),
|
|
||||||
"chat_id": chat_id,
|
|
||||||
}
|
|
||||||
messages.append(message)
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
print(f"Found {len(messages)} messages in database")
|
|
||||||
return messages
|
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
print(f"Error reading iMessage database: {e}")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Unexpected error reading iMessage database: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
|
|
||||||
"""
|
|
||||||
Group messages by chat ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of message dictionaries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping chat_id to list of messages
|
|
||||||
"""
|
|
||||||
chats = {}
|
|
||||||
for message in messages:
|
|
||||||
chat_id = message["chat_id"]
|
|
||||||
if chat_id not in chats:
|
|
||||||
chats[chat_id] = []
|
|
||||||
chats[chat_id].append(message)
|
|
||||||
|
|
||||||
return chats
|
|
||||||
|
|
||||||
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
|
|
||||||
"""
|
|
||||||
Create concatenated content from chat messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: The chat ID
|
|
||||||
messages: List of messages in the chat
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Concatenated text content
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Get chat info from first message
|
|
||||||
first_msg = messages[0]
|
|
||||||
chat_name = first_msg["chat_display_name"]
|
|
||||||
chat_identifier = first_msg["chat_identifier"]
|
|
||||||
|
|
||||||
# Build message content
|
|
||||||
message_parts = []
|
|
||||||
for message in messages:
|
|
||||||
timestamp = message["timestamp"]
|
|
||||||
is_from_me = message["is_from_me"]
|
|
||||||
text = message["text"]
|
|
||||||
contact_name = message["contact_name"]
|
|
||||||
|
|
||||||
if is_from_me:
|
|
||||||
prefix = "[You]"
|
|
||||||
else:
|
|
||||||
prefix = f"[{contact_name}]"
|
|
||||||
|
|
||||||
if timestamp != "Unknown":
|
|
||||||
prefix += f" ({timestamp})"
|
|
||||||
|
|
||||||
message_parts.append(f"{prefix}: {text}")
|
|
||||||
|
|
||||||
concatenated_text = "\n\n".join(message_parts)
|
|
||||||
|
|
||||||
doc_content = f"""Chat: {chat_name}
|
|
||||||
Identifier: {chat_identifier}
|
|
||||||
Messages ({len(messages)} messages):
|
|
||||||
|
|
||||||
{concatenated_text}
|
|
||||||
"""
|
|
||||||
return doc_content
|
|
||||||
|
|
||||||
def _create_individual_content(self, message: dict) -> str:
|
|
||||||
"""
|
|
||||||
Create content for individual message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: Message dictionary
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted message content
|
|
||||||
"""
|
|
||||||
timestamp = message["timestamp"]
|
|
||||||
is_from_me = message["is_from_me"]
|
|
||||||
text = message["text"]
|
|
||||||
contact_name = message["contact_name"]
|
|
||||||
chat_name = message["chat_display_name"]
|
|
||||||
|
|
||||||
sender = "You" if is_from_me else contact_name
|
|
||||||
|
|
||||||
return f"""Message from {sender} in chat "{chat_name}"
|
|
||||||
Time: {timestamp}
|
|
||||||
Content: {text}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
|
||||||
"""
|
|
||||||
Load iMessage data and return as documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Optional path to directory containing chat.db file.
|
|
||||||
If not provided, uses default macOS location.
|
|
||||||
**load_kwargs: Additional arguments (unused)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Document objects containing iMessage data
|
|
||||||
"""
|
|
||||||
docs = []
|
|
||||||
|
|
||||||
# Determine database path
|
|
||||||
if input_dir:
|
|
||||||
db_path = Path(input_dir) / "chat.db"
|
|
||||||
else:
|
|
||||||
db_path = self._get_default_chat_db_path()
|
|
||||||
|
|
||||||
print(f"Reading iMessage database from: {db_path}")
|
|
||||||
|
|
||||||
# Read messages from database
|
|
||||||
messages = self._read_messages_from_db(db_path)
|
|
||||||
if not messages:
|
|
||||||
return docs
|
|
||||||
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
# Group messages by chat and create concatenated documents
|
|
||||||
chats = self._group_messages_by_chat(messages)
|
|
||||||
|
|
||||||
for chat_id, chat_messages in chats.items():
|
|
||||||
if not chat_messages:
|
|
||||||
continue
|
|
||||||
|
|
||||||
content = self._create_concatenated_content(chat_id, chat_messages)
|
|
||||||
|
|
||||||
# Create metadata
|
|
||||||
first_msg = chat_messages[0]
|
|
||||||
last_msg = chat_messages[-1]
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"source": "iMessage",
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"chat_name": first_msg["chat_display_name"],
|
|
||||||
"chat_identifier": first_msg["chat_identifier"],
|
|
||||||
"message_count": len(chat_messages),
|
|
||||||
"first_message_date": first_msg["timestamp"],
|
|
||||||
"last_message_date": last_msg["timestamp"],
|
|
||||||
"participants": list(
|
|
||||||
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Create individual documents for each message
|
|
||||||
for message in messages:
|
|
||||||
content = self._create_individual_content(message)
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"source": "iMessage",
|
|
||||||
"message_id": message["message_id"],
|
|
||||||
"chat_id": message["chat_id"],
|
|
||||||
"chat_name": message["chat_display_name"],
|
|
||||||
"chat_identifier": message["chat_identifier"],
|
|
||||||
"timestamp": message["timestamp"],
|
|
||||||
"is_from_me": message["is_from_me"],
|
|
||||||
"contact_name": message["contact_name"],
|
|
||||||
"service": message["service"],
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = Document(text=content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
|
|
||||||
print(f"Created {len(docs)} documents from iMessage data")
|
|
||||||
return docs
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
"""
|
|
||||||
iMessage RAG Example.
|
|
||||||
|
|
||||||
This example demonstrates how to build a RAG system on your iMessage conversation history.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from leann.chunking_utils import create_text_chunks
|
|
||||||
|
|
||||||
from apps.base_rag_example import BaseRAGExample
|
|
||||||
from apps.imessage_data.imessage_reader import IMessageReader
|
|
||||||
|
|
||||||
|
|
||||||
class IMessageRAG(BaseRAGExample):
|
|
||||||
"""RAG example for iMessage conversation history."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
name="iMessage",
|
|
||||||
description="RAG on your iMessage conversation history",
|
|
||||||
default_index_name="imessage_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add iMessage-specific arguments."""
|
|
||||||
imessage_group = parser.add_argument_group("iMessage Parameters")
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--db-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
|
|
||||||
)
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Concatenate messages within conversations for better context (default: True)",
|
|
||||||
)
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--no-concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
help="Process each message individually instead of concatenating by conversation",
|
|
||||||
)
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--chunk-size",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Maximum characters per text chunk (default: 1000)",
|
|
||||||
)
|
|
||||||
imessage_group.add_argument(
|
|
||||||
"--chunk-overlap",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="Overlap between text chunks (default: 200)",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load iMessage history and convert to text chunks."""
|
|
||||||
print("Loading iMessage conversation history...")
|
|
||||||
|
|
||||||
# Determine concatenation setting
|
|
||||||
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
|
|
||||||
|
|
||||||
# Initialize iMessage reader
|
|
||||||
reader = IMessageReader(concatenate_conversations=concatenate)
|
|
||||||
|
|
||||||
# Load documents
|
|
||||||
try:
|
|
||||||
if args.db_path:
|
|
||||||
# Use custom database path
|
|
||||||
db_dir = str(Path(args.db_path).parent)
|
|
||||||
documents = reader.load_data(input_dir=db_dir)
|
|
||||||
else:
|
|
||||||
# Use default macOS location
|
|
||||||
documents = reader.load_data()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading iMessage data: {e}")
|
|
||||||
print("\nTroubleshooting tips:")
|
|
||||||
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
|
|
||||||
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
|
|
||||||
print("3. Try specifying a custom path with --db-path if you have a backup")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
print("No iMessage conversations found!")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} iMessage documents")
|
|
||||||
|
|
||||||
# Show some statistics
|
|
||||||
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
|
|
||||||
print(f"Total messages: {total_messages}")
|
|
||||||
|
|
||||||
if concatenate:
|
|
||||||
# Show chat statistics
|
|
||||||
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
|
|
||||||
unique_chats = len(set(chat_names))
|
|
||||||
print(f"Unique conversations: {unique_chats}")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
documents,
|
|
||||||
chunk_size=args.chunk_size,
|
|
||||||
chunk_overlap=args.chunk_overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply max_items limit if specified
|
|
||||||
if args.max_items > 0:
|
|
||||||
all_texts = all_texts[: args.max_items]
|
|
||||||
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main entry point."""
|
|
||||||
app = IMessageRAG()
|
|
||||||
await app.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# Slack MCP data integration for LEANN
|
|
||||||
@@ -1,334 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Slack MCP Reader for LEANN
|
|
||||||
|
|
||||||
This module provides functionality to connect to Slack MCP servers and fetch message data
|
|
||||||
for indexing in LEANN. It supports various Slack MCP server implementations and provides
|
|
||||||
flexible message processing options.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SlackMCPReader:
|
|
||||||
"""
|
|
||||||
Reader for Slack data via MCP (Model Context Protocol) servers.
|
|
||||||
|
|
||||||
This class connects to Slack MCP servers to fetch message data and convert it
|
|
||||||
into a format suitable for LEANN indexing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mcp_server_command: str,
|
|
||||||
workspace_name: Optional[str] = None,
|
|
||||||
concatenate_conversations: bool = True,
|
|
||||||
max_messages_per_conversation: int = 100,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the Slack MCP Reader.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mcp_server_command: Command to start the MCP server (e.g., 'slack-mcp-server')
|
|
||||||
workspace_name: Optional workspace name to filter messages
|
|
||||||
concatenate_conversations: Whether to group messages by channel/thread
|
|
||||||
max_messages_per_conversation: Maximum messages to include per conversation
|
|
||||||
"""
|
|
||||||
self.mcp_server_command = mcp_server_command
|
|
||||||
self.workspace_name = workspace_name
|
|
||||||
self.concatenate_conversations = concatenate_conversations
|
|
||||||
self.max_messages_per_conversation = max_messages_per_conversation
|
|
||||||
self.mcp_process = None
|
|
||||||
|
|
||||||
async def start_mcp_server(self):
|
|
||||||
"""Start the MCP server process."""
|
|
||||||
try:
|
|
||||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
|
||||||
*self.mcp_server_command.split(),
|
|
||||||
stdin=asyncio.subprocess.PIPE,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
)
|
|
||||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to start MCP server: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def stop_mcp_server(self):
|
|
||||||
"""Stop the MCP server process."""
|
|
||||||
if self.mcp_process:
|
|
||||||
self.mcp_process.terminate()
|
|
||||||
await self.mcp_process.wait()
|
|
||||||
logger.info("Stopped MCP server")
|
|
||||||
|
|
||||||
async def send_mcp_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Send a request to the MCP server and get response."""
|
|
||||||
if not self.mcp_process:
|
|
||||||
raise RuntimeError("MCP server not started")
|
|
||||||
|
|
||||||
request_json = json.dumps(request) + "\n"
|
|
||||||
self.mcp_process.stdin.write(request_json.encode())
|
|
||||||
await self.mcp_process.stdin.drain()
|
|
||||||
|
|
||||||
response_line = await self.mcp_process.stdout.readline()
|
|
||||||
if not response_line:
|
|
||||||
raise RuntimeError("No response from MCP server")
|
|
||||||
|
|
||||||
return json.loads(response_line.decode().strip())
|
|
||||||
|
|
||||||
async def initialize_mcp_connection(self):
|
|
||||||
"""Initialize the MCP connection."""
|
|
||||||
init_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 1,
|
|
||||||
"method": "initialize",
|
|
||||||
"params": {
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(init_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
|
||||||
|
|
||||||
logger.info("MCP connection initialized successfully")
|
|
||||||
|
|
||||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
|
||||||
"""List available tools from the MCP server."""
|
|
||||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(list_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
|
||||||
|
|
||||||
return response.get("result", {}).get("tools", [])
|
|
||||||
|
|
||||||
async def fetch_slack_messages(
|
|
||||||
self, channel: Optional[str] = None, limit: int = 100
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Fetch Slack messages using MCP tools.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: Optional channel name to filter messages
|
|
||||||
limit: Maximum number of messages to fetch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of message dictionaries
|
|
||||||
"""
|
|
||||||
# This is a generic implementation - specific MCP servers may have different tool names
|
|
||||||
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
|
|
||||||
|
|
||||||
tools = await self.list_available_tools()
|
|
||||||
message_tool = None
|
|
||||||
|
|
||||||
# Look for a tool that can fetch messages
|
|
||||||
for tool in tools:
|
|
||||||
tool_name = tool.get("name", "").lower()
|
|
||||||
if any(
|
|
||||||
keyword in tool_name
|
|
||||||
for keyword in ["message", "history", "channel", "conversation"]
|
|
||||||
):
|
|
||||||
message_tool = tool
|
|
||||||
break
|
|
||||||
|
|
||||||
if not message_tool:
|
|
||||||
raise RuntimeError("No message fetching tool found in MCP server")
|
|
||||||
|
|
||||||
# Prepare tool call parameters
|
|
||||||
tool_params = {"limit": limit}
|
|
||||||
if channel:
|
|
||||||
# Try common parameter names for channel specification
|
|
||||||
for param_name in ["channel", "channel_id", "channel_name"]:
|
|
||||||
tool_params[param_name] = channel
|
|
||||||
break
|
|
||||||
|
|
||||||
fetch_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 3,
|
|
||||||
"method": "tools/call",
|
|
||||||
"params": {"name": message_tool["name"], "arguments": tool_params},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(fetch_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"Failed to fetch messages: {response['error']}")
|
|
||||||
|
|
||||||
# Extract messages from response - format may vary by MCP server
|
|
||||||
result = response.get("result", {})
|
|
||||||
if "content" in result and isinstance(result["content"], list):
|
|
||||||
# Some MCP servers return content as a list
|
|
||||||
content = result["content"][0] if result["content"] else {}
|
|
||||||
if "text" in content:
|
|
||||||
try:
|
|
||||||
messages = json.loads(content["text"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# If not JSON, treat as plain text
|
|
||||||
messages = [{"text": content["text"], "channel": channel or "unknown"}]
|
|
||||||
else:
|
|
||||||
messages = result["content"]
|
|
||||||
else:
|
|
||||||
# Direct message format
|
|
||||||
messages = result.get("messages", [result])
|
|
||||||
|
|
||||||
return messages if isinstance(messages, list) else [messages]
|
|
||||||
|
|
||||||
def _format_message(self, message: Dict[str, Any]) -> str:
|
|
||||||
"""Format a single message for indexing."""
|
|
||||||
text = message.get("text", "")
|
|
||||||
user = message.get("user", message.get("username", "Unknown"))
|
|
||||||
channel = message.get("channel", message.get("channel_name", "Unknown"))
|
|
||||||
timestamp = message.get("ts", message.get("timestamp", ""))
|
|
||||||
|
|
||||||
# Format timestamp if available
|
|
||||||
formatted_time = ""
|
|
||||||
if timestamp:
|
|
||||||
try:
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
if isinstance(timestamp, str) and "." in timestamp:
|
|
||||||
dt = datetime.datetime.fromtimestamp(float(timestamp))
|
|
||||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
elif isinstance(timestamp, (int, float)):
|
|
||||||
dt = datetime.datetime.fromtimestamp(timestamp)
|
|
||||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
else:
|
|
||||||
formatted_time = str(timestamp)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
formatted_time = str(timestamp)
|
|
||||||
|
|
||||||
# Build formatted message
|
|
||||||
parts = []
|
|
||||||
if channel:
|
|
||||||
parts.append(f"Channel: #{channel}")
|
|
||||||
if user:
|
|
||||||
parts.append(f"User: {user}")
|
|
||||||
if formatted_time:
|
|
||||||
parts.append(f"Time: {formatted_time}")
|
|
||||||
if text:
|
|
||||||
parts.append(f"Message: {text}")
|
|
||||||
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
def _create_concatenated_content(self, messages: List[Dict[str, Any]], channel: str) -> str:
|
|
||||||
"""Create concatenated content from multiple messages in a channel."""
|
|
||||||
if not messages:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Sort messages by timestamp if available
|
|
||||||
try:
|
|
||||||
messages.sort(key=lambda x: float(x.get("ts", x.get("timestamp", 0))))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass # Keep original order if timestamps aren't numeric
|
|
||||||
|
|
||||||
# Limit messages per conversation
|
|
||||||
if len(messages) > self.max_messages_per_conversation:
|
|
||||||
messages = messages[-self.max_messages_per_conversation :]
|
|
||||||
|
|
||||||
# Create header
|
|
||||||
content_parts = [
|
|
||||||
f"Slack Channel: #{channel}",
|
|
||||||
f"Message Count: {len(messages)}",
|
|
||||||
f"Workspace: {self.workspace_name or 'Unknown'}",
|
|
||||||
"=" * 50,
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add messages
|
|
||||||
for message in messages:
|
|
||||||
formatted_msg = self._format_message(message)
|
|
||||||
if formatted_msg.strip():
|
|
||||||
content_parts.append(formatted_msg)
|
|
||||||
content_parts.append("-" * 30)
|
|
||||||
content_parts.append("")
|
|
||||||
|
|
||||||
return "\n".join(content_parts)
|
|
||||||
|
|
||||||
async def read_slack_data(self, channels: Optional[List[str]] = None) -> List[str]:
|
|
||||||
"""
|
|
||||||
Read Slack data and return formatted text chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channels: Optional list of channel names to fetch. If None, fetches from all available channels.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of formatted text chunks ready for LEANN indexing
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await self.start_mcp_server()
|
|
||||||
await self.initialize_mcp_connection()
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
|
|
||||||
if channels:
|
|
||||||
# Fetch specific channels
|
|
||||||
for channel in channels:
|
|
||||||
try:
|
|
||||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
|
||||||
if messages:
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
text_content = self._create_concatenated_content(messages, channel)
|
|
||||||
if text_content.strip():
|
|
||||||
all_texts.append(text_content)
|
|
||||||
else:
|
|
||||||
# Process individual messages
|
|
||||||
for message in messages:
|
|
||||||
formatted_msg = self._format_message(message)
|
|
||||||
if formatted_msg.strip():
|
|
||||||
all_texts.append(formatted_msg)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# Fetch from all available channels/conversations
|
|
||||||
# This is a simplified approach - real implementation would need to
|
|
||||||
# discover available channels first
|
|
||||||
try:
|
|
||||||
messages = await self.fetch_slack_messages(limit=1000)
|
|
||||||
if messages:
|
|
||||||
# Group messages by channel if concatenating
|
|
||||||
if self.concatenate_conversations:
|
|
||||||
channel_messages = {}
|
|
||||||
for message in messages:
|
|
||||||
channel = message.get(
|
|
||||||
"channel", message.get("channel_name", "general")
|
|
||||||
)
|
|
||||||
if channel not in channel_messages:
|
|
||||||
channel_messages[channel] = []
|
|
||||||
channel_messages[channel].append(message)
|
|
||||||
|
|
||||||
# Create concatenated content for each channel
|
|
||||||
for channel, msgs in channel_messages.items():
|
|
||||||
text_content = self._create_concatenated_content(msgs, channel)
|
|
||||||
if text_content.strip():
|
|
||||||
all_texts.append(text_content)
|
|
||||||
else:
|
|
||||||
# Process individual messages
|
|
||||||
for message in messages:
|
|
||||||
formatted_msg = self._format_message(message)
|
|
||||||
if formatted_msg.strip():
|
|
||||||
all_texts.append(formatted_msg)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to fetch messages: {e}")
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await self.stop_mcp_server()
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
"""Async context manager entry."""
|
|
||||||
await self.start_mcp_server()
|
|
||||||
await self.initialize_mcp_connection()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
"""Async context manager exit."""
|
|
||||||
await self.stop_mcp_server()
|
|
||||||
@@ -1,204 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Slack RAG Application with MCP Support
|
|
||||||
|
|
||||||
This application enables RAG (Retrieval-Augmented Generation) on Slack messages
|
|
||||||
by connecting to Slack MCP servers to fetch live data and index it in LEANN.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --query "What did the team discuss about the project?"
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from apps.base_rag_example import BaseRAGExample
|
|
||||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
|
||||||
|
|
||||||
|
|
||||||
class SlackMCPRAG(BaseRAGExample):
|
|
||||||
"""
|
|
||||||
RAG application for Slack messages via MCP servers.
|
|
||||||
|
|
||||||
This class provides a complete RAG pipeline for Slack data, including
|
|
||||||
MCP server connection, data fetching, indexing, and interactive chat.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.default_index_name = "slack_messages"
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
|
||||||
"""Add Slack MCP-specific arguments."""
|
|
||||||
parser.add_argument(
|
|
||||||
"--mcp-server",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Command to start the Slack MCP server (e.g., 'slack-mcp-server' or 'npx slack-mcp-server')",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--workspace-name",
|
|
||||||
type=str,
|
|
||||||
help="Slack workspace name for better organization and filtering",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--channels",
|
|
||||||
nargs="+",
|
|
||||||
help="Specific Slack channels to index (e.g., general random). If not specified, fetches from all available channels",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Group messages by channel/thread for better context (default: True)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-concatenate-conversations",
|
|
||||||
action="store_true",
|
|
||||||
help="Process individual messages instead of grouping by channel",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-messages-per-channel",
|
|
||||||
type=int,
|
|
||||||
default=100,
|
|
||||||
help="Maximum number of messages to include per channel (default: 100)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--test-connection",
|
|
||||||
action="store_true",
|
|
||||||
help="Test MCP server connection and list available tools without indexing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_mcp_connection(self, args) -> bool:
|
|
||||||
"""Test the MCP server connection and display available tools."""
|
|
||||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader = SlackMCPReader(
|
|
||||||
mcp_server_command=args.mcp_server,
|
|
||||||
workspace_name=args.workspace_name,
|
|
||||||
concatenate_conversations=not args.no_concatenate_conversations,
|
|
||||||
max_messages_per_conversation=args.max_messages_per_channel,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with reader:
|
|
||||||
tools = await reader.list_available_tools()
|
|
||||||
|
|
||||||
print("\n✅ Successfully connected to MCP server!")
|
|
||||||
print(f"Available tools ({len(tools)}):")
|
|
||||||
|
|
||||||
for i, tool in enumerate(tools, 1):
|
|
||||||
name = tool.get("name", "Unknown")
|
|
||||||
description = tool.get("description", "No description available")
|
|
||||||
print(f"\n{i}. {name}")
|
|
||||||
print(
|
|
||||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show input schema if available
|
|
||||||
schema = tool.get("inputSchema", {})
|
|
||||||
if schema.get("properties"):
|
|
||||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
|
||||||
print(
|
|
||||||
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Failed to connect to MCP server: {e}")
|
|
||||||
print("\nTroubleshooting tips:")
|
|
||||||
print("1. Make sure the MCP server is installed and accessible")
|
|
||||||
print("2. Check if the server command is correct")
|
|
||||||
print("3. Ensure you have proper authentication/credentials configured")
|
|
||||||
print("4. Try running the MCP server command directly to test it")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def load_data(self, args) -> List[str]:
|
|
||||||
"""Load Slack messages via MCP server."""
|
|
||||||
print(f"Connecting to Slack MCP server: {args.mcp_server}")
|
|
||||||
|
|
||||||
if args.workspace_name:
|
|
||||||
print(f"Workspace: {args.workspace_name}")
|
|
||||||
|
|
||||||
if args.channels:
|
|
||||||
print(f"Channels: {', '.join(args.channels)}")
|
|
||||||
else:
|
|
||||||
print("Fetching from all available channels")
|
|
||||||
|
|
||||||
concatenate = not args.no_concatenate_conversations
|
|
||||||
print(
|
|
||||||
f"Processing mode: {'Concatenated conversations' if concatenate else 'Individual messages'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader = SlackMCPReader(
|
|
||||||
mcp_server_command=args.mcp_server,
|
|
||||||
workspace_name=args.workspace_name,
|
|
||||||
concatenate_conversations=concatenate,
|
|
||||||
max_messages_per_conversation=args.max_messages_per_channel,
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = await reader.read_slack_data(channels=args.channels)
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
print("❌ No messages found! This could mean:")
|
|
||||||
print("- The MCP server couldn't fetch messages")
|
|
||||||
print("- The specified channels don't exist or are empty")
|
|
||||||
print("- Authentication issues with the Slack workspace")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"✅ Successfully loaded {len(texts)} text chunks from Slack")
|
|
||||||
|
|
||||||
# Show sample of what was loaded
|
|
||||||
if texts:
|
|
||||||
sample_text = texts[0][:200] + "..." if len(texts[0]) > 200 else texts[0]
|
|
||||||
print("\nSample content:")
|
|
||||||
print("-" * 40)
|
|
||||||
print(sample_text)
|
|
||||||
print("-" * 40)
|
|
||||||
|
|
||||||
return texts
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error loading Slack data: {e}")
|
|
||||||
print("\nThis might be due to:")
|
|
||||||
print("- MCP server connection issues")
|
|
||||||
print("- Authentication problems")
|
|
||||||
print("- Network connectivity issues")
|
|
||||||
print("- Incorrect channel names")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""Main entry point with MCP connection testing."""
|
|
||||||
args = self.parser.parse_args()
|
|
||||||
|
|
||||||
# Test connection if requested
|
|
||||||
if args.test_connection:
|
|
||||||
success = await self.test_mcp_connection(args)
|
|
||||||
if not success:
|
|
||||||
return
|
|
||||||
print(
|
|
||||||
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Run the standard RAG pipeline
|
|
||||||
await super().run()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main entry point for the Slack MCP RAG application."""
|
|
||||||
app = SlackMCPRAG()
|
|
||||||
await app.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# Twitter MCP data integration for LEANN
|
|
||||||
@@ -1,295 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Twitter MCP Reader for LEANN
|
|
||||||
|
|
||||||
This module provides functionality to connect to Twitter MCP servers and fetch bookmark data
|
|
||||||
for indexing in LEANN. It supports various Twitter MCP server implementations and provides
|
|
||||||
flexible bookmark processing options.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TwitterMCPReader:
|
|
||||||
"""
|
|
||||||
Reader for Twitter bookmark data via MCP (Model Context Protocol) servers.
|
|
||||||
|
|
||||||
This class connects to Twitter MCP servers to fetch bookmark data and convert it
|
|
||||||
into a format suitable for LEANN indexing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mcp_server_command: str,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
include_tweet_content: bool = True,
|
|
||||||
include_metadata: bool = True,
|
|
||||||
max_bookmarks: int = 1000,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the Twitter MCP Reader.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mcp_server_command: Command to start the MCP server (e.g., 'twitter-mcp-server')
|
|
||||||
username: Optional Twitter username to filter bookmarks
|
|
||||||
include_tweet_content: Whether to include full tweet content
|
|
||||||
include_metadata: Whether to include tweet metadata (likes, retweets, etc.)
|
|
||||||
max_bookmarks: Maximum number of bookmarks to fetch
|
|
||||||
"""
|
|
||||||
self.mcp_server_command = mcp_server_command
|
|
||||||
self.username = username
|
|
||||||
self.include_tweet_content = include_tweet_content
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
self.max_bookmarks = max_bookmarks
|
|
||||||
self.mcp_process = None
|
|
||||||
|
|
||||||
async def start_mcp_server(self):
|
|
||||||
"""Start the MCP server process."""
|
|
||||||
try:
|
|
||||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
|
||||||
*self.mcp_server_command.split(),
|
|
||||||
stdin=asyncio.subprocess.PIPE,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
)
|
|
||||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to start MCP server: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def stop_mcp_server(self):
|
|
||||||
"""Stop the MCP server process."""
|
|
||||||
if self.mcp_process:
|
|
||||||
self.mcp_process.terminate()
|
|
||||||
await self.mcp_process.wait()
|
|
||||||
logger.info("Stopped MCP server")
|
|
||||||
|
|
||||||
async def send_mcp_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Send a request to the MCP server and get response."""
|
|
||||||
if not self.mcp_process:
|
|
||||||
raise RuntimeError("MCP server not started")
|
|
||||||
|
|
||||||
request_json = json.dumps(request) + "\n"
|
|
||||||
self.mcp_process.stdin.write(request_json.encode())
|
|
||||||
await self.mcp_process.stdin.drain()
|
|
||||||
|
|
||||||
response_line = await self.mcp_process.stdout.readline()
|
|
||||||
if not response_line:
|
|
||||||
raise RuntimeError("No response from MCP server")
|
|
||||||
|
|
||||||
return json.loads(response_line.decode().strip())
|
|
||||||
|
|
||||||
async def initialize_mcp_connection(self):
|
|
||||||
"""Initialize the MCP connection."""
|
|
||||||
init_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 1,
|
|
||||||
"method": "initialize",
|
|
||||||
"params": {
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "leann-twitter-reader", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(init_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
|
||||||
|
|
||||||
logger.info("MCP connection initialized successfully")
|
|
||||||
|
|
||||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
|
||||||
"""List available tools from the MCP server."""
|
|
||||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(list_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
|
||||||
|
|
||||||
return response.get("result", {}).get("tools", [])
|
|
||||||
|
|
||||||
async def fetch_twitter_bookmarks(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Fetch Twitter bookmarks using MCP tools.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
limit: Maximum number of bookmarks to fetch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of bookmark dictionaries
|
|
||||||
"""
|
|
||||||
tools = await self.list_available_tools()
|
|
||||||
bookmark_tool = None
|
|
||||||
|
|
||||||
# Look for a tool that can fetch bookmarks
|
|
||||||
for tool in tools:
|
|
||||||
tool_name = tool.get("name", "").lower()
|
|
||||||
if any(keyword in tool_name for keyword in ["bookmark", "saved", "favorite"]):
|
|
||||||
bookmark_tool = tool
|
|
||||||
break
|
|
||||||
|
|
||||||
if not bookmark_tool:
|
|
||||||
raise RuntimeError("No bookmark fetching tool found in MCP server")
|
|
||||||
|
|
||||||
# Prepare tool call parameters
|
|
||||||
tool_params = {}
|
|
||||||
if limit or self.max_bookmarks:
|
|
||||||
tool_params["limit"] = limit or self.max_bookmarks
|
|
||||||
if self.username:
|
|
||||||
tool_params["username"] = self.username
|
|
||||||
|
|
||||||
fetch_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 3,
|
|
||||||
"method": "tools/call",
|
|
||||||
"params": {"name": bookmark_tool["name"], "arguments": tool_params},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.send_mcp_request(fetch_request)
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(f"Failed to fetch bookmarks: {response['error']}")
|
|
||||||
|
|
||||||
# Extract bookmarks from response
|
|
||||||
result = response.get("result", {})
|
|
||||||
if "content" in result and isinstance(result["content"], list):
|
|
||||||
content = result["content"][0] if result["content"] else {}
|
|
||||||
if "text" in content:
|
|
||||||
try:
|
|
||||||
bookmarks = json.loads(content["text"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# If not JSON, treat as plain text
|
|
||||||
bookmarks = [{"text": content["text"], "source": "twitter"}]
|
|
||||||
else:
|
|
||||||
bookmarks = result["content"]
|
|
||||||
else:
|
|
||||||
bookmarks = result.get("bookmarks", result.get("tweets", [result]))
|
|
||||||
|
|
||||||
return bookmarks if isinstance(bookmarks, list) else [bookmarks]
|
|
||||||
|
|
||||||
def _format_bookmark(self, bookmark: Dict[str, Any]) -> str:
|
|
||||||
"""Format a single bookmark for indexing."""
|
|
||||||
# Extract tweet information
|
|
||||||
text = bookmark.get("text", bookmark.get("content", ""))
|
|
||||||
author = bookmark.get(
|
|
||||||
"author", bookmark.get("username", bookmark.get("user", {}).get("username", "Unknown"))
|
|
||||||
)
|
|
||||||
timestamp = bookmark.get("created_at", bookmark.get("timestamp", ""))
|
|
||||||
url = bookmark.get("url", bookmark.get("tweet_url", ""))
|
|
||||||
|
|
||||||
# Extract metadata if available
|
|
||||||
likes = bookmark.get("likes", bookmark.get("favorite_count", 0))
|
|
||||||
retweets = bookmark.get("retweets", bookmark.get("retweet_count", 0))
|
|
||||||
replies = bookmark.get("replies", bookmark.get("reply_count", 0))
|
|
||||||
|
|
||||||
# Build formatted bookmark
|
|
||||||
parts = []
|
|
||||||
|
|
||||||
# Header
|
|
||||||
parts.append("=== Twitter Bookmark ===")
|
|
||||||
|
|
||||||
if author:
|
|
||||||
parts.append(f"Author: @{author}")
|
|
||||||
|
|
||||||
if timestamp:
|
|
||||||
# Format timestamp if it's a standard format
|
|
||||||
try:
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
if "T" in str(timestamp): # ISO format
|
|
||||||
dt = datetime.datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
|
||||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
else:
|
|
||||||
formatted_time = str(timestamp)
|
|
||||||
parts.append(f"Date: {formatted_time}")
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
parts.append(f"Date: {timestamp}")
|
|
||||||
|
|
||||||
if url:
|
|
||||||
parts.append(f"URL: {url}")
|
|
||||||
|
|
||||||
# Tweet content
|
|
||||||
if text and self.include_tweet_content:
|
|
||||||
parts.append("")
|
|
||||||
parts.append("Content:")
|
|
||||||
parts.append(text)
|
|
||||||
|
|
||||||
# Metadata
|
|
||||||
if self.include_metadata and any([likes, retweets, replies]):
|
|
||||||
parts.append("")
|
|
||||||
parts.append("Engagement:")
|
|
||||||
if likes:
|
|
||||||
parts.append(f" Likes: {likes}")
|
|
||||||
if retweets:
|
|
||||||
parts.append(f" Retweets: {retweets}")
|
|
||||||
if replies:
|
|
||||||
parts.append(f" Replies: {replies}")
|
|
||||||
|
|
||||||
# Extract hashtags and mentions if available
|
|
||||||
hashtags = bookmark.get("hashtags", [])
|
|
||||||
mentions = bookmark.get("mentions", [])
|
|
||||||
|
|
||||||
if hashtags or mentions:
|
|
||||||
parts.append("")
|
|
||||||
if hashtags:
|
|
||||||
parts.append(f"Hashtags: {', '.join(hashtags)}")
|
|
||||||
if mentions:
|
|
||||||
parts.append(f"Mentions: {', '.join(mentions)}")
|
|
||||||
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
async def read_twitter_bookmarks(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Read Twitter bookmark data and return formatted text chunks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of formatted text chunks ready for LEANN indexing
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await self.start_mcp_server()
|
|
||||||
await self.initialize_mcp_connection()
|
|
||||||
|
|
||||||
print(f"Fetching up to {self.max_bookmarks} bookmarks...")
|
|
||||||
if self.username:
|
|
||||||
print(f"Filtering for user: @{self.username}")
|
|
||||||
|
|
||||||
bookmarks = await self.fetch_twitter_bookmarks()
|
|
||||||
|
|
||||||
if not bookmarks:
|
|
||||||
print("No bookmarks found")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Processing {len(bookmarks)} bookmarks...")
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
processed_count = 0
|
|
||||||
|
|
||||||
for bookmark in bookmarks:
|
|
||||||
try:
|
|
||||||
formatted_bookmark = self._format_bookmark(bookmark)
|
|
||||||
if formatted_bookmark.strip():
|
|
||||||
all_texts.append(formatted_bookmark)
|
|
||||||
processed_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to format bookmark: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Successfully processed {processed_count} bookmarks")
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await self.stop_mcp_server()
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
"""Async context manager entry."""
|
|
||||||
await self.start_mcp_server()
|
|
||||||
await self.initialize_mcp_connection()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
"""Async context manager exit."""
|
|
||||||
await self.stop_mcp_server()
|
|
||||||
@@ -1,190 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Twitter RAG Application with MCP Support
|
|
||||||
|
|
||||||
This application enables RAG (Retrieval-Augmented Generation) on Twitter bookmarks
|
|
||||||
by connecting to Twitter MCP servers to fetch live data and index it in LEANN.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --query "What articles did I bookmark about AI?"
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from apps.base_rag_example import BaseRAGExample
|
|
||||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
|
||||||
|
|
||||||
|
|
||||||
class TwitterMCPRAG(BaseRAGExample):
|
|
||||||
"""
|
|
||||||
RAG application for Twitter bookmarks via MCP servers.
|
|
||||||
|
|
||||||
This class provides a complete RAG pipeline for Twitter bookmark data, including
|
|
||||||
MCP server connection, data fetching, indexing, and interactive chat.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.default_index_name = "twitter_bookmarks"
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
|
||||||
"""Add Twitter MCP-specific arguments."""
|
|
||||||
parser.add_argument(
|
|
||||||
"--mcp-server",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Command to start the Twitter MCP server (e.g., 'twitter-mcp-server' or 'npx twitter-mcp-server')"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--username",
|
|
||||||
type=str,
|
|
||||||
help="Twitter username to filter bookmarks (without @)"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-bookmarks",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Maximum number of bookmarks to fetch (default: 1000)"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-tweet-content",
|
|
||||||
action="store_true",
|
|
||||||
help="Exclude tweet content, only include metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-metadata",
|
|
||||||
action="store_true",
|
|
||||||
help="Exclude engagement metadata (likes, retweets, etc.)"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--test-connection",
|
|
||||||
action="store_true",
|
|
||||||
help="Test MCP server connection and list available tools without indexing"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_mcp_connection(self, args) -> bool:
|
|
||||||
"""Test the MCP server connection and display available tools."""
|
|
||||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader = TwitterMCPReader(
|
|
||||||
mcp_server_command=args.mcp_server,
|
|
||||||
username=args.username,
|
|
||||||
include_tweet_content=not args.no_tweet_content,
|
|
||||||
include_metadata=not args.no_metadata,
|
|
||||||
max_bookmarks=args.max_bookmarks,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with reader:
|
|
||||||
tools = await reader.list_available_tools()
|
|
||||||
|
|
||||||
print(f"\n✅ Successfully connected to MCP server!")
|
|
||||||
print(f"Available tools ({len(tools)}):")
|
|
||||||
|
|
||||||
for i, tool in enumerate(tools, 1):
|
|
||||||
name = tool.get("name", "Unknown")
|
|
||||||
description = tool.get("description", "No description available")
|
|
||||||
print(f"\n{i}. {name}")
|
|
||||||
print(f" Description: {description[:100]}{'...' if len(description) > 100 else ''}")
|
|
||||||
|
|
||||||
# Show input schema if available
|
|
||||||
schema = tool.get("inputSchema", {})
|
|
||||||
if schema.get("properties"):
|
|
||||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
|
||||||
print(f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Failed to connect to MCP server: {e}")
|
|
||||||
print("\nTroubleshooting tips:")
|
|
||||||
print("1. Make sure the Twitter MCP server is installed and accessible")
|
|
||||||
print("2. Check if the server command is correct")
|
|
||||||
print("3. Ensure you have proper Twitter API credentials configured")
|
|
||||||
print("4. Verify your Twitter account has bookmarks to fetch")
|
|
||||||
print("5. Try running the MCP server command directly to test it")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def load_data(self, args) -> List[str]:
|
|
||||||
"""Load Twitter bookmarks via MCP server."""
|
|
||||||
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
|
|
||||||
|
|
||||||
if args.username:
|
|
||||||
print(f"Username filter: @{args.username}")
|
|
||||||
|
|
||||||
print(f"Max bookmarks: {args.max_bookmarks}")
|
|
||||||
print(f"Include tweet content: {not args.no_tweet_content}")
|
|
||||||
print(f"Include metadata: {not args.no_metadata}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader = TwitterMCPReader(
|
|
||||||
mcp_server_command=args.mcp_server,
|
|
||||||
username=args.username,
|
|
||||||
include_tweet_content=not args.no_tweet_content,
|
|
||||||
include_metadata=not args.no_metadata,
|
|
||||||
max_bookmarks=args.max_bookmarks,
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = await reader.read_twitter_bookmarks()
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
print("❌ No bookmarks found! This could mean:")
|
|
||||||
print("- You don't have any bookmarks on Twitter")
|
|
||||||
print("- The MCP server couldn't access your bookmarks")
|
|
||||||
print("- Authentication issues with Twitter API")
|
|
||||||
print("- The username filter didn't match any bookmarks")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"✅ Successfully loaded {len(texts)} bookmarks from Twitter")
|
|
||||||
|
|
||||||
# Show sample of what was loaded
|
|
||||||
if texts:
|
|
||||||
sample_text = texts[0][:300] + "..." if len(texts[0]) > 300 else texts[0]
|
|
||||||
print(f"\nSample bookmark:")
|
|
||||||
print("-" * 50)
|
|
||||||
print(sample_text)
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
return texts
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error loading Twitter bookmarks: {e}")
|
|
||||||
print("\nThis might be due to:")
|
|
||||||
print("- MCP server connection issues")
|
|
||||||
print("- Twitter API authentication problems")
|
|
||||||
print("- Network connectivity issues")
|
|
||||||
print("- Rate limiting from Twitter API")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""Main entry point with MCP connection testing."""
|
|
||||||
args = self.parser.parse_args()
|
|
||||||
|
|
||||||
# Test connection if requested
|
|
||||||
if args.test_connection:
|
|
||||||
success = await self.test_mcp_connection(args)
|
|
||||||
if not success:
|
|
||||||
return
|
|
||||||
print(f"\n🎉 MCP server is working! You can now run without --test-connection to start indexing.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Run the standard RAG pipeline
|
|
||||||
await super().run()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main entry point for the Twitter MCP RAG application."""
|
|
||||||
app = TwitterMCPRAG()
|
|
||||||
await app.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -83,6 +83,81 @@ ollama pull nomic-embed-text
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## Local & Remote Inference Endpoints
|
||||||
|
|
||||||
|
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
|
||||||
|
|
||||||
|
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables.
|
||||||
|
|
||||||
|
### One-Time Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
|
||||||
|
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
|
||||||
|
export OPENAI_BASE_URL="http://localhost:1234/v1"
|
||||||
|
|
||||||
|
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
|
||||||
|
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
|
||||||
|
```
|
||||||
|
|
||||||
|
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
|
||||||
|
|
||||||
|
### Passing Hosts Per Command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build an index with a remote embedding server
|
||||||
|
leann build my-notes \
|
||||||
|
--docs ./notes \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-qwen3-embedding-0.6b \
|
||||||
|
--embedding-api-base http://192.168.1.50:1234/v1 \
|
||||||
|
--embedding-api-key local-dev-key
|
||||||
|
|
||||||
|
# Query using a local LM Studio instance via OpenAI-compatible API
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm openai \
|
||||||
|
--llm-model qwen3-8b \
|
||||||
|
--api-base http://localhost:1234/v1 \
|
||||||
|
--api-key local-dev-key
|
||||||
|
|
||||||
|
# Query an Ollama instance running on another box
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm ollama \
|
||||||
|
--llm-model qwen3:14b \
|
||||||
|
--host http://192.168.1.101:11434
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
|
||||||
|
|
||||||
|
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
|
||||||
|
- Configure router or cloud provider port forwarding.
|
||||||
|
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
|
||||||
|
|
||||||
|
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
|
||||||
|
|
||||||
|
**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials.
|
||||||
|
|
||||||
|
### Python API Usage
|
||||||
|
|
||||||
|
You can pass the same configuration from Python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_model="text-embedding-qwen3-embedding-0.6b",
|
||||||
|
embedding_options={
|
||||||
|
"base_url": "http://192.168.1.50:1234/v1",
|
||||||
|
"api_key": "local-dev-key",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.build_index("./indexes/my-notes", chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
|
|||||||
@@ -1,181 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
MCP Integration Examples for LEANN
|
|
||||||
|
|
||||||
This script demonstrates how to use LEANN with different MCP servers for
|
|
||||||
RAG on various platforms like Slack and Twitter.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
1. Slack message RAG via MCP
|
|
||||||
2. Twitter bookmark RAG via MCP
|
|
||||||
3. Testing MCP server connections
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add the parent directory to the path so we can import from apps
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from apps.slack_rag import SlackMCPRAG
|
|
||||||
from apps.twitter_rag import TwitterMCPRAG
|
|
||||||
|
|
||||||
|
|
||||||
async def demo_slack_mcp():
|
|
||||||
"""Demonstrate Slack MCP integration."""
|
|
||||||
print("=" * 60)
|
|
||||||
print("🔥 Slack MCP RAG Demo")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
print("\n1. Testing Slack MCP server connection...")
|
|
||||||
|
|
||||||
# This would typically use a real MCP server command
|
|
||||||
# For demo purposes, we show what the command would look like
|
|
||||||
slack_app = SlackMCPRAG()
|
|
||||||
|
|
||||||
# Simulate command line arguments for testing
|
|
||||||
class MockArgs:
|
|
||||||
mcp_server = "slack-mcp-server" # This would be the actual MCP server command
|
|
||||||
workspace_name = "my-workspace"
|
|
||||||
channels = ["general", "random", "dev-team"]
|
|
||||||
no_concatenate_conversations = False
|
|
||||||
max_messages_per_channel = 50
|
|
||||||
test_connection = True
|
|
||||||
|
|
||||||
print(f"MCP Server Command: {MockArgs.mcp_server}")
|
|
||||||
print(f"Workspace: {MockArgs.workspace_name}")
|
|
||||||
print(f"Channels: {', '.join(MockArgs.channels)}")
|
|
||||||
|
|
||||||
# In a real scenario, you would run:
|
|
||||||
# success = await slack_app.test_mcp_connection(MockArgs)
|
|
||||||
|
|
||||||
print("\n📝 Example usage:")
|
|
||||||
print("python -m apps.slack_rag \\")
|
|
||||||
print(" --mcp-server 'slack-mcp-server' \\")
|
|
||||||
print(" --workspace-name 'my-team' \\")
|
|
||||||
print(" --channels general dev-team \\")
|
|
||||||
print(" --test-connection")
|
|
||||||
|
|
||||||
print("\n🔍 After indexing, you could query:")
|
|
||||||
print("- 'What did the team discuss about the project deadline?'")
|
|
||||||
print("- 'Find messages about the new feature launch'")
|
|
||||||
print("- 'Show me conversations about budget planning'")
|
|
||||||
|
|
||||||
|
|
||||||
async def demo_twitter_mcp():
|
|
||||||
"""Demonstrate Twitter MCP integration."""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🐦 Twitter MCP RAG Demo")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
print("\n1. Testing Twitter MCP server connection...")
|
|
||||||
|
|
||||||
twitter_app = TwitterMCPRAG()
|
|
||||||
|
|
||||||
class MockArgs:
|
|
||||||
mcp_server = "twitter-mcp-server"
|
|
||||||
username = None # Fetch all bookmarks
|
|
||||||
max_bookmarks = 500
|
|
||||||
no_tweet_content = False
|
|
||||||
no_metadata = False
|
|
||||||
test_connection = True
|
|
||||||
|
|
||||||
print(f"MCP Server Command: {MockArgs.mcp_server}")
|
|
||||||
print(f"Max Bookmarks: {MockArgs.max_bookmarks}")
|
|
||||||
print(f"Include Content: {not MockArgs.no_tweet_content}")
|
|
||||||
print(f"Include Metadata: {not MockArgs.no_metadata}")
|
|
||||||
|
|
||||||
print("\n📝 Example usage:")
|
|
||||||
print("python -m apps.twitter_rag \\")
|
|
||||||
print(" --mcp-server 'twitter-mcp-server' \\")
|
|
||||||
print(" --max-bookmarks 1000 \\")
|
|
||||||
print(" --test-connection")
|
|
||||||
|
|
||||||
print("\n🔍 After indexing, you could query:")
|
|
||||||
print("- 'What AI articles did I bookmark last month?'")
|
|
||||||
print("- 'Find tweets about machine learning techniques'")
|
|
||||||
print("- 'Show me bookmarked threads about startup advice'")
|
|
||||||
|
|
||||||
|
|
||||||
async def show_mcp_server_setup():
|
|
||||||
"""Show how to set up MCP servers."""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("⚙️ MCP Server Setup Guide")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
print("\n🔧 Setting up Slack MCP Server:")
|
|
||||||
print("1. Install a Slack MCP server (example commands):")
|
|
||||||
print(" npm install -g slack-mcp-server")
|
|
||||||
print(" # OR")
|
|
||||||
print(" pip install slack-mcp-server")
|
|
||||||
|
|
||||||
print("\n2. Configure Slack credentials:")
|
|
||||||
print(" export SLACK_BOT_TOKEN='xoxb-your-bot-token'")
|
|
||||||
print(" export SLACK_APP_TOKEN='xapp-your-app-token'")
|
|
||||||
|
|
||||||
print("\n3. Test the server:")
|
|
||||||
print(" slack-mcp-server --help")
|
|
||||||
|
|
||||||
print("\n🔧 Setting up Twitter MCP Server:")
|
|
||||||
print("1. Install a Twitter MCP server:")
|
|
||||||
print(" npm install -g twitter-mcp-server")
|
|
||||||
print(" # OR")
|
|
||||||
print(" pip install twitter-mcp-server")
|
|
||||||
|
|
||||||
print("\n2. Configure Twitter API credentials:")
|
|
||||||
print(" export TWITTER_API_KEY='your-api-key'")
|
|
||||||
print(" export TWITTER_API_SECRET='your-api-secret'")
|
|
||||||
print(" export TWITTER_ACCESS_TOKEN='your-access-token'")
|
|
||||||
print(" export TWITTER_ACCESS_TOKEN_SECRET='your-access-token-secret'")
|
|
||||||
|
|
||||||
print("\n3. Test the server:")
|
|
||||||
print(" twitter-mcp-server --help")
|
|
||||||
|
|
||||||
|
|
||||||
async def show_integration_benefits():
|
|
||||||
"""Show the benefits of MCP integration."""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🌟 Benefits of MCP Integration")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
benefits = [
|
|
||||||
("🔄 Live Data Access", "Fetch real-time data from platforms without manual exports"),
|
|
||||||
("🔌 Standardized Protocol", "Use any MCP-compatible server with minimal code changes"),
|
|
||||||
("🚀 Easy Extension", "Add new platforms by implementing MCP readers"),
|
|
||||||
("🔒 Secure Access", "MCP servers handle authentication and API management"),
|
|
||||||
("📊 Rich Metadata", "Access full platform metadata (timestamps, engagement, etc.)"),
|
|
||||||
("⚡ Efficient Processing", "Stream data directly into LEANN without intermediate files"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for title, description in benefits:
|
|
||||||
print(f"\n{title}")
|
|
||||||
print(f" {description}")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main demo function."""
|
|
||||||
print("🎯 LEANN MCP Integration Examples")
|
|
||||||
print("This demo shows how to integrate LEANN with MCP servers for various platforms.")
|
|
||||||
|
|
||||||
await demo_slack_mcp()
|
|
||||||
await demo_twitter_mcp()
|
|
||||||
await show_mcp_server_setup()
|
|
||||||
await show_integration_benefits()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("✨ Next Steps")
|
|
||||||
print("=" * 60)
|
|
||||||
print("1. Install and configure MCP servers for your platforms")
|
|
||||||
print("2. Test connections using --test-connection flag")
|
|
||||||
print("3. Run indexing to build your RAG knowledge base")
|
|
||||||
print("4. Start querying your personal data!")
|
|
||||||
|
|
||||||
print("\n📚 For more information:")
|
|
||||||
print("- Check the README for detailed setup instructions")
|
|
||||||
print("- Look at the apps/slack_rag.py and apps/twitter_rag.py for implementation details")
|
|
||||||
print("- Explore other MCP servers for additional platforms")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -10,7 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
@@ -32,6 +32,16 @@ if not logger.handlers:
|
|||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||||
|
try:
|
||||||
|
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||||
|
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||||
|
PROVIDER_OPTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def create_diskann_embedding_server(
|
def create_diskann_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: Optional[str] = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
@@ -181,7 +191,12 @@ def create_diskann_embedding_server(
|
|||||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||||
|
|
||||||
# Process embeddings using unified computation
|
# Process embeddings using unified computation
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
@@ -296,7 +311,12 @@ def create_diskann_embedding_server(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Process the request
|
# Process the request
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -45,6 +45,15 @@ if log_path:
|
|||||||
|
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
|
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||||
|
try:
|
||||||
|
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||||
|
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||||
|
PROVIDER_OPTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: Optional[str] = None,
|
||||||
@@ -151,7 +160,12 @@ def create_hnsw_embedding_server(
|
|||||||
):
|
):
|
||||||
last_request_type = "text"
|
last_request_type = "text"
|
||||||
last_request_length = len(request)
|
last_request_length = len(request)
|
||||||
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
request,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
@@ -200,7 +214,10 @@ def create_hnsw_embedding_server(
|
|||||||
if texts:
|
if texts:
|
||||||
try:
|
try:
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
texts, model_name, mode=embedding_mode
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
@@ -265,7 +282,12 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
if texts:
|
if texts:
|
||||||
try:
|
try:
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ def compute_embeddings(
|
|||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
is_build=False,
|
is_build=False,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
@@ -72,6 +73,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -278,6 +280,7 @@ class LeannBuilder:
|
|||||||
embedding_model: str = "facebook/contriever",
|
embedding_model: str = "facebook/contriever",
|
||||||
dimensions: Optional[int] = None,
|
dimensions: Optional[int] = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
embedding_options: Optional[dict[str, Any]] = None,
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
):
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
@@ -300,6 +303,7 @@ class LeannBuilder:
|
|||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
|
self.embedding_options = embedding_options or {}
|
||||||
|
|
||||||
# Check if we need to use cosine distance for normalized embeddings
|
# Check if we need to use cosine distance for normalized embeddings
|
||||||
normalized_embeddings_models = {
|
normalized_embeddings_models = {
|
||||||
@@ -407,6 +411,7 @@ class LeannBuilder:
|
|||||||
self.embedding_model,
|
self.embedding_model,
|
||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)[0]
|
)[0]
|
||||||
)
|
)
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
@@ -446,6 +451,7 @@ class LeannBuilder:
|
|||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
is_build=True,
|
is_build=True,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
@@ -472,6 +478,9 @@ class LeannBuilder:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.embedding_options:
|
||||||
|
meta_data["embedding_options"] = self.embedding_options
|
||||||
|
|
||||||
# Add storage status flags for HNSW backend
|
# Add storage status flags for HNSW backend
|
||||||
if self.backend_name == "hnsw":
|
if self.backend_name == "hnsw":
|
||||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
@@ -592,6 +601,9 @@ class LeannBuilder:
|
|||||||
"embeddings_source": str(embeddings_file),
|
"embeddings_source": str(embeddings_file),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.embedding_options:
|
||||||
|
meta_data["embedding_options"] = self.embedding_options
|
||||||
|
|
||||||
# Add storage status flags for HNSW backend
|
# Add storage status flags for HNSW backend
|
||||||
if self.backend_name == "hnsw":
|
if self.backend_name == "hnsw":
|
||||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
@@ -673,6 +685,7 @@ class LeannBuilder:
|
|||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
is_build=True,
|
is_build=True,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_dim = embeddings.shape[1]
|
embedding_dim = embeddings.shape[1]
|
||||||
@@ -771,6 +784,7 @@ class LeannSearcher:
|
|||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
|
self.embedding_options = self.meta_data.get("embedding_options", {})
|
||||||
# Delegate portability handling to PassageManager
|
# Delegate portability handling to PassageManager
|
||||||
self.passage_manager = PassageManager(
|
self.passage_manager = PassageManager(
|
||||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||||
@@ -782,6 +796,8 @@ class LeannSearcher:
|
|||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = enable_warmup
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
|
if self.embedding_options:
|
||||||
|
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
index_path, **final_kwargs
|
index_path, **final_kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(
|
def validate_model_and_suggest(
|
||||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
model_name: str, llm_type: str, host: Optional[str] = None
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models(host)
|
resolved_host = resolve_ollama_host(host)
|
||||||
|
available_models = check_ollama_models(resolved_host)
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
|
||||||
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
|
|||||||
class OllamaChat(LLMInterface):
|
class OllamaChat(LLMInterface):
|
||||||
"""LLM interface for Ollama models."""
|
"""LLM interface for Ollama models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.host = host
|
self.host = resolve_ollama_host(host)
|
||||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# Check if the Ollama server is responsive
|
# Check if the Ollama server is responsive
|
||||||
if host:
|
if self.host:
|
||||||
requests.get(host)
|
requests.get(self.host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
model_error = validate_model_and_suggest(model, "ollama", self.host)
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
|
|||||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
logger.error(
|
||||||
|
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||||
|
)
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
@@ -737,21 +742,31 @@ class GeminiChat(LLMInterface):
|
|||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.base_url = resolve_openai_base_url(base_url)
|
||||||
|
self.api_key = resolve_openai_api_key(api_key)
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
logger.info(
|
||||||
|
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
self.client = openai.OpenAI(api_key=self.api_key)
|
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||||
@@ -841,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
return OllamaChat(
|
return OllamaChat(
|
||||||
model=model or "llama3:8b",
|
model=model or "llama3:8b",
|
||||||
host=llm_config.get("host", "http://localhost:11434"),
|
host=llm_config.get("host"),
|
||||||
)
|
)
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
return OpenAIChat(
|
||||||
|
model=model or "gpt-4o",
|
||||||
|
api_key=llm_config.get("api_key"),
|
||||||
|
base_url=llm_config.get("base_url"),
|
||||||
|
)
|
||||||
elif llm_type == "gemini":
|
elif llm_type == "gemini":
|
||||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .registry import register_project_directory
|
from .registry import register_project_directory
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||||
@@ -123,6 +124,24 @@ Examples:
|
|||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode (default: sentence-transformers)",
|
help="Embedding backend mode (default: sentence-transformers)",
|
||||||
)
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||||
)
|
)
|
||||||
@@ -238,6 +257,11 @@ Examples:
|
|||||||
# Ask command
|
# Ask command
|
||||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
ask_parser.add_argument("index_name", help="Index name")
|
ask_parser.add_argument("index_name", help="Index name")
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"query",
|
||||||
|
nargs="?",
|
||||||
|
help="Question to ask (omit for prompt or when using --interactive)",
|
||||||
|
)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--llm",
|
"--llm",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -248,7 +272,12 @@ Examples:
|
|||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
||||||
)
|
)
|
||||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
ask_parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
|
)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
||||||
)
|
)
|
||||||
@@ -277,6 +306,18 @@ Examples:
|
|||||||
default=None,
|
default=None,
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
)
|
)
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
|
||||||
|
)
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
subparsers.add_parser("list", help="List all indexes")
|
subparsers.add_parser("list", help="List all indexes")
|
||||||
@@ -1325,10 +1366,20 @@ Examples:
|
|||||||
|
|
||||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend,
|
backend_name=args.backend,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
graph_degree=args.graph_degree,
|
graph_degree=args.graph_degree,
|
||||||
complexity=args.complexity,
|
complexity=args.complexity,
|
||||||
is_compact=args.compact,
|
is_compact=args.compact,
|
||||||
@@ -1476,11 +1527,38 @@ Examples:
|
|||||||
|
|
||||||
llm_config = {"type": args.llm, "model": args.model}
|
llm_config = {"type": args.llm, "model": args.model}
|
||||||
if args.llm == "ollama":
|
if args.llm == "ollama":
|
||||||
llm_config["host"] = args.host
|
llm_config["host"] = resolve_ollama_host(args.host)
|
||||||
|
elif args.llm == "openai":
|
||||||
|
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
|
||||||
|
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||||
|
if resolved_api_key:
|
||||||
|
llm_config["api_key"] = resolved_api_key
|
||||||
|
|
||||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||||
|
|
||||||
|
llm_kwargs: dict[str, Any] = {}
|
||||||
|
if args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
def _ask_once(prompt: str) -> None:
|
||||||
|
response = chat.ask(
|
||||||
|
prompt,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
)
|
||||||
|
print(f"LEANN: {response}")
|
||||||
|
|
||||||
|
initial_query = (args.query or "").strip()
|
||||||
|
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
|
if initial_query:
|
||||||
|
_ask_once(initial_query)
|
||||||
|
|
||||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||||
print("=" * 40)
|
print("=" * 40)
|
||||||
|
|
||||||
@@ -1493,41 +1571,14 @@ Examples:
|
|||||||
if not user_input:
|
if not user_input:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
_ask_once(user_input)
|
||||||
llm_kwargs = {}
|
|
||||||
if args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
|
||||||
user_input,
|
|
||||||
top_k=args.top_k,
|
|
||||||
complexity=args.complexity,
|
|
||||||
beam_width=args.beam_width,
|
|
||||||
prune_ratio=args.prune_ratio,
|
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
|
||||||
pruning_strategy=args.pruning_strategy,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
|
||||||
print(f"LEANN: {response}")
|
|
||||||
else:
|
else:
|
||||||
query = input("Enter your question: ").strip()
|
query = initial_query or input("Enter your question: ").strip()
|
||||||
if query:
|
if not query:
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
print("No question provided. Exiting.")
|
||||||
llm_kwargs = {}
|
return
|
||||||
if args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
_ask_once(query)
|
||||||
query,
|
|
||||||
top_k=args.top_k,
|
|
||||||
complexity=args.complexity,
|
|
||||||
beam_width=args.beam_width,
|
|
||||||
prune_ratio=args.prune_ratio,
|
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
|
||||||
pruning_strategy=args.pruning_strategy,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
|
||||||
print(f"LEANN: {response}")
|
|
||||||
|
|
||||||
async def run(self, args=None):
|
async def run(self, args=None):
|
||||||
parser = self.create_parser()
|
parser = self.create_parser()
|
||||||
|
|||||||
@@ -7,11 +7,13 @@ Preserves all optimization parameters to ensure performance
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -31,6 +33,7 @@ def compute_embeddings(
|
|||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
manual_tokenize: bool = False,
|
manual_tokenize: bool = False,
|
||||||
max_length: int = 512,
|
max_length: int = 512,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Unified embedding computation entry point
|
Unified embedding computation entry point
|
||||||
@@ -46,6 +49,8 @@ def compute_embeddings(
|
|||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
"""
|
"""
|
||||||
|
provider_options = provider_options or {}
|
||||||
|
|
||||||
if mode == "sentence-transformers":
|
if mode == "sentence-transformers":
|
||||||
return compute_embeddings_sentence_transformers(
|
return compute_embeddings_sentence_transformers(
|
||||||
texts,
|
texts,
|
||||||
@@ -57,11 +62,21 @@ def compute_embeddings(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
base_url=provider_options.get("base_url"),
|
||||||
|
api_key=provider_options.get("api_key"),
|
||||||
|
)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
elif mode == "ollama":
|
elif mode == "ollama":
|
||||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
return compute_embeddings_ollama(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
is_build=is_build,
|
||||||
|
host=provider_options.get("host"),
|
||||||
|
)
|
||||||
elif mode == "gemini":
|
elif mode == "gemini":
|
||||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||||
else:
|
else:
|
||||||
@@ -353,12 +368,15 @@ def compute_embeddings_sentence_transformers(
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
def compute_embeddings_openai(
|
||||||
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
try:
|
try:
|
||||||
import os
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"OpenAI package not installed: {e}")
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
@@ -373,16 +391,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||||
)
|
)
|
||||||
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
resolved_base_url = resolve_openai_base_url(base_url)
|
||||||
if not api_key:
|
resolved_api_key = resolve_openai_api_key(api_key)
|
||||||
|
|
||||||
|
if not resolved_api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
# Cache OpenAI client
|
# Cache OpenAI client
|
||||||
cache_key = "openai_client"
|
cache_key = f"openai_client::{resolved_base_url}"
|
||||||
if cache_key in _model_cache:
|
if cache_key in _model_cache:
|
||||||
client = _model_cache[cache_key]
|
client = _model_cache[cache_key]
|
||||||
else:
|
else:
|
||||||
client = openai.OpenAI(api_key=api_key)
|
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||||
_model_cache[cache_key] = client
|
_model_cache[cache_key] = client
|
||||||
logger.info("OpenAI client cached")
|
logger.info("OpenAI client cached")
|
||||||
|
|
||||||
@@ -507,7 +527,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings_ollama(
|
def compute_embeddings_ollama(
|
||||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
is_build: bool = False,
|
||||||
|
host: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with simplified batch processing.
|
Compute embeddings using Ollama API with simplified batch processing.
|
||||||
@@ -518,7 +541,7 @@ def compute_embeddings_ollama(
|
|||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
host: Ollama host URL (default: http://localhost:11434)
|
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
@@ -533,17 +556,19 @@ def compute_embeddings_ollama(
|
|||||||
if not texts:
|
if not texts:
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
|
|
||||||
|
resolved_host = resolve_ollama_host(host)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Ollama is running
|
# Check if Ollama is running
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{host}/api/version", timeout=5)
|
response = requests.get(f"{resolved_host}/api/version", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
|
||||||
"Please ensure Ollama is running:\n"
|
"Please ensure Ollama is running:\n"
|
||||||
" • macOS/Linux: ollama serve\n"
|
" • macOS/Linux: ollama serve\n"
|
||||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||||
@@ -555,7 +580,7 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
# Check if model exists and provide helpful suggestions
|
# Check if model exists and provide helpful suggestions
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
models = response.json()
|
models = response.json()
|
||||||
model_names = [model["name"] for model in models.get("models", [])]
|
model_names = [model["name"] for model in models.get("models", [])]
|
||||||
@@ -618,7 +643,9 @@ def compute_embeddings_ollama(
|
|||||||
# Verify the model supports embeddings by testing it
|
# Verify the model supports embeddings by testing it
|
||||||
try:
|
try:
|
||||||
test_response = requests.post(
|
test_response = requests.post(
|
||||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
f"{resolved_host}/api/embeddings",
|
||||||
|
json={"model": model_name, "prompt": "test"},
|
||||||
|
timeout=10,
|
||||||
)
|
)
|
||||||
if test_response.status_code != 200:
|
if test_response.status_code != 200:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
@@ -665,7 +692,7 @@ def compute_embeddings_ollama(
|
|||||||
while retry_count < max_retries:
|
while retry_count < max_retries:
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{host}/api/embeddings",
|
f"{resolved_host}/api/embeddings",
|
||||||
json={"model": model_name, "prompt": truncated_text},
|
json={"model": model_name, "prompt": truncated_text},
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from .settings import encode_provider_options
|
||||||
|
|
||||||
# Lightweight, self-contained server manager with no cross-process inspection
|
# Lightweight, self-contained server manager with no cross-process inspection
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
@@ -82,16 +84,40 @@ class EmbeddingServerManager:
|
|||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||||
|
provider_options = kwargs.pop("provider_options", None)
|
||||||
|
|
||||||
|
config_signature = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
|
||||||
# If this manager already has a live server, just reuse it
|
# If this manager already has a live server, just reuse it
|
||||||
if self.server_process and self.server_process.poll() is None and self.server_port:
|
if (
|
||||||
|
self.server_process
|
||||||
|
and self.server_process.poll() is None
|
||||||
|
and self.server_port
|
||||||
|
and self._server_config == config_signature
|
||||||
|
):
|
||||||
logger.info("Reusing in-process server")
|
logger.info("Reusing in-process server")
|
||||||
return True, self.server_port
|
return True, self.server_port
|
||||||
|
|
||||||
|
# Configuration changed, stop existing server before starting a new one
|
||||||
|
if self.server_process and self.server_process.poll() is None:
|
||||||
|
logger.info("Existing server configuration differs; restarting embedding server")
|
||||||
|
self.stop_server()
|
||||||
|
|
||||||
# For Colab environment, use a different strategy
|
# For Colab environment, use a different strategy
|
||||||
if _is_colab_environment():
|
if _is_colab_environment():
|
||||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
return self._start_server_colab(
|
||||||
|
port,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=provider_options,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Always pick a fresh available port
|
# Always pick a fresh available port
|
||||||
try:
|
try:
|
||||||
@@ -101,13 +127,21 @@ class EmbeddingServerManager:
|
|||||||
return False, port
|
return False, port
|
||||||
|
|
||||||
# Start a new server
|
# Start a new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(
|
||||||
|
actual_port,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=provider_options,
|
||||||
|
config_signature=config_signature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _start_server_colab(
|
def _start_server_colab(
|
||||||
self,
|
self,
|
||||||
port: int,
|
port: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start server with Colab-specific configuration."""
|
"""Start server with Colab-specific configuration."""
|
||||||
@@ -125,8 +159,20 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# In Colab, we'll use a more direct approach
|
# In Colab, we'll use a more direct approach
|
||||||
self._launch_server_process_colab(command, actual_port)
|
self._launch_server_process_colab(
|
||||||
return self._wait_for_server_ready_colab(actual_port)
|
command,
|
||||||
|
actual_port,
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||||
|
if started:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
return started, ready_port
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||||
return False, actual_port
|
return False, actual_port
|
||||||
@@ -134,7 +180,13 @@ class EmbeddingServerManager:
|
|||||||
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self,
|
||||||
|
port: int,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
config_signature: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start a new embedding server on the given port."""
|
"""Start a new embedding server on the given port."""
|
||||||
logger.info(f"Starting embedding server on port {port}...")
|
logger.info(f"Starting embedding server on port {port}...")
|
||||||
@@ -142,8 +194,20 @@ class EmbeddingServerManager:
|
|||||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._launch_server_process(command, port)
|
self._launch_server_process(
|
||||||
return self._wait_for_server_ready(port)
|
command,
|
||||||
|
port,
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
started, ready_port = self._wait_for_server_ready(port)
|
||||||
|
if started:
|
||||||
|
self._server_config = config_signature or {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
return started, ready_port
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start embedding server: {e}")
|
logger.error(f"Failed to start embedding server: {e}")
|
||||||
return False, port
|
return False, port
|
||||||
@@ -173,7 +237,12 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
return command
|
return command
|
||||||
|
|
||||||
def _launch_server_process(self, command: list, port: int) -> None:
|
def _launch_server_process(
|
||||||
|
self,
|
||||||
|
command: list,
|
||||||
|
port: int,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
"""Launch the server process."""
|
"""Launch the server process."""
|
||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
logger.info(f"Command: {' '.join(command)}")
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
@@ -193,14 +262,20 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
# Start embedding server subprocess
|
# Start embedding server subprocess
|
||||||
logger.info(f"Starting server process with command: {' '.join(command)}")
|
logger.info(f"Starting server process with command: {' '.join(command)}")
|
||||||
|
env = os.environ.copy()
|
||||||
|
encoded_options = encode_provider_options(provider_options)
|
||||||
|
if encoded_options:
|
||||||
|
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=stdout_target,
|
stdout=stdout_target,
|
||||||
stderr=stderr_target,
|
stderr=stderr_target,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
# Record config for in-process reuse
|
# Record config for in-process reuse (best effort; refined later when ready)
|
||||||
try:
|
try:
|
||||||
self._server_config = {
|
self._server_config = {
|
||||||
"model_name": command[command.index("--model-name") + 1]
|
"model_name": command[command.index("--model-name") + 1]
|
||||||
@@ -212,12 +287,14 @@ class EmbeddingServerManager:
|
|||||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||||
if "--embedding-mode" in command
|
if "--embedding-mode" in command
|
||||||
else "sentence-transformers",
|
else "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
self._server_config = {
|
self._server_config = {
|
||||||
"model_name": "",
|
"model_name": "",
|
||||||
"passages_file": "",
|
"passages_file": "",
|
||||||
"embedding_mode": "sentence-transformers",
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
}
|
}
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
@@ -322,16 +399,27 @@ class EmbeddingServerManager:
|
|||||||
# Removed: cross-process adoption no longer supported
|
# Removed: cross-process adoption no longer supported
|
||||||
return
|
return
|
||||||
|
|
||||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
def _launch_server_process_colab(
|
||||||
|
self,
|
||||||
|
command: list,
|
||||||
|
port: int,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
"""Launch the server process with Colab-specific settings."""
|
"""Launch the server process with Colab-specific settings."""
|
||||||
logger.info(f"Colab Command: {' '.join(command)}")
|
logger.info(f"Colab Command: {' '.join(command)}")
|
||||||
|
|
||||||
# In Colab, we need to be more careful about process management
|
# In Colab, we need to be more careful about process management
|
||||||
|
env = os.environ.copy()
|
||||||
|
encoded_options = encode_provider_options(provider_options)
|
||||||
|
if encoded_options:
|
||||||
|
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
text=True,
|
text=True,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||||
@@ -345,6 +433,7 @@ class EmbeddingServerManager:
|
|||||||
"model_name": "",
|
"model_name": "",
|
||||||
"passages_file": "",
|
"passages_file": "",
|
||||||
"embedding_mode": "sentence-transformers",
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
||||||
|
|
||||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
self.embedding_options = self.meta.get("embedding_options", {})
|
||||||
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
backend_module_name=backend_module_name,
|
backend_module_name=backend_module_name,
|
||||||
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
distance_metric=distance_metric,
|
distance_metric=distance_metric,
|
||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||||
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
from .embedding_compute import compute_embeddings
|
from .embedding_compute import compute_embeddings
|
||||||
|
|
||||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
return compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.embedding_model,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
|
)
|
||||||
|
|
||||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||||
"""Compute embeddings using the ZMQ embedding server."""
|
"""Compute embeddings using the ZMQ embedding server."""
|
||||||
|
|||||||
74
packages/leann-core/src/leann/settings.py
Normal file
74
packages/leann-core/src/leann/settings.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Runtime configuration helpers for LEANN."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||||
|
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||||
|
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_url(value: str) -> str:
|
||||||
|
"""Normalize URL strings by stripping trailing slashes."""
|
||||||
|
|
||||||
|
return value.rstrip("/") if value else value
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_ollama_host(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the Ollama-compatible endpoint to use."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_LOCAL_LLM_HOST"),
|
||||||
|
os.getenv("LEANN_OLLAMA_HOST"),
|
||||||
|
os.getenv("OLLAMA_HOST"),
|
||||||
|
os.getenv("LOCAL_LLM_ENDPOINT"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_OLLAMA_HOST)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the base URL for OpenAI-compatible services."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_OPENAI_BASE_URL"),
|
||||||
|
os.getenv("OPENAI_BASE_URL"),
|
||||||
|
os.getenv("LOCAL_OPENAI_BASE_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||||
|
"""Resolve the API key for OpenAI-compatible services."""
|
||||||
|
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
|
||||||
|
return os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||||
|
"""Serialize provider options for child processes."""
|
||||||
|
|
||||||
|
if not options:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.dumps(options)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
# Fall back to empty payload if serialization fails
|
||||||
|
return None
|
||||||
14
tests/test_cli_ask.py
Normal file
14
tests/test_cli_ask.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from leann.cli import LeannCLI
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
|
||||||
|
|
||||||
|
assert args.command == "ask"
|
||||||
|
assert args.index_name == "my-docs"
|
||||||
|
assert args.query == "Where are prompts configured?"
|
||||||
@@ -1,209 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for MCP integration implementations.
|
|
||||||
|
|
||||||
This script tests the basic functionality of the MCP readers and RAG applications
|
|
||||||
without requiring actual MCP servers to be running.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import asyncio
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add the parent directory to the path so we can import from apps
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
|
||||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
|
||||||
from apps.slack_rag import SlackMCPRAG
|
|
||||||
from apps.twitter_rag import TwitterMCPRAG
|
|
||||||
|
|
||||||
|
|
||||||
def test_slack_reader_initialization():
|
|
||||||
"""Test that SlackMCPReader can be initialized with various parameters."""
|
|
||||||
print("Testing SlackMCPReader initialization...")
|
|
||||||
|
|
||||||
# Test basic initialization
|
|
||||||
reader = SlackMCPReader("slack-mcp-server")
|
|
||||||
assert reader.mcp_server_command == "slack-mcp-server"
|
|
||||||
assert reader.concatenate_conversations == True
|
|
||||||
assert reader.max_messages_per_conversation == 100
|
|
||||||
|
|
||||||
# Test with custom parameters
|
|
||||||
reader = SlackMCPReader(
|
|
||||||
"custom-slack-server",
|
|
||||||
workspace_name="test-workspace",
|
|
||||||
concatenate_conversations=False,
|
|
||||||
max_messages_per_conversation=50
|
|
||||||
)
|
|
||||||
assert reader.workspace_name == "test-workspace"
|
|
||||||
assert reader.concatenate_conversations == False
|
|
||||||
assert reader.max_messages_per_conversation == 50
|
|
||||||
|
|
||||||
print("✅ SlackMCPReader initialization tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_twitter_reader_initialization():
|
|
||||||
"""Test that TwitterMCPReader can be initialized with various parameters."""
|
|
||||||
print("Testing TwitterMCPReader initialization...")
|
|
||||||
|
|
||||||
# Test basic initialization
|
|
||||||
reader = TwitterMCPReader("twitter-mcp-server")
|
|
||||||
assert reader.mcp_server_command == "twitter-mcp-server"
|
|
||||||
assert reader.include_tweet_content == True
|
|
||||||
assert reader.include_metadata == True
|
|
||||||
assert reader.max_bookmarks == 1000
|
|
||||||
|
|
||||||
# Test with custom parameters
|
|
||||||
reader = TwitterMCPReader(
|
|
||||||
"custom-twitter-server",
|
|
||||||
username="testuser",
|
|
||||||
include_tweet_content=False,
|
|
||||||
include_metadata=False,
|
|
||||||
max_bookmarks=500
|
|
||||||
)
|
|
||||||
assert reader.username == "testuser"
|
|
||||||
assert reader.include_tweet_content == False
|
|
||||||
assert reader.include_metadata == False
|
|
||||||
assert reader.max_bookmarks == 500
|
|
||||||
|
|
||||||
print("✅ TwitterMCPReader initialization tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_slack_message_formatting():
|
|
||||||
"""Test Slack message formatting functionality."""
|
|
||||||
print("Testing Slack message formatting...")
|
|
||||||
|
|
||||||
reader = SlackMCPReader("slack-mcp-server")
|
|
||||||
|
|
||||||
# Test basic message formatting
|
|
||||||
message = {
|
|
||||||
"text": "Hello, world!",
|
|
||||||
"user": "john_doe",
|
|
||||||
"channel": "general",
|
|
||||||
"ts": "1234567890.123456"
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = reader._format_message(message)
|
|
||||||
assert "Channel: #general" in formatted
|
|
||||||
assert "User: john_doe" in formatted
|
|
||||||
assert "Message: Hello, world!" in formatted
|
|
||||||
assert "Time:" in formatted
|
|
||||||
|
|
||||||
# Test with missing fields
|
|
||||||
message = {"text": "Simple message"}
|
|
||||||
formatted = reader._format_message(message)
|
|
||||||
assert "Message: Simple message" in formatted
|
|
||||||
|
|
||||||
print("✅ Slack message formatting tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_twitter_bookmark_formatting():
|
|
||||||
"""Test Twitter bookmark formatting functionality."""
|
|
||||||
print("Testing Twitter bookmark formatting...")
|
|
||||||
|
|
||||||
reader = TwitterMCPReader("twitter-mcp-server")
|
|
||||||
|
|
||||||
# Test basic bookmark formatting
|
|
||||||
bookmark = {
|
|
||||||
"text": "This is a great article about AI!",
|
|
||||||
"author": "ai_researcher",
|
|
||||||
"created_at": "2024-01-01T12:00:00Z",
|
|
||||||
"url": "https://twitter.com/ai_researcher/status/123456789",
|
|
||||||
"likes": 42,
|
|
||||||
"retweets": 15
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = reader._format_bookmark(bookmark)
|
|
||||||
assert "=== Twitter Bookmark ===" in formatted
|
|
||||||
assert "Author: @ai_researcher" in formatted
|
|
||||||
assert "Content:" in formatted
|
|
||||||
assert "This is a great article about AI!" in formatted
|
|
||||||
assert "URL: https://twitter.com" in formatted
|
|
||||||
assert "Likes: 42" in formatted
|
|
||||||
assert "Retweets: 15" in formatted
|
|
||||||
|
|
||||||
# Test with minimal data
|
|
||||||
bookmark = {"text": "Simple tweet"}
|
|
||||||
formatted = reader._format_bookmark(bookmark)
|
|
||||||
assert "=== Twitter Bookmark ===" in formatted
|
|
||||||
assert "Simple tweet" in formatted
|
|
||||||
|
|
||||||
print("✅ Twitter bookmark formatting tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_slack_rag_initialization():
|
|
||||||
"""Test that SlackMCPRAG can be initialized."""
|
|
||||||
print("Testing SlackMCPRAG initialization...")
|
|
||||||
|
|
||||||
app = SlackMCPRAG()
|
|
||||||
assert app.default_index_name == "slack_messages"
|
|
||||||
assert hasattr(app, 'parser')
|
|
||||||
|
|
||||||
print("✅ SlackMCPRAG initialization tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_twitter_rag_initialization():
|
|
||||||
"""Test that TwitterMCPRAG can be initialized."""
|
|
||||||
print("Testing TwitterMCPRAG initialization...")
|
|
||||||
|
|
||||||
app = TwitterMCPRAG()
|
|
||||||
assert app.default_index_name == "twitter_bookmarks"
|
|
||||||
assert hasattr(app, 'parser')
|
|
||||||
|
|
||||||
print("✅ TwitterMCPRAG initialization tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_concatenated_content_creation():
|
|
||||||
"""Test creation of concatenated content from multiple messages."""
|
|
||||||
print("Testing concatenated content creation...")
|
|
||||||
|
|
||||||
reader = SlackMCPReader("slack-mcp-server", workspace_name="test-workspace")
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"text": "First message", "user": "alice", "ts": "1000"},
|
|
||||||
{"text": "Second message", "user": "bob", "ts": "2000"},
|
|
||||||
{"text": "Third message", "user": "charlie", "ts": "3000"}
|
|
||||||
]
|
|
||||||
|
|
||||||
content = reader._create_concatenated_content(messages, "general")
|
|
||||||
|
|
||||||
assert "Slack Channel: #general" in content
|
|
||||||
assert "Message Count: 3" in content
|
|
||||||
assert "Workspace: test-workspace" in content
|
|
||||||
assert "First message" in content
|
|
||||||
assert "Second message" in content
|
|
||||||
assert "Third message" in content
|
|
||||||
|
|
||||||
print("✅ Concatenated content creation tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all tests."""
|
|
||||||
print("🧪 Running MCP Integration Tests")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_slack_reader_initialization()
|
|
||||||
test_twitter_reader_initialization()
|
|
||||||
test_slack_message_formatting()
|
|
||||||
test_twitter_bookmark_formatting()
|
|
||||||
test_slack_rag_initialization()
|
|
||||||
test_twitter_rag_initialization()
|
|
||||||
test_concatenated_content_creation()
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("🎉 All tests passed! MCP integration is working correctly.")
|
|
||||||
print("\nNext steps:")
|
|
||||||
print("1. Install actual MCP servers for Slack and Twitter")
|
|
||||||
print("2. Configure API credentials")
|
|
||||||
print("3. Test with --test-connection flag")
|
|
||||||
print("4. Start indexing your live data!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Test failed: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Standalone test script for MCP integration implementations.
|
|
||||||
|
|
||||||
This script tests the basic functionality of the MCP readers
|
|
||||||
without requiring LEANN core dependencies.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add the parent directory to the path so we can import from apps
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
|
|
||||||
def test_slack_reader_basic():
|
|
||||||
"""Test basic SlackMCPReader functionality without async operations."""
|
|
||||||
print("Testing SlackMCPReader basic functionality...")
|
|
||||||
|
|
||||||
# Import and test initialization
|
|
||||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
|
||||||
|
|
||||||
reader = SlackMCPReader("slack-mcp-server")
|
|
||||||
assert reader.mcp_server_command == "slack-mcp-server"
|
|
||||||
assert reader.concatenate_conversations == True
|
|
||||||
|
|
||||||
# Test message formatting
|
|
||||||
message = {
|
|
||||||
"text": "Hello team! How's the project going?",
|
|
||||||
"user": "john_doe",
|
|
||||||
"channel": "general",
|
|
||||||
"ts": "1234567890.123456"
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = reader._format_message(message)
|
|
||||||
assert "Channel: #general" in formatted
|
|
||||||
assert "User: john_doe" in formatted
|
|
||||||
assert "Message: Hello team!" in formatted
|
|
||||||
|
|
||||||
# Test concatenated content creation
|
|
||||||
messages = [
|
|
||||||
{"text": "First message", "user": "alice", "ts": "1000"},
|
|
||||||
{"text": "Second message", "user": "bob", "ts": "2000"}
|
|
||||||
]
|
|
||||||
|
|
||||||
content = reader._create_concatenated_content(messages, "dev-team")
|
|
||||||
assert "Slack Channel: #dev-team" in content
|
|
||||||
assert "Message Count: 2" in content
|
|
||||||
assert "First message" in content
|
|
||||||
assert "Second message" in content
|
|
||||||
|
|
||||||
print("✅ SlackMCPReader basic tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_twitter_reader_basic():
|
|
||||||
"""Test basic TwitterMCPReader functionality."""
|
|
||||||
print("Testing TwitterMCPReader basic functionality...")
|
|
||||||
|
|
||||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
|
||||||
|
|
||||||
reader = TwitterMCPReader("twitter-mcp-server")
|
|
||||||
assert reader.mcp_server_command == "twitter-mcp-server"
|
|
||||||
assert reader.include_tweet_content == True
|
|
||||||
assert reader.max_bookmarks == 1000
|
|
||||||
|
|
||||||
# Test bookmark formatting
|
|
||||||
bookmark = {
|
|
||||||
"text": "Amazing article about the future of AI! Must read for everyone interested in tech.",
|
|
||||||
"author": "tech_guru",
|
|
||||||
"created_at": "2024-01-15T14:30:00Z",
|
|
||||||
"url": "https://twitter.com/tech_guru/status/123456789",
|
|
||||||
"likes": 156,
|
|
||||||
"retweets": 42,
|
|
||||||
"replies": 23,
|
|
||||||
"hashtags": ["AI", "tech", "future"],
|
|
||||||
"mentions": ["@openai", "@anthropic"]
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = reader._format_bookmark(bookmark)
|
|
||||||
assert "=== Twitter Bookmark ===" in formatted
|
|
||||||
assert "Author: @tech_guru" in formatted
|
|
||||||
assert "Amazing article about the future of AI!" in formatted
|
|
||||||
assert "Likes: 156" in formatted
|
|
||||||
assert "Retweets: 42" in formatted
|
|
||||||
assert "Hashtags: AI, tech, future" in formatted
|
|
||||||
assert "Mentions: @openai, @anthropic" in formatted
|
|
||||||
|
|
||||||
# Test with minimal data
|
|
||||||
simple_bookmark = {"text": "Short tweet", "author": "user123"}
|
|
||||||
formatted_simple = reader._format_bookmark(simple_bookmark)
|
|
||||||
assert "=== Twitter Bookmark ===" in formatted_simple
|
|
||||||
assert "Short tweet" in formatted_simple
|
|
||||||
assert "Author: @user123" in formatted_simple
|
|
||||||
|
|
||||||
print("✅ TwitterMCPReader basic tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_request_format():
|
|
||||||
"""Test MCP request formatting."""
|
|
||||||
print("Testing MCP request formatting...")
|
|
||||||
|
|
||||||
# Test initialization request format
|
|
||||||
init_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 1,
|
|
||||||
"method": "initialize",
|
|
||||||
"params": {
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Verify it's valid JSON
|
|
||||||
json_str = json.dumps(init_request)
|
|
||||||
parsed = json.loads(json_str)
|
|
||||||
assert parsed["jsonrpc"] == "2.0"
|
|
||||||
assert parsed["method"] == "initialize"
|
|
||||||
assert parsed["params"]["protocolVersion"] == "2024-11-05"
|
|
||||||
|
|
||||||
# Test tools/list request
|
|
||||||
list_request = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 2,
|
|
||||||
"method": "tools/list",
|
|
||||||
"params": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
json_str = json.dumps(list_request)
|
|
||||||
parsed = json.loads(json_str)
|
|
||||||
assert parsed["method"] == "tools/list"
|
|
||||||
|
|
||||||
print("✅ MCP request formatting tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_data_processing():
|
|
||||||
"""Test data processing capabilities."""
|
|
||||||
print("Testing data processing capabilities...")
|
|
||||||
|
|
||||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
|
||||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
|
||||||
|
|
||||||
# Test Slack message processing with various formats
|
|
||||||
slack_reader = SlackMCPReader("test-server")
|
|
||||||
|
|
||||||
messages_with_timestamps = [
|
|
||||||
{"text": "Meeting in 5 minutes", "user": "alice", "ts": "1000.123"},
|
|
||||||
{"text": "On my way!", "user": "bob", "ts": "1001.456"},
|
|
||||||
{"text": "Starting now", "user": "charlie", "ts": "1002.789"}
|
|
||||||
]
|
|
||||||
|
|
||||||
content = slack_reader._create_concatenated_content(messages_with_timestamps, "meetings")
|
|
||||||
assert "Meeting in 5 minutes" in content
|
|
||||||
assert "On my way!" in content
|
|
||||||
assert "Starting now" in content
|
|
||||||
|
|
||||||
# Test Twitter bookmark processing with engagement data
|
|
||||||
twitter_reader = TwitterMCPReader("test-server", include_metadata=True)
|
|
||||||
|
|
||||||
high_engagement_bookmark = {
|
|
||||||
"text": "Thread about startup lessons learned 🧵",
|
|
||||||
"author": "startup_founder",
|
|
||||||
"likes": 1250,
|
|
||||||
"retweets": 340,
|
|
||||||
"replies": 89
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted = twitter_reader._format_bookmark(high_engagement_bookmark)
|
|
||||||
assert "Thread about startup lessons learned" in formatted
|
|
||||||
assert "Likes: 1250" in formatted
|
|
||||||
assert "Retweets: 340" in formatted
|
|
||||||
assert "Replies: 89" in formatted
|
|
||||||
|
|
||||||
# Test with metadata disabled
|
|
||||||
twitter_reader_no_meta = TwitterMCPReader("test-server", include_metadata=False)
|
|
||||||
formatted_no_meta = twitter_reader_no_meta._format_bookmark(high_engagement_bookmark)
|
|
||||||
assert "Thread about startup lessons learned" in formatted_no_meta
|
|
||||||
assert "Likes:" not in formatted_no_meta
|
|
||||||
assert "Retweets:" not in formatted_no_meta
|
|
||||||
|
|
||||||
print("✅ Data processing tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all standalone tests."""
|
|
||||||
print("🧪 Running MCP Integration Standalone Tests")
|
|
||||||
print("=" * 60)
|
|
||||||
print("Testing core functionality without LEANN dependencies...")
|
|
||||||
print()
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_slack_reader_basic()
|
|
||||||
test_twitter_reader_basic()
|
|
||||||
test_mcp_request_format()
|
|
||||||
test_data_processing()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🎉 All standalone tests passed!")
|
|
||||||
print("\n✨ MCP Integration Summary:")
|
|
||||||
print("- SlackMCPReader: Ready for Slack message processing")
|
|
||||||
print("- TwitterMCPReader: Ready for Twitter bookmark processing")
|
|
||||||
print("- MCP Protocol: Properly formatted JSON-RPC requests")
|
|
||||||
print("- Data Processing: Handles various message/bookmark formats")
|
|
||||||
|
|
||||||
print("\n🚀 Next Steps:")
|
|
||||||
print("1. Install MCP servers: npm install -g slack-mcp-server twitter-mcp-server")
|
|
||||||
print("2. Configure API credentials for Slack and Twitter")
|
|
||||||
print("3. Test connections: python -m apps.slack_rag --test-connection")
|
|
||||||
print("4. Start indexing live data from your platforms!")
|
|
||||||
|
|
||||||
print("\n📖 Documentation:")
|
|
||||||
print("- Check README.md for detailed setup instructions")
|
|
||||||
print("- Run examples/mcp_integration_demo.py for usage examples")
|
|
||||||
print("- Explore apps/slack_rag.py and apps/twitter_rag.py for implementation details")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user