fix readme

This commit is contained in:
yichuan-w
2025-10-08 21:38:55 +00:00
parent 3ec5e8d035
commit 5be0c144ad
72 changed files with 16608 additions and 4175 deletions

View File

@@ -10,7 +10,9 @@ from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
from leann.interactive_utils import create_rag_session
from leann.registry import register_project_directory
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
dotenv.load_dotenv()
@@ -78,6 +80,24 @@ class BaseRAGExample(ABC):
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
)
embedding_group.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
embedding_group.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
embedding_group.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
# LLM parameters
llm_group = parser.add_argument_group("LLM Parameters")
@@ -97,8 +117,8 @@ class BaseRAGExample(ABC):
llm_group.add_argument(
"--llm-host",
type=str,
default="http://localhost:11434",
help="Host for Ollama API (default: http://localhost:11434)",
default=None,
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
)
llm_group.add_argument(
"--thinking-budget",
@@ -107,6 +127,18 @@ class BaseRAGExample(ABC):
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
llm_group.add_argument(
"--llm-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs",
)
llm_group.add_argument(
"--llm-api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# AST Chunking parameters
ast_group = parser.add_argument_group("AST Chunking Parameters")
@@ -205,9 +237,13 @@ class BaseRAGExample(ABC):
if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o"
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
resolved_key = resolve_openai_api_key(args.llm_api_key)
if resolved_key:
config["api_key"] = resolved_key
elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = args.llm_host
config["host"] = resolve_ollama_host(args.llm_host)
elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
@@ -223,10 +259,20 @@ class BaseRAGExample(ABC):
print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree,
complexity=args.build_complexity,
is_compact=not args.no_compact,
@@ -262,37 +308,26 @@ class BaseRAGExample(ABC):
complexity=args.search_complexity,
)
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
print("Type 'quit' or 'exit' to stop.\n")
# Create interactive session
session = create_rag_session(
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
)
while True:
try:
query = input("You: ").strip()
if query.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
def handle_query(query: str):
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
if not query:
continue
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error: {e}")
session.run_interactive_loop(handle_query)
async def run_single_query(self, args, index_path: str, query: str):
"""Run a single query against the index."""

View File

View File

@@ -0,0 +1,413 @@
"""
ChatGPT export data reader.
Reads and processes ChatGPT export data from chat.html files.
"""
import re
from pathlib import Path
from typing import Any
from zipfile import ZipFile
from bs4 import BeautifulSoup
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ChatGPTReader(BaseReader):
"""
ChatGPT export data reader.
Reads ChatGPT conversation data from exported chat.html files or zip archives.
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
self.concatenate_conversations = concatenate_conversations
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
"""
Extract chat.html from ChatGPT export zip file.
Args:
zip_path: Path to the ChatGPT export zip file
Returns:
HTML content as string, or None if not found
"""
try:
with ZipFile(zip_path, "r") as zip_file:
# Look for chat.html or conversations.html
html_files = [
f
for f in zip_file.namelist()
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
]
if not html_files:
print(f"No HTML chat file found in {zip_path}")
return None
# Use the first HTML file found
html_file = html_files[0]
print(f"Found HTML file: {html_file}")
with zip_file.open(html_file) as f:
return f.read().decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error extracting HTML from zip {zip_path}: {e}")
return None
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
"""
Parse ChatGPT HTML export to extract conversations.
Args:
html_content: HTML content from ChatGPT export
Returns:
List of conversation dictionaries
"""
soup = BeautifulSoup(html_content, "html.parser")
conversations = []
# Try different possible structures for ChatGPT exports
# Structure 1: Look for conversation containers
conversation_containers = soup.find_all(
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
)
if not conversation_containers:
# Structure 2: Look for message containers directly
conversation_containers = [soup] # Use the entire document as one conversation
for container in conversation_containers:
conversation = self._extract_conversation_from_container(container)
if conversation and conversation.get("messages"):
conversations.append(conversation)
# If no structured conversations found, try to extract all text as one conversation
if not conversations:
all_text = soup.get_text(separator="\n", strip=True)
if all_text:
conversations.append(
{
"title": "ChatGPT Conversation",
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
"timestamp": None,
}
)
return conversations
def _extract_conversation_from_container(self, container) -> dict | None:
"""
Extract conversation data from a container element.
Args:
container: BeautifulSoup element containing conversation
Returns:
Dictionary with conversation data or None
"""
messages = []
# Look for message elements with various possible structures
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
for selector in message_selectors:
message_elements = container.select(selector)
if message_elements:
break
else:
message_elements = []
# If no structured messages found, treat the entire container as one message
if not message_elements:
text_content = container.get_text(separator="\n", strip=True)
if text_content:
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
else:
for element in message_elements:
message = self._extract_message_from_element(element)
if message:
messages.append(message)
if not messages:
return None
# Try to extract conversation title
title_element = container.find(["h1", "h2", "h3", "title"])
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
# Try to extract timestamp from various possible locations
timestamp = self._extract_timestamp_from_container(container)
return {"title": title, "messages": messages, "timestamp": timestamp}
def _extract_message_from_element(self, element) -> dict | None:
"""
Extract message data from an element.
Args:
element: BeautifulSoup element containing message
Returns:
Dictionary with message data or None
"""
text_content = element.get_text(separator=" ", strip=True)
# Skip empty or very short messages
if not text_content or len(text_content.strip()) < 3:
return None
# Try to determine role (user/assistant) from class names or content
role = "mixed" # Default role
class_names = " ".join(element.get("class", [])).lower()
if "user" in class_names or "human" in class_names:
role = "user"
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
role = "assistant"
elif text_content.lower().startswith(("you:", "user:", "me:")):
role = "user"
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
role = "assistant"
text_content = re.sub(
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
)
# Try to extract timestamp
timestamp = self._extract_timestamp_from_element(element)
return {"role": role, "content": text_content, "timestamp": timestamp}
def _extract_timestamp_from_element(self, element) -> str | None:
"""Extract timestamp from element."""
# Look for timestamp in various attributes and child elements
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
for attr in timestamp_attrs:
if element.get(attr):
return element.get(attr)
# Look for time elements
time_element = element.find("time")
if time_element:
return time_element.get("datetime") or time_element.get_text(strip=True)
# Look for date-like text patterns
text = element.get_text()
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
for pattern in date_patterns:
match = re.search(pattern, text)
if match:
return match.group()
return None
def _extract_timestamp_from_container(self, container) -> str | None:
"""Extract timestamp from conversation container."""
return self._extract_timestamp_from_element(container)
def _create_concatenated_content(self, conversation: dict) -> str:
"""
Create concatenated content from conversation messages.
Args:
conversation: Dictionary containing conversation data
Returns:
Formatted concatenated content
"""
title = conversation.get("title", "ChatGPT Conversation")
messages = conversation.get("messages", [])
timestamp = conversation.get("timestamp", "Unknown")
# Build message content
message_parts = []
for message in messages:
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if role == "user":
prefix = "[You]"
elif role == "assistant":
prefix = "[ChatGPT]"
else:
prefix = "[Message]"
# Add timestamp if available
if msg_timestamp:
prefix += f" ({msg_timestamp})"
message_parts.append(f"{prefix}: {content}")
concatenated_text = "\n\n".join(message_parts)
# Create final document content
doc_content = f"""Conversation: {title}
Date: {timestamp}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load ChatGPT export data.
Args:
input_dir: Directory containing ChatGPT export files or path to specific file
**load_kwargs:
max_count (int): Maximum number of conversations to process
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
include_metadata (bool): Whether to include metadata in documents
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", -1)
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
include_metadata = load_kwargs.get("include_metadata", True)
if not chatgpt_export_path:
print("No ChatGPT export path provided")
return docs
export_path = Path(chatgpt_export_path)
if not export_path.exists():
print(f"ChatGPT export path not found: {export_path}")
return docs
html_content = None
# Handle different input types
if export_path.is_file():
if export_path.suffix.lower() == ".zip":
# Extract HTML from zip file
html_content = self._extract_html_from_zip(export_path)
elif export_path.suffix.lower() == ".html":
# Read HTML file directly
try:
with open(export_path, encoding="utf-8", errors="ignore") as f:
html_content = f.read()
except Exception as e:
print(f"Error reading HTML file {export_path}: {e}")
return docs
else:
print(f"Unsupported file type: {export_path.suffix}")
return docs
elif export_path.is_dir():
# Look for HTML files in directory
html_files = list(export_path.glob("*.html"))
zip_files = list(export_path.glob("*.zip"))
if html_files:
# Use first HTML file found
html_file = html_files[0]
print(f"Found HTML file: {html_file}")
try:
with open(html_file, encoding="utf-8", errors="ignore") as f:
html_content = f.read()
except Exception as e:
print(f"Error reading HTML file {html_file}: {e}")
return docs
elif zip_files:
# Use first zip file found
zip_file = zip_files[0]
print(f"Found zip file: {zip_file}")
html_content = self._extract_html_from_zip(zip_file)
else:
print(f"No HTML or zip files found in {export_path}")
return docs
if not html_content:
print("No HTML content found to process")
return docs
# Parse conversations from HTML
print("Parsing ChatGPT conversations from HTML...")
conversations = self._parse_chatgpt_html(html_content)
if not conversations:
print("No conversations found in HTML content")
return docs
print(f"Found {len(conversations)} conversations")
# Process conversations into documents
count = 0
for conversation in conversations:
if max_count > 0 and count >= max_count:
break
if self.concatenate_conversations:
# Create one document per conversation with concatenated messages
doc_content = self._create_concatenated_content(conversation)
metadata = {}
if include_metadata:
metadata = {
"title": conversation.get("title", "ChatGPT Conversation"),
"timestamp": conversation.get("timestamp", "Unknown"),
"message_count": len(conversation.get("messages", [])),
"source": "ChatGPT Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
else:
# Create separate documents for each message
for message in conversation.get("messages", []):
if max_count > 0 and count >= max_count:
break
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if not content.strip():
continue
# Create document content with context
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
Role: {role}
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
Message: {content}
"""
metadata = {}
if include_metadata:
metadata = {
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
"role": role,
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
"source": "ChatGPT Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
print(f"Created {len(docs)} documents from ChatGPT export")
return docs

186
apps/chatgpt_rag.py Normal file
View File

@@ -0,0 +1,186 @@
"""
ChatGPT RAG example using the unified interface.
Supports ChatGPT export data from chat.html files.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .chatgpt_data.chatgpt_reader import ChatGPTReader
class ChatGPTRAG(BaseRAGExample):
"""RAG example for ChatGPT conversation data."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all conversations by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="ChatGPT",
description="Process and query ChatGPT conversation exports with LEANN",
default_index_name="chatgpt_conversations_index",
)
def _add_specific_arguments(self, parser):
"""Add ChatGPT-specific arguments."""
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
chatgpt_group.add_argument(
"--export-path",
type=str,
default="./chatgpt_export",
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
)
chatgpt_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
chatgpt_group.add_argument(
"--separate-messages",
action="store_true",
help="Process each message as a separate document (overrides --concatenate-conversations)",
)
chatgpt_group.add_argument(
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
)
chatgpt_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
"""
Find ChatGPT export files in the given path.
Args:
export_path: Path to search for exports
Returns:
List of paths to ChatGPT export files
"""
export_files = []
if export_path.is_file():
if export_path.suffix.lower() in [".zip", ".html"]:
export_files.append(export_path)
elif export_path.is_dir():
# Look for zip and html files
export_files.extend(export_path.glob("*.zip"))
export_files.extend(export_path.glob("*.html"))
return export_files
async def load_data(self, args) -> list[str]:
"""Load ChatGPT export data and convert to text chunks."""
export_path = Path(args.export_path)
if not export_path.exists():
print(f"ChatGPT export path not found: {export_path}")
print(
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
)
print("\nTo export your ChatGPT data:")
print("1. Sign in to ChatGPT")
print("2. Click on your profile icon → Settings → Data Controls")
print("3. Click 'Export' under Export Data")
print("4. Download the zip file from the email link")
print("5. Extract or place the file/directory at the specified path")
return []
# Find export files
export_files = self._find_chatgpt_exports(export_path)
if not export_files:
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
return []
print(f"Found {len(export_files)} ChatGPT export files")
# Create reader with appropriate settings
concatenate = args.concatenate_conversations and not args.separate_messages
reader = ChatGPTReader(concatenate_conversations=concatenate)
# Process each export file
all_documents = []
total_processed = 0
for i, export_file in enumerate(export_files):
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
try:
# Apply max_items limit per file
max_per_file = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_file = remaining
# Load conversations
documents = reader.load_data(
chatgpt_export_path=str(export_file),
max_count=max_per_file,
include_metadata=True,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} conversations from this file")
else:
print(f"No conversations loaded from {export_file}")
except Exception as e:
print(f"Error processing {export_file}: {e}")
continue
if not all_documents:
print("No conversations found to process!")
print("\nTroubleshooting:")
print("- Ensure the export file is a valid ChatGPT export")
print("- Check that the HTML file contains conversation data")
print("- Try extracting the zip file and pointing to the HTML file directly")
return []
print(f"\nTotal conversations processed: {len(all_documents)}")
print("Now starting to split into text chunks... this may take some time")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for ChatGPT RAG
print("\n🤖 ChatGPT RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did I ask about Python programming?'")
print("- 'Show me conversations about machine learning'")
print("- 'Find discussions about travel planning'")
print("- 'What advice did ChatGPT give me about career development?'")
print("- 'Search for conversations about cooking recipes'")
print("\nTo get started:")
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
print("3. Run this script to build your personal ChatGPT knowledge base!")
print("\nOr run without --query for interactive mode\n")
rag = ChatGPTRAG()
asyncio.run(rag.run())

View File

View File

@@ -0,0 +1,420 @@
"""
Claude export data reader.
Reads and processes Claude conversation data from exported JSON files.
"""
import json
from pathlib import Path
from typing import Any
from zipfile import ZipFile
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ClaudeReader(BaseReader):
"""
Claude export data reader.
Reads Claude conversation data from exported JSON files or zip archives.
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
self.concatenate_conversations = concatenate_conversations
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
"""
Extract JSON files from Claude export zip file.
Args:
zip_path: Path to the Claude export zip file
Returns:
List of JSON content strings, or empty list if not found
"""
json_contents = []
try:
with ZipFile(zip_path, "r") as zip_file:
# Look for JSON files
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
if not json_files:
print(f"No JSON files found in {zip_path}")
return []
print(f"Found {len(json_files)} JSON files in archive")
for json_file in json_files:
with zip_file.open(json_file) as f:
content = f.read().decode("utf-8", errors="ignore")
json_contents.append(content)
except Exception as e:
print(f"Error extracting JSON from zip {zip_path}: {e}")
return json_contents
def _parse_claude_json(self, json_content: str) -> list[dict]:
"""
Parse Claude JSON export to extract conversations.
Args:
json_content: JSON content from Claude export
Returns:
List of conversation dictionaries
"""
try:
data = json.loads(json_content)
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
return []
conversations = []
# Handle different possible JSON structures
if isinstance(data, list):
# If data is a list of conversations
for item in data:
conversation = self._extract_conversation_from_json(item)
if conversation:
conversations.append(conversation)
elif isinstance(data, dict):
# Check for common structures
if "conversations" in data:
# Structure: {"conversations": [...]}
for item in data["conversations"]:
conversation = self._extract_conversation_from_json(item)
if conversation:
conversations.append(conversation)
elif "messages" in data:
# Single conversation with messages
conversation = self._extract_conversation_from_json(data)
if conversation:
conversations.append(conversation)
else:
# Try to treat the whole object as a conversation
conversation = self._extract_conversation_from_json(data)
if conversation:
conversations.append(conversation)
return conversations
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
"""
Extract conversation data from a JSON object.
Args:
conv_data: Dictionary containing conversation data
Returns:
Dictionary with conversation data or None
"""
if not isinstance(conv_data, dict):
return None
messages = []
# Look for messages in various possible structures
message_sources = []
if "messages" in conv_data:
message_sources = conv_data["messages"]
elif "chat" in conv_data:
message_sources = conv_data["chat"]
elif "conversation" in conv_data:
message_sources = conv_data["conversation"]
else:
# If no clear message structure, try to extract from the object itself
if "content" in conv_data and "role" in conv_data:
message_sources = [conv_data]
for msg_data in message_sources:
message = self._extract_message_from_json(msg_data)
if message:
messages.append(message)
if not messages:
return None
# Extract conversation metadata
title = self._extract_title_from_conversation(conv_data, messages)
timestamp = self._extract_timestamp_from_conversation(conv_data)
return {"title": title, "messages": messages, "timestamp": timestamp}
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
"""
Extract message data from a JSON message object.
Args:
msg_data: Dictionary containing message data
Returns:
Dictionary with message data or None
"""
if not isinstance(msg_data, dict):
return None
# Extract content from various possible fields
content = ""
content_fields = ["content", "text", "message", "body"]
for field in content_fields:
if msg_data.get(field):
content = str(msg_data[field])
break
if not content or len(content.strip()) < 3:
return None
# Extract role (user/assistant/human/ai/claude)
role = "mixed" # Default role
role_fields = ["role", "sender", "from", "author", "type"]
for field in role_fields:
if msg_data.get(field):
role_value = str(msg_data[field]).lower()
if role_value in ["user", "human", "person"]:
role = "user"
elif role_value in ["assistant", "ai", "claude", "bot"]:
role = "assistant"
break
# Extract timestamp
timestamp = self._extract_timestamp_from_message(msg_data)
return {"role": role, "content": content, "timestamp": timestamp}
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
"""Extract timestamp from message data."""
timestamp_fields = ["timestamp", "created_at", "date", "time"]
for field in timestamp_fields:
if msg_data.get(field):
return str(msg_data[field])
return None
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
"""Extract timestamp from conversation data."""
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
for field in timestamp_fields:
if conv_data.get(field):
return str(conv_data[field])
return None
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
"""Extract or generate title for conversation."""
# Try to find explicit title
title_fields = ["title", "name", "subject", "topic"]
for field in title_fields:
if conv_data.get(field):
return str(conv_data[field])
# Generate title from first user message
for message in messages:
if message.get("role") == "user":
content = message.get("content", "")
if content:
# Use first 50 characters as title
title = content[:50].strip()
if len(content) > 50:
title += "..."
return title
return "Claude Conversation"
def _create_concatenated_content(self, conversation: dict) -> str:
"""
Create concatenated content from conversation messages.
Args:
conversation: Dictionary containing conversation data
Returns:
Formatted concatenated content
"""
title = conversation.get("title", "Claude Conversation")
messages = conversation.get("messages", [])
timestamp = conversation.get("timestamp", "Unknown")
# Build message content
message_parts = []
for message in messages:
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if role == "user":
prefix = "[You]"
elif role == "assistant":
prefix = "[Claude]"
else:
prefix = "[Message]"
# Add timestamp if available
if msg_timestamp:
prefix += f" ({msg_timestamp})"
message_parts.append(f"{prefix}: {content}")
concatenated_text = "\n\n".join(message_parts)
# Create final document content
doc_content = f"""Conversation: {title}
Date: {timestamp}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load Claude export data.
Args:
input_dir: Directory containing Claude export files or path to specific file
**load_kwargs:
max_count (int): Maximum number of conversations to process
claude_export_path (str): Specific path to Claude export file/directory
include_metadata (bool): Whether to include metadata in documents
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", -1)
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
include_metadata = load_kwargs.get("include_metadata", True)
if not claude_export_path:
print("No Claude export path provided")
return docs
export_path = Path(claude_export_path)
if not export_path.exists():
print(f"Claude export path not found: {export_path}")
return docs
json_contents = []
# Handle different input types
if export_path.is_file():
if export_path.suffix.lower() == ".zip":
# Extract JSON from zip file
json_contents = self._extract_json_from_zip(export_path)
elif export_path.suffix.lower() == ".json":
# Read JSON file directly
try:
with open(export_path, encoding="utf-8", errors="ignore") as f:
json_contents.append(f.read())
except Exception as e:
print(f"Error reading JSON file {export_path}: {e}")
return docs
else:
print(f"Unsupported file type: {export_path.suffix}")
return docs
elif export_path.is_dir():
# Look for JSON files in directory
json_files = list(export_path.glob("*.json"))
zip_files = list(export_path.glob("*.zip"))
if json_files:
print(f"Found {len(json_files)} JSON files in directory")
for json_file in json_files:
try:
with open(json_file, encoding="utf-8", errors="ignore") as f:
json_contents.append(f.read())
except Exception as e:
print(f"Error reading JSON file {json_file}: {e}")
continue
if zip_files:
print(f"Found {len(zip_files)} ZIP files in directory")
for zip_file in zip_files:
zip_contents = self._extract_json_from_zip(zip_file)
json_contents.extend(zip_contents)
if not json_files and not zip_files:
print(f"No JSON or ZIP files found in {export_path}")
return docs
if not json_contents:
print("No JSON content found to process")
return docs
# Parse conversations from JSON content
print("Parsing Claude conversations from JSON...")
all_conversations = []
for json_content in json_contents:
conversations = self._parse_claude_json(json_content)
all_conversations.extend(conversations)
if not all_conversations:
print("No conversations found in JSON content")
return docs
print(f"Found {len(all_conversations)} conversations")
# Process conversations into documents
count = 0
for conversation in all_conversations:
if max_count > 0 and count >= max_count:
break
if self.concatenate_conversations:
# Create one document per conversation with concatenated messages
doc_content = self._create_concatenated_content(conversation)
metadata = {}
if include_metadata:
metadata = {
"title": conversation.get("title", "Claude Conversation"),
"timestamp": conversation.get("timestamp", "Unknown"),
"message_count": len(conversation.get("messages", [])),
"source": "Claude Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
else:
# Create separate documents for each message
for message in conversation.get("messages", []):
if max_count > 0 and count >= max_count:
break
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if not content.strip():
continue
# Create document content with context
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
Role: {role}
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
Message: {content}
"""
metadata = {}
if include_metadata:
metadata = {
"conversation_title": conversation.get("title", "Claude Conversation"),
"role": role,
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
"source": "Claude Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
print(f"Created {len(docs)} documents from Claude export")
return docs

189
apps/claude_rag.py Normal file
View File

@@ -0,0 +1,189 @@
"""
Claude RAG example using the unified interface.
Supports Claude export data from JSON files.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .claude_data.claude_reader import ClaudeReader
class ClaudeRAG(BaseRAGExample):
"""RAG example for Claude conversation data."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all conversations by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Claude",
description="Process and query Claude conversation exports with LEANN",
default_index_name="claude_conversations_index",
)
def _add_specific_arguments(self, parser):
"""Add Claude-specific arguments."""
claude_group = parser.add_argument_group("Claude Parameters")
claude_group.add_argument(
"--export-path",
type=str,
default="./claude_export",
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
)
claude_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
claude_group.add_argument(
"--separate-messages",
action="store_true",
help="Process each message as a separate document (overrides --concatenate-conversations)",
)
claude_group.add_argument(
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
)
claude_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _find_claude_exports(self, export_path: Path) -> list[Path]:
"""
Find Claude export files in the given path.
Args:
export_path: Path to search for exports
Returns:
List of paths to Claude export files
"""
export_files = []
if export_path.is_file():
if export_path.suffix.lower() in [".zip", ".json"]:
export_files.append(export_path)
elif export_path.is_dir():
# Look for zip and json files
export_files.extend(export_path.glob("*.zip"))
export_files.extend(export_path.glob("*.json"))
return export_files
async def load_data(self, args) -> list[str]:
"""Load Claude export data and convert to text chunks."""
export_path = Path(args.export_path)
if not export_path.exists():
print(f"Claude export path not found: {export_path}")
print(
"Please ensure you have exported your Claude data and placed it in the correct location."
)
print("\nTo export your Claude data:")
print("1. Open Claude in your browser")
print("2. Look for export/download options in settings or conversation menu")
print("3. Download the conversation data (usually in JSON format)")
print("4. Place the file/directory at the specified path")
print(
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
)
return []
# Find export files
export_files = self._find_claude_exports(export_path)
if not export_files:
print(f"No Claude export files (.json or .zip) found in: {export_path}")
return []
print(f"Found {len(export_files)} Claude export files")
# Create reader with appropriate settings
concatenate = args.concatenate_conversations and not args.separate_messages
reader = ClaudeReader(concatenate_conversations=concatenate)
# Process each export file
all_documents = []
total_processed = 0
for i, export_file in enumerate(export_files):
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
try:
# Apply max_items limit per file
max_per_file = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_file = remaining
# Load conversations
documents = reader.load_data(
claude_export_path=str(export_file),
max_count=max_per_file,
include_metadata=True,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} conversations from this file")
else:
print(f"No conversations loaded from {export_file}")
except Exception as e:
print(f"Error processing {export_file}: {e}")
continue
if not all_documents:
print("No conversations found to process!")
print("\nTroubleshooting:")
print("- Ensure the export file is a valid Claude export")
print("- Check that the JSON file contains conversation data")
print("- Try using a different export format or method")
print("- Check Claude's documentation for current export procedures")
return []
print(f"\nTotal conversations processed: {len(all_documents)}")
print("Now starting to split into text chunks... this may take some time")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for Claude RAG
print("\n🤖 Claude RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did I ask Claude about Python programming?'")
print("- 'Show me conversations about machine learning'")
print("- 'Find discussions about code optimization'")
print("- 'What advice did Claude give me about software design?'")
print("- 'Search for conversations about debugging techniques'")
print("\nTo get started:")
print("1. Export your Claude conversation data")
print("2. Place the JSON/ZIP file in ./claude_export/")
print("3. Run this script to build your personal Claude knowledge base!")
print("\nOr run without --query for interactive mode\n")
rag = ClaudeRAG()
asyncio.run(rag.run())

View File

@@ -0,0 +1 @@
"""iMessage data processing module."""

View File

@@ -0,0 +1,342 @@
"""
iMessage data reader.
Reads and processes iMessage conversation data from the macOS Messages database.
"""
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class IMessageReader(BaseReader):
"""
iMessage data reader.
Reads iMessage conversation data from the macOS Messages database (chat.db).
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
self.concatenate_conversations = concatenate_conversations
def _get_default_chat_db_path(self) -> Path:
"""
Get the default path to the iMessage chat database.
Returns:
Path to the chat.db file
"""
home = Path.home()
return home / "Library" / "Messages" / "chat.db"
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
"""
Convert Cocoa timestamp to readable format.
Args:
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
Returns:
Formatted timestamp string
"""
if cocoa_timestamp == 0:
return "Unknown"
try:
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
# Convert to seconds and add to Unix epoch
cocoa_epoch = datetime(2001, 1, 1)
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
message_time = cocoa_epoch.timestamp() + unix_timestamp
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "Unknown"
def _get_contact_name(self, handle_id: str) -> str:
"""
Get a readable contact name from handle ID.
Args:
handle_id: The handle ID (phone number or email)
Returns:
Formatted contact name
"""
if not handle_id:
return "Unknown"
# Clean up phone numbers and emails for display
if "@" in handle_id:
return handle_id # Email address
elif handle_id.startswith("+"):
return handle_id # International phone number
else:
# Try to format as phone number
digits = "".join(filter(str.isdigit, handle_id))
if len(digits) == 10:
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
elif len(digits) == 11 and digits[0] == "1":
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
else:
return handle_id
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
"""
Read messages from the iMessage database.
Args:
db_path: Path to the chat.db file
Returns:
List of message dictionaries
"""
if not db_path.exists():
print(f"iMessage database not found at: {db_path}")
return []
try:
# Connect to the database
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Query to get messages with chat and handle information
query = """
SELECT
m.ROWID as message_id,
m.text,
m.date,
m.is_from_me,
m.service,
c.chat_identifier,
c.display_name as chat_display_name,
h.id as handle_id,
c.ROWID as chat_id
FROM message m
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
LEFT JOIN handle h ON m.handle_id = h.ROWID
WHERE m.text IS NOT NULL AND m.text != ''
ORDER BY c.ROWID, m.date
"""
cursor.execute(query)
rows = cursor.fetchall()
messages = []
for row in rows:
(
message_id,
text,
date,
is_from_me,
service,
chat_identifier,
chat_display_name,
handle_id,
chat_id,
) = row
message = {
"message_id": message_id,
"text": text,
"timestamp": self._convert_cocoa_timestamp(date),
"is_from_me": bool(is_from_me),
"service": service or "iMessage",
"chat_identifier": chat_identifier or "Unknown",
"chat_display_name": chat_display_name or "Unknown Chat",
"handle_id": handle_id or "Unknown",
"contact_name": self._get_contact_name(handle_id or ""),
"chat_id": chat_id,
}
messages.append(message)
conn.close()
print(f"Found {len(messages)} messages in database")
return messages
except sqlite3.Error as e:
print(f"Error reading iMessage database: {e}")
return []
except Exception as e:
print(f"Unexpected error reading iMessage database: {e}")
return []
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
"""
Group messages by chat ID.
Args:
messages: List of message dictionaries
Returns:
Dictionary mapping chat_id to list of messages
"""
chats = {}
for message in messages:
chat_id = message["chat_id"]
if chat_id not in chats:
chats[chat_id] = []
chats[chat_id].append(message)
return chats
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
"""
Create concatenated content from chat messages.
Args:
chat_id: The chat ID
messages: List of messages in the chat
Returns:
Concatenated text content
"""
if not messages:
return ""
# Get chat info from first message
first_msg = messages[0]
chat_name = first_msg["chat_display_name"]
chat_identifier = first_msg["chat_identifier"]
# Build message content
message_parts = []
for message in messages:
timestamp = message["timestamp"]
is_from_me = message["is_from_me"]
text = message["text"]
contact_name = message["contact_name"]
if is_from_me:
prefix = "[You]"
else:
prefix = f"[{contact_name}]"
if timestamp != "Unknown":
prefix += f" ({timestamp})"
message_parts.append(f"{prefix}: {text}")
concatenated_text = "\n\n".join(message_parts)
doc_content = f"""Chat: {chat_name}
Identifier: {chat_identifier}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def _create_individual_content(self, message: dict) -> str:
"""
Create content for individual message.
Args:
message: Message dictionary
Returns:
Formatted message content
"""
timestamp = message["timestamp"]
is_from_me = message["is_from_me"]
text = message["text"]
contact_name = message["contact_name"]
chat_name = message["chat_display_name"]
sender = "You" if is_from_me else contact_name
return f"""Message from {sender} in chat "{chat_name}"
Time: {timestamp}
Content: {text}
"""
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load iMessage data and return as documents.
Args:
input_dir: Optional path to directory containing chat.db file.
If not provided, uses default macOS location.
**load_kwargs: Additional arguments (unused)
Returns:
List of Document objects containing iMessage data
"""
docs = []
# Determine database path
if input_dir:
db_path = Path(input_dir) / "chat.db"
else:
db_path = self._get_default_chat_db_path()
print(f"Reading iMessage database from: {db_path}")
# Read messages from database
messages = self._read_messages_from_db(db_path)
if not messages:
return docs
if self.concatenate_conversations:
# Group messages by chat and create concatenated documents
chats = self._group_messages_by_chat(messages)
for chat_id, chat_messages in chats.items():
if not chat_messages:
continue
content = self._create_concatenated_content(chat_id, chat_messages)
# Create metadata
first_msg = chat_messages[0]
last_msg = chat_messages[-1]
metadata = {
"source": "iMessage",
"chat_id": chat_id,
"chat_name": first_msg["chat_display_name"],
"chat_identifier": first_msg["chat_identifier"],
"message_count": len(chat_messages),
"first_message_date": first_msg["timestamp"],
"last_message_date": last_msg["timestamp"],
"participants": list(
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
),
}
doc = Document(text=content, metadata=metadata)
docs.append(doc)
else:
# Create individual documents for each message
for message in messages:
content = self._create_individual_content(message)
metadata = {
"source": "iMessage",
"message_id": message["message_id"],
"chat_id": message["chat_id"],
"chat_name": message["chat_display_name"],
"chat_identifier": message["chat_identifier"],
"timestamp": message["timestamp"],
"is_from_me": message["is_from_me"],
"contact_name": message["contact_name"],
"service": message["service"],
}
doc = Document(text=content, metadata=metadata)
docs.append(doc)
print(f"Created {len(docs)} documents from iMessage data")
return docs

125
apps/imessage_rag.py Normal file
View File

@@ -0,0 +1,125 @@
"""
iMessage RAG Example.
This example demonstrates how to build a RAG system on your iMessage conversation history.
"""
import asyncio
from pathlib import Path
from leann.chunking_utils import create_text_chunks
from apps.base_rag_example import BaseRAGExample
from apps.imessage_data.imessage_reader import IMessageReader
class IMessageRAG(BaseRAGExample):
"""RAG example for iMessage conversation history."""
def __init__(self):
super().__init__(
name="iMessage",
description="RAG on your iMessage conversation history",
default_index_name="imessage_index",
)
def _add_specific_arguments(self, parser):
"""Add iMessage-specific arguments."""
imessage_group = parser.add_argument_group("iMessage Parameters")
imessage_group.add_argument(
"--db-path",
type=str,
default=None,
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
)
imessage_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
imessage_group.add_argument(
"--no-concatenate-conversations",
action="store_true",
help="Process each message individually instead of concatenating by conversation",
)
imessage_group.add_argument(
"--chunk-size",
type=int,
default=1000,
help="Maximum characters per text chunk (default: 1000)",
)
imessage_group.add_argument(
"--chunk-overlap",
type=int,
default=200,
help="Overlap between text chunks (default: 200)",
)
async def load_data(self, args) -> list[str]:
"""Load iMessage history and convert to text chunks."""
print("Loading iMessage conversation history...")
# Determine concatenation setting
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
# Initialize iMessage reader
reader = IMessageReader(concatenate_conversations=concatenate)
# Load documents
try:
if args.db_path:
# Use custom database path
db_dir = str(Path(args.db_path).parent)
documents = reader.load_data(input_dir=db_dir)
else:
# Use default macOS location
documents = reader.load_data()
except Exception as e:
print(f"Error loading iMessage data: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
print("3. Try specifying a custom path with --db-path if you have a backup")
return []
if not documents:
print("No iMessage conversations found!")
return []
print(f"Loaded {len(documents)} iMessage documents")
# Show some statistics
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
print(f"Total messages: {total_messages}")
if concatenate:
# Show chat statistics
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
unique_chats = len(set(chat_names))
print(f"Unique conversations: {unique_chats}")
# Convert to text chunks
all_texts = create_text_chunks(
documents,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
)
# Apply max_items limit if specified
if args.max_items > 0:
all_texts = all_texts[: args.max_items]
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
return all_texts
async def main():
"""Main entry point."""
app = IMessageRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,113 @@
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
### What youll run
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
## Prerequisites (macOS)
### 1) Homebrew poppler (for pdf2image)
```bash
brew install poppler
which pdfinfo && pdfinfo -v
```
### 2) Python environment
Use uv (recommended) or pip. Python 3.9+.
Using uv:
```bash
uv pip install \
colpali_engine \
pdf2image \
pillow \
matplotlib qwen_vl_utils \
einops \
seaborn
```
Notes:
- On first run, models download from Hugging Face. Login/config if needed.
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
```bash
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
```
## Run the demos
### A) Local PDF example
Converts a local PDF into page images, embeds them, builds an index, and searches.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
# If you don't have the sample PDF locally, download it (ignored by Git)
mkdir -p pdfs
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
ls pdfs/2004.12832v2.pdf
# Ensure output dir exists
mkdir -p pages
python multi-vector-leann-paper-example.py
```
Expected:
- Page images in `pages/`.
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
To use your own PDF: edit `pdf_path` near the top of the script.
### B) Similarity map + answer demo
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Artifacts (when enabled):
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
- Similarity maps: `./figures/similarity_map_rank{K}.png`
Key knobs in the script (top of file):
- `QUERY`: your question
- `MODEL`: `"colqwen2"` or `"colpali"`
- `USE_HF_DATASET`: set `False` to use local pages
- `PDF`, `PAGES_DIR`: for local mode
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
## Troubleshooting
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
## Notes
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
- For local PDFs, page images go to `./pages/`.
### Retrieval and Visualization Example
Example settings in `multi-vector-leann-similarity-map.py`:
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
- `SIMILARITY_MAP = True` (to generate heatmaps)
- `TOPK = 1` (save the top retrieved page and its similarity map)
Run:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Outputs (by default):
- Retrieved page: `./figures/retrieved_page_rank1.png`
- Similarity map: `./figures/similarity_map_rank1.png`
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
"):
![Similarity map example](fig/image.png)
Notes:
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 166 KiB

View File

@@ -0,0 +1,182 @@
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
def _ensure_repo_paths_importable(current_file: str) -> None:
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
_ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
class LeannMultiVector:
def __init__(
self,
index_path: str,
dim: int = 128,
distance_metric: str = "mips",
m: int = 16,
ef_construction: int = 500,
is_compact: bool = False,
is_recompute: bool = False,
embedding_model_name: str = "colvision",
) -> None:
self.index_path = index_path
self.dim = dim
self.embedding_model_name = embedding_model_name
self._pending_items: list[dict] = []
self._backend_kwargs = {
"distance_metric": distance_metric,
"M": m,
"efConstruction": ef_construction,
"is_compact": is_compact,
"is_recompute": is_recompute,
}
self._labels_meta: list[dict] = []
def _meta_dict(self) -> dict:
return {
"version": "1.0",
"backend_name": "hnsw",
"embedding_model": self.embedding_model_name,
"embedding_mode": "custom",
"dimensions": self.dim,
"backend_kwargs": self._backend_kwargs,
"is_compact": self._backend_kwargs.get("is_compact", True),
"is_pruned": self._backend_kwargs.get("is_compact", True)
and self._backend_kwargs.get("is_recompute", True),
}
def create_collection(self) -> None:
path = Path(self.index_path)
path.parent.mkdir(parents=True, exist_ok=True)
def insert(self, data: dict) -> None:
self._pending_items.append(
{
"doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
}
)
def _labels_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
def _meta_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def create_index(self) -> None:
if not self._pending_items:
return
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
for item in self._pending_items:
doc_id = int(item["doc_id"])
filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"]
for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np)
labels_meta.append(
{
"id": f"{doc_id}:{seq_id}",
"doc_id": doc_id,
"seq_id": int(seq_id),
"filepath": filepath,
}
)
if not embeddings:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
# print shape of embeddings_np
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
ids = [str(i) for i in range(embeddings_np.shape[0])]
builder.build(embeddings_np, ids, self.index_path)
import json as _json
with open(self._meta_path(), "w", encoding="utf-8") as f:
_json.dump(self._meta_dict(), f, indent=2)
with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f)
self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None:
if self._labels_meta:
return
labels_path = self._labels_path()
if labels_path.exists():
import json as _json
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
distances = raw.get("distances")
if labels is None or distances is None:
return []
doc_scores: dict[int, float] = {}
B = len(labels)
for b in range(B):
per_doc_best: dict[int, float] = {}
for k, sid in enumerate(labels[b]):
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
else:
continue
score = float(distances[b][k])
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
per_doc_best[doc_id] = score
for doc_id, best_score in per_doc_best.items():
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores

View File

@@ -0,0 +1,112 @@
# pip install pdf2image
# pip install pymilvus
# pip install colpali_engine
# pip install tqdm
# pip install pillow
import os
import re
import sys
from pathlib import Path
from typing import cast
from PIL import Image
from tqdm import tqdm
# Ensure local leann packages are importable before importing them
_repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
import torch
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
# Auto-select device: CUDA > MPS (mac) > CPU
_device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(_device_str)
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=_dtype,
device_map=device,
).eval()
print(f"Using device={_device_str}, dtype={_dtype}")
queries = [
"How to end-to-end retrieval with ColBert",
"Where is ColBERT performance Table, including text representation results?",
]
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: list[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
print(qs[0].shape)
# %%
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: list[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
print(ds[0].shape)
# %%
# Build HNSW index via LeannRetriever primitives and run search
index_path = "./indexes/colpali.leann"
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
retriever.create_collection()
filepaths = [os.path.join("./pages", name) for name in page_filenames]
for i in range(len(filepaths)):
data = {
"colbert_vecs": ds[i].float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
for query in qs:
query_np = query.float().numpy()
result = retriever.search(query_np, topk=1)
print(filepaths[result[0][1]])

View File

@@ -0,0 +1,477 @@
## Jupyter-style notebook script
# %%
# uv pip install matplotlib qwen_vl_utils
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
_ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
DATASET_SPLIT: str = "train"
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False)
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 1
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
# Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 128
# %%
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists():
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# %%
# Step 1: Prepare data
if USE_HF_DATASET:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N ):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
print(identifier)
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
# %%
# Step 2: Load model and processor
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %%
# %%
# Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX:
try:
one_vec = _embed_images(model, processor, [images[0]])[0]
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
except Exception:
retriever = None
if retriever is None:
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
# %%
# Step 4: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1):
path = filepaths[doc_id]
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(images[doc_id])
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
base = _Path(SAVE_TOP_IMAGE)
base.parent.mkdir(parents=True, exist_ok=True)
for rank, img in enumerate(top_images[:TOPK], start=1):
if base.suffix:
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
else:
out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path))
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 5: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
for rank, img in enumerate(top_images[:TOPK], start=1):
if output_base:
if output_base.suffix:
out_dir = output_base.parent
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
out_path = str(out_dir / out_name)
else:
out_dir = output_base
out_dir.mkdir(parents=True, exist_ok=True)
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
else:
out_path = None
chosen_idx, max_sim = _generate_similarity_map(
model=model,
processor=processor,
image=img,
query=QUERY,
token_idx=token_idx,
output_path=out_path,
)
if out_path:
print(
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
)
else:
print(
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
)
# %%
# Step 6: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
print("\nAnswer:")
print(response)

View File

@@ -0,0 +1,183 @@
#!/usr/bin/env python3
import re
import sys
from datetime import datetime, timedelta
from pathlib import Path
from leann import LeannSearcher
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
class TimeParser:
def __init__(self):
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
# Compile for performance
self.regex = re.compile(self.pattern, re.IGNORECASE)
# Stop words to remove before regex parsing
self.stop_words = {
"in",
"at",
"of",
"by",
"as",
"me",
"the",
"a",
"an",
"and",
"any",
"find",
"search",
"list",
"ago",
"back",
"past",
"earlier",
}
def clean_text(self, text):
"""Remove stop words from text"""
words = text.split()
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
return cleaned
def parse(self, text):
"""Extract all time expressions from text"""
# Clean text first
cleaned_text = self.clean_text(text)
matches = []
for match in self.regex.finditer(cleaned_text):
fuzzy = match.group(1) # "around", "about", etc.
number = int(match.group(2))
unit = match.group(3).lower()
matches.append(
{
"full_match": match.group(0),
"fuzzy": bool(fuzzy),
"number": number,
"unit": unit,
"range": self.calculate_range(number, unit, bool(fuzzy)),
}
)
return matches
def calculate_range(self, number, unit, is_fuzzy):
"""Convert to actual datetime range and return ISO format strings"""
units = {
"hour": timedelta(hours=number),
"day": timedelta(days=number),
"week": timedelta(weeks=number),
"month": timedelta(days=number * 30),
"year": timedelta(days=number * 365),
}
delta = units[unit]
now = datetime.now()
target = now - delta
if is_fuzzy:
buffer = delta * 0.2 # 20% buffer for fuzzy
start = (target - buffer).isoformat()
end = (target + buffer).isoformat()
else:
start = target.isoformat()
end = now.isoformat()
return (start, end)
def search_files(query, top_k=15):
"""Search the index and return results"""
# Parse time expressions
parser = TimeParser()
time_matches = parser.parse(query)
# Remove time expressions from query for semantic search
clean_query = query
if time_matches:
for match in time_matches:
clean_query = clean_query.replace(match["full_match"], "").strip()
# Check if clean_query is less than 4 characters
if len(clean_query) < 4:
print("Error: add more input for accurate results.")
return
# Single query to vector DB
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search(
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
)
# Filter by time if time expression found
if time_matches:
time_range = time_matches[0]["range"] # Use first time expression
start_time, end_time = time_range
filtered_results = []
for result in results:
# Access metadata attribute directly (not .get())
metadata = result.metadata if hasattr(result, "metadata") else {}
if metadata:
# Check modification date first, fall back to creation date
date_str = metadata.get("modification_date") or metadata.get("creation_date")
if date_str:
# Convert strings to datetime objects for proper comparison
try:
file_date = datetime.fromisoformat(date_str)
start_dt = datetime.fromisoformat(start_time)
end_dt = datetime.fromisoformat(end_time)
# Compare dates properly
if start_dt <= file_date <= end_dt:
filtered_results.append(result)
except (ValueError, TypeError):
# Handle invalid date formats
print(f"Warning: Invalid date format in metadata: {date_str}")
continue
results = filtered_results
# Print results
print(f"\nSearch results for: '{query}'")
if time_matches:
print(
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
)
print(
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
)
print("-" * 80)
for i, result in enumerate(results, 1):
print(f"\n[{i}] Score: {result.score:.4f}")
print(f"Content: {result.text}")
# Show metadata if present
metadata = result.metadata if hasattr(result, "metadata") else None
if metadata:
if "creation_date" in metadata:
print(f"Created: {metadata['creation_date']}")
if "modification_date" in metadata:
print(f"Modified: {metadata['modification_date']}")
print("-" * 80)
if __name__ == "__main__":
if len(sys.argv) < 2:
print('Usage: python search_index.py "<search query>" [top_k]')
sys.exit(1)
query = sys.argv[1]
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
search_files(query, top_k)

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
import json
import sys
from pathlib import Path
from leann import LeannBuilder
def process_json_items(json_file_path):
"""Load and process JSON file with metadata items"""
with open(json_file_path, encoding="utf-8") as f:
items = json.load(f)
# Guard against empty JSON
if not items:
print("⚠️ No items found in the JSON file. Exiting gracefully.")
return
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
total_items = len(items)
items_added = 0
print(f"Processing {total_items} items...")
for idx, item in enumerate(items):
try:
# Create embedding text sentence
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
# Prepare metadata with dates
metadata = {}
if "CreationDate" in item:
metadata["creation_date"] = item["CreationDate"]
if "ContentChangeDate" in item:
metadata["modification_date"] = item["ContentChangeDate"]
# Add to builder
builder.add_text(embedding_text, metadata=metadata)
items_added += 1
except Exception as e:
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
continue
# Show progress
progress = (idx + 1) / total_items * 100
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
sys.stdout.flush()
print() # New line after progress
# Guard against no successfully added items
if items_added == 0:
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
return
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
print("Building index...")
try:
builder.build_index(INDEX_PATH)
print(f"✓ Index saved to {INDEX_PATH}")
except ValueError as e:
if "No chunks added" in str(e):
print("⚠️ No chunks were added to the builder. Index not created.")
else:
raise
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python build_index.py <json_file>")
sys.exit(1)
json_file = sys.argv[1]
if not Path(json_file).exists():
print(f"Error: File {json_file} not found")
sys.exit(1)
process_json_items(json_file)

View File

@@ -0,0 +1,265 @@
#!/usr/bin/env python3
"""
Spotlight Metadata Dumper for Vector DB
Extracts only essential metadata for semantic search embeddings
Output is optimized for vector database storage with minimal fields
"""
import json
import sys
from datetime import datetime
# Check platform before importing macOS-specific modules
if sys.platform != "darwin":
print("This script requires macOS (uses Spotlight)")
sys.exit(1)
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
# EDIT THIS LIST: Add or remove folders to search
# Can be either:
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
# - Absolute paths (e.g., "/Applications", "/System/Library")
SEARCH_FOLDERS = [
"Desktop",
"Downloads",
"Documents",
"Music",
"Pictures",
"Movies",
# "Library", # Uncomment to include
# "/Applications", # Absolute path example
# "Code/Projects", # Subfolder example
# Add any other folders here
]
def convert_to_serializable(obj):
"""Convert NS objects to Python serializable types"""
if obj is None:
return None
# Handle NSDate
if hasattr(obj, "timeIntervalSince1970"):
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
# Handle NSArray
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
# Convert to string
try:
return str(obj)
except Exception:
return repr(obj)
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
"""
Dump Spotlight data using public.item predicate
"""
# Build full paths from SEARCH_FOLDERS
import os
home_dir = os.path.expanduser("~")
search_paths = []
print("Search locations:")
for folder in SEARCH_FOLDERS:
# Check if it's an absolute path or relative
if folder.startswith("/"):
full_path = folder
else:
full_path = os.path.join(home_dir, folder)
if os.path.exists(full_path):
search_paths.append(full_path)
print(f"{full_path}")
else:
print(f"{full_path} (not found)")
if not search_paths:
print("No valid search paths found!")
return []
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
# Create query with public.item predicate
query = NSMetadataQuery.alloc().init()
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
query.setPredicate_(predicate)
# Set search scopes to our specific folders
query.setSearchScopes_(search_paths)
print("Starting query...")
query.startQuery()
# Wait for gathering to complete
run_loop = NSRunLoop.currentRunLoop()
print("Gathering results...")
# Let it gather for a few seconds
for i in range(50): # 5 seconds max
run_loop.runMode_beforeDate_(
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
)
# Check gathering status periodically
if i % 10 == 0:
current_count = query.resultCount()
if current_count > 0:
print(f" Found {current_count} items so far...")
# Continue while still gathering (up to 2 more seconds)
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
run_loop.runMode_beforeDate_(
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
)
query.stopQuery()
total_results = query.resultCount()
print(f"Found {total_results} total items")
if total_results == 0:
print("No results found")
return []
# Process items
items_to_process = min(total_results, max_items)
results = []
# ONLY relevant attributes for vector embeddings
# These provide essential context for semantic search without bloat
attributes = [
"kMDItemPath", # Full path for file retrieval
"kMDItemFSName", # Filename for display & embedding
"kMDItemFSSize", # Size for filtering/ranking
"kMDItemContentType", # File type for categorization
"kMDItemKind", # Human-readable type for embedding
"kMDItemFSCreationDate", # Temporal context
"kMDItemFSContentChangeDate", # Recency for ranking
]
print(f"Processing {items_to_process} items...")
for i in range(items_to_process):
try:
item = query.resultAtIndex_(i)
metadata = {}
# Extract ONLY the relevant attributes
for attr in attributes:
try:
value = item.valueForAttribute_(attr)
if value is not None:
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
clean_key = attr.replace("kMDItem", "").replace("FS", "")
metadata[clean_key] = convert_to_serializable(value)
except (AttributeError, ValueError, TypeError):
continue
# Only add if we have at least a path
if metadata.get("Path"):
results.append(metadata)
except Exception as e:
print(f"Error processing item {i}: {e}")
continue
# Save to JSON
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\n✓ Saved {len(results)} items to {output_file}")
# Show summary
print("\nSample items:")
import os
home_dir = os.path.expanduser("~")
for i, item in enumerate(results[:3]):
print(f"\n[Item {i + 1}]")
print(f" Path: {item.get('Path', 'N/A')}")
print(f" Name: {item.get('Name', 'N/A')}")
print(f" Type: {item.get('ContentType', 'N/A')}")
print(f" Kind: {item.get('Kind', 'N/A')}")
# Handle size properly
size = item.get("Size")
if size:
try:
size_int = int(size)
if size_int > 1024 * 1024:
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
elif size_int > 1024:
print(f" Size: {size_int / 1024:.2f} KB")
else:
print(f" Size: {size_int} bytes")
except (ValueError, TypeError):
print(f" Size: {size}")
# Show dates
if "CreationDate" in item:
print(f" Created: {item['CreationDate']}")
if "ContentChangeDate" in item:
print(f" Modified: {item['ContentChangeDate']}")
# Count by type
type_counts = {}
for item in results:
content_type = item.get("ContentType", "unknown")
type_counts[content_type] = type_counts.get(content_type, 0) + 1
print(f"\nTotal items saved: {len(results)}")
if type_counts:
print("\nTop content types:")
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
print(f" {ct}: {count} items")
# Count by folder
folder_counts = {}
for item in results:
path = item.get("Path", "")
for folder in SEARCH_FOLDERS:
# Build the full folder path
if folder.startswith("/"):
folder_path = folder
else:
folder_path = os.path.join(home_dir, folder)
if path.startswith(folder_path):
folder_counts[folder] = folder_counts.get(folder, 0) + 1
break
if folder_counts:
print("\nItems by location:")
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
print(f" {folder}: {count} items")
return results
def main():
# Parse arguments
if len(sys.argv) > 1:
try:
max_items = int(sys.argv[1])
except ValueError:
print("Usage: python spot.py [number_of_items]")
print("Default: 10 items")
sys.exit(1)
else:
max_items = 10
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
# Run dump
dump_spotlight_data(max_items=max_items, output_file=output_file)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1 @@
# Slack MCP data integration for LEANN

View File

@@ -0,0 +1,334 @@
#!/usr/bin/env python3
"""
Slack MCP Reader for LEANN
This module provides functionality to connect to Slack MCP servers and fetch message data
for indexing in LEANN. It supports various Slack MCP server implementations and provides
flexible message processing options.
"""
import asyncio
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
class SlackMCPReader:
"""
Reader for Slack data via MCP (Model Context Protocol) servers.
This class connects to Slack MCP servers to fetch message data and convert it
into a format suitable for LEANN indexing.
"""
def __init__(
self,
mcp_server_command: str,
workspace_name: Optional[str] = None,
concatenate_conversations: bool = True,
max_messages_per_conversation: int = 100,
):
"""
Initialize the Slack MCP Reader.
Args:
mcp_server_command: Command to start the MCP server (e.g., 'slack-mcp-server')
workspace_name: Optional workspace name to filter messages
concatenate_conversations: Whether to group messages by channel/thread
max_messages_per_conversation: Maximum messages to include per conversation
"""
self.mcp_server_command = mcp_server_command
self.workspace_name = workspace_name
self.concatenate_conversations = concatenate_conversations
self.max_messages_per_conversation = max_messages_per_conversation
self.mcp_process = None
async def start_mcp_server(self):
"""Start the MCP server process."""
try:
self.mcp_process = await asyncio.create_subprocess_exec(
*self.mcp_server_command.split(),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
logger.info(f"Started MCP server: {self.mcp_server_command}")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}")
raise
async def stop_mcp_server(self):
"""Stop the MCP server process."""
if self.mcp_process:
self.mcp_process.terminate()
await self.mcp_process.wait()
logger.info("Stopped MCP server")
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send a request to the MCP server and get response."""
if not self.mcp_process:
raise RuntimeError("MCP server not started")
request_json = json.dumps(request) + "\n"
self.mcp_process.stdin.write(request_json.encode())
await self.mcp_process.stdin.drain()
response_line = await self.mcp_process.stdout.readline()
if not response_line:
raise RuntimeError("No response from MCP server")
return json.loads(response_line.decode().strip())
async def initialize_mcp_connection(self):
"""Initialize the MCP connection."""
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
},
}
response = await self.send_mcp_request(init_request)
if "error" in response:
raise RuntimeError(f"MCP initialization failed: {response['error']}")
logger.info("MCP connection initialized successfully")
async def list_available_tools(self) -> list[dict[str, Any]]:
"""List available tools from the MCP server."""
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
response = await self.send_mcp_request(list_request)
if "error" in response:
raise RuntimeError(f"Failed to list tools: {response['error']}")
return response.get("result", {}).get("tools", [])
async def fetch_slack_messages(
self, channel: Optional[str] = None, limit: int = 100
) -> list[dict[str, Any]]:
"""
Fetch Slack messages using MCP tools.
Args:
channel: Optional channel name to filter messages
limit: Maximum number of messages to fetch
Returns:
List of message dictionaries
"""
# This is a generic implementation - specific MCP servers may have different tool names
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
tools = await self.list_available_tools()
message_tool = None
# Look for a tool that can fetch messages
for tool in tools:
tool_name = tool.get("name", "").lower()
if any(
keyword in tool_name
for keyword in ["message", "history", "channel", "conversation"]
):
message_tool = tool
break
if not message_tool:
raise RuntimeError("No message fetching tool found in MCP server")
# Prepare tool call parameters
tool_params = {"limit": limit}
if channel:
# Try common parameter names for channel specification
for param_name in ["channel", "channel_id", "channel_name"]:
tool_params[param_name] = channel
break
fetch_request = {
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {"name": message_tool["name"], "arguments": tool_params},
}
response = await self.send_mcp_request(fetch_request)
if "error" in response:
raise RuntimeError(f"Failed to fetch messages: {response['error']}")
# Extract messages from response - format may vary by MCP server
result = response.get("result", {})
if "content" in result and isinstance(result["content"], list):
# Some MCP servers return content as a list
content = result["content"][0] if result["content"] else {}
if "text" in content:
try:
messages = json.loads(content["text"])
except json.JSONDecodeError:
# If not JSON, treat as plain text
messages = [{"text": content["text"], "channel": channel or "unknown"}]
else:
messages = result["content"]
else:
# Direct message format
messages = result.get("messages", [result])
return messages if isinstance(messages, list) else [messages]
def _format_message(self, message: dict[str, Any]) -> str:
"""Format a single message for indexing."""
text = message.get("text", "")
user = message.get("user", message.get("username", "Unknown"))
channel = message.get("channel", message.get("channel_name", "Unknown"))
timestamp = message.get("ts", message.get("timestamp", ""))
# Format timestamp if available
formatted_time = ""
if timestamp:
try:
import datetime
if isinstance(timestamp, str) and "." in timestamp:
dt = datetime.datetime.fromtimestamp(float(timestamp))
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(timestamp, (int, float)):
dt = datetime.datetime.fromtimestamp(timestamp)
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
else:
formatted_time = str(timestamp)
except (ValueError, TypeError):
formatted_time = str(timestamp)
# Build formatted message
parts = []
if channel:
parts.append(f"Channel: #{channel}")
if user:
parts.append(f"User: {user}")
if formatted_time:
parts.append(f"Time: {formatted_time}")
if text:
parts.append(f"Message: {text}")
return "\n".join(parts)
def _create_concatenated_content(self, messages: list[dict[str, Any]], channel: str) -> str:
"""Create concatenated content from multiple messages in a channel."""
if not messages:
return ""
# Sort messages by timestamp if available
try:
messages.sort(key=lambda x: float(x.get("ts", x.get("timestamp", 0))))
except (ValueError, TypeError):
pass # Keep original order if timestamps aren't numeric
# Limit messages per conversation
if len(messages) > self.max_messages_per_conversation:
messages = messages[-self.max_messages_per_conversation :]
# Create header
content_parts = [
f"Slack Channel: #{channel}",
f"Message Count: {len(messages)}",
f"Workspace: {self.workspace_name or 'Unknown'}",
"=" * 50,
"",
]
# Add messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
content_parts.append(formatted_msg)
content_parts.append("-" * 30)
content_parts.append("")
return "\n".join(content_parts)
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
"""
Read Slack data and return formatted text chunks.
Args:
channels: Optional list of channel names to fetch. If None, fetches from all available channels.
Returns:
List of formatted text chunks ready for LEANN indexing
"""
try:
await self.start_mcp_server()
await self.initialize_mcp_connection()
all_texts = []
if channels:
# Fetch specific channels
for channel in channels:
try:
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
if messages:
if self.concatenate_conversations:
text_content = self._create_concatenated_content(messages, channel)
if text_content.strip():
all_texts.append(text_content)
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
continue
else:
# Fetch from all available channels/conversations
# This is a simplified approach - real implementation would need to
# discover available channels first
try:
messages = await self.fetch_slack_messages(limit=1000)
if messages:
# Group messages by channel if concatenating
if self.concatenate_conversations:
channel_messages = {}
for message in messages:
channel = message.get(
"channel", message.get("channel_name", "general")
)
if channel not in channel_messages:
channel_messages[channel] = []
channel_messages[channel].append(message)
# Create concatenated content for each channel
for channel, msgs in channel_messages.items():
text_content = self._create_concatenated_content(msgs, channel)
if text_content.strip():
all_texts.append(text_content)
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.error(f"Failed to fetch messages: {e}")
return all_texts
finally:
await self.stop_mcp_server()
async def __aenter__(self):
"""Async context manager entry."""
await self.start_mcp_server()
await self.initialize_mcp_connection()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.stop_mcp_server()

206
apps/slack_rag.py Normal file
View File

@@ -0,0 +1,206 @@
#!/usr/bin/env python3
"""
Slack RAG Application with MCP Support
This application enables RAG (Retrieval-Augmented Generation) on Slack messages
by connecting to Slack MCP servers to fetch live data and index it in LEANN.
Usage:
python -m apps.slack_rag --mcp-server "slack-mcp-server" --query "What did the team discuss about the project?"
"""
import argparse
import asyncio
from apps.base_rag_example import BaseRAGExample
from apps.slack_data.slack_mcp_reader import SlackMCPReader
class SlackMCPRAG(BaseRAGExample):
"""
RAG application for Slack messages via MCP servers.
This class provides a complete RAG pipeline for Slack data, including
MCP server connection, data fetching, indexing, and interactive chat.
"""
def __init__(self):
super().__init__(
name="Slack MCP RAG",
description="RAG application for Slack messages via MCP servers",
default_index_name="slack_messages",
)
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add Slack MCP-specific arguments."""
parser.add_argument(
"--mcp-server",
type=str,
required=True,
help="Command to start the Slack MCP server (e.g., 'slack-mcp-server' or 'npx slack-mcp-server')",
)
parser.add_argument(
"--workspace-name",
type=str,
help="Slack workspace name for better organization and filtering",
)
parser.add_argument(
"--channels",
nargs="+",
help="Specific Slack channels to index (e.g., general random). If not specified, fetches from all available channels",
)
parser.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Group messages by channel/thread for better context (default: True)",
)
parser.add_argument(
"--no-concatenate-conversations",
action="store_true",
help="Process individual messages instead of grouping by channel",
)
parser.add_argument(
"--max-messages-per-channel",
type=int,
default=100,
help="Maximum number of messages to include per channel (default: 100)",
)
parser.add_argument(
"--test-connection",
action="store_true",
help="Test MCP server connection and list available tools without indexing",
)
async def test_mcp_connection(self, args) -> bool:
"""Test the MCP server connection and display available tools."""
print(f"Testing connection to MCP server: {args.mcp_server}")
try:
reader = SlackMCPReader(
mcp_server_command=args.mcp_server,
workspace_name=args.workspace_name,
concatenate_conversations=not args.no_concatenate_conversations,
max_messages_per_conversation=args.max_messages_per_channel,
)
async with reader:
tools = await reader.list_available_tools()
print("\n✅ Successfully connected to MCP server!")
print(f"Available tools ({len(tools)}):")
for i, tool in enumerate(tools, 1):
name = tool.get("name", "Unknown")
description = tool.get("description", "No description available")
print(f"\n{i}. {name}")
print(
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
)
# Show input schema if available
schema = tool.get("inputSchema", {})
if schema.get("properties"):
props = list(schema["properties"].keys())[:3] # Show first 3 properties
print(
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
)
return True
except Exception as e:
print(f"\n❌ Failed to connect to MCP server: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure the MCP server is installed and accessible")
print("2. Check if the server command is correct")
print("3. Ensure you have proper authentication/credentials configured")
print("4. Try running the MCP server command directly to test it")
return False
async def load_data(self, args) -> list[str]:
"""Load Slack messages via MCP server."""
print(f"Connecting to Slack MCP server: {args.mcp_server}")
if args.workspace_name:
print(f"Workspace: {args.workspace_name}")
if args.channels:
print(f"Channels: {', '.join(args.channels)}")
else:
print("Fetching from all available channels")
concatenate = not args.no_concatenate_conversations
print(
f"Processing mode: {'Concatenated conversations' if concatenate else 'Individual messages'}"
)
try:
reader = SlackMCPReader(
mcp_server_command=args.mcp_server,
workspace_name=args.workspace_name,
concatenate_conversations=concatenate,
max_messages_per_conversation=args.max_messages_per_channel,
)
texts = await reader.read_slack_data(channels=args.channels)
if not texts:
print("❌ No messages found! This could mean:")
print("- The MCP server couldn't fetch messages")
print("- The specified channels don't exist or are empty")
print("- Authentication issues with the Slack workspace")
return []
print(f"✅ Successfully loaded {len(texts)} text chunks from Slack")
# Show sample of what was loaded
if texts:
sample_text = texts[0][:200] + "..." if len(texts[0]) > 200 else texts[0]
print("\nSample content:")
print("-" * 40)
print(sample_text)
print("-" * 40)
return texts
except Exception as e:
print(f"❌ Error loading Slack data: {e}")
print("\nThis might be due to:")
print("- MCP server connection issues")
print("- Authentication problems")
print("- Network connectivity issues")
print("- Incorrect channel names")
raise
async def run(self):
"""Main entry point with MCP connection testing."""
args = self.parser.parse_args()
# Test connection if requested
if args.test_connection:
success = await self.test_mcp_connection(args)
if not success:
return
print(
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
)
return
# Run the standard RAG pipeline
await super().run()
async def main():
"""Main entry point for the Slack MCP RAG application."""
app = SlackMCPRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1 @@
# Twitter MCP data integration for LEANN

View File

@@ -0,0 +1,295 @@
#!/usr/bin/env python3
"""
Twitter MCP Reader for LEANN
This module provides functionality to connect to Twitter MCP servers and fetch bookmark data
for indexing in LEANN. It supports various Twitter MCP server implementations and provides
flexible bookmark processing options.
"""
import asyncio
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
class TwitterMCPReader:
"""
Reader for Twitter bookmark data via MCP (Model Context Protocol) servers.
This class connects to Twitter MCP servers to fetch bookmark data and convert it
into a format suitable for LEANN indexing.
"""
def __init__(
self,
mcp_server_command: str,
username: Optional[str] = None,
include_tweet_content: bool = True,
include_metadata: bool = True,
max_bookmarks: int = 1000,
):
"""
Initialize the Twitter MCP Reader.
Args:
mcp_server_command: Command to start the MCP server (e.g., 'twitter-mcp-server')
username: Optional Twitter username to filter bookmarks
include_tweet_content: Whether to include full tweet content
include_metadata: Whether to include tweet metadata (likes, retweets, etc.)
max_bookmarks: Maximum number of bookmarks to fetch
"""
self.mcp_server_command = mcp_server_command
self.username = username
self.include_tweet_content = include_tweet_content
self.include_metadata = include_metadata
self.max_bookmarks = max_bookmarks
self.mcp_process = None
async def start_mcp_server(self):
"""Start the MCP server process."""
try:
self.mcp_process = await asyncio.create_subprocess_exec(
*self.mcp_server_command.split(),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
logger.info(f"Started MCP server: {self.mcp_server_command}")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}")
raise
async def stop_mcp_server(self):
"""Stop the MCP server process."""
if self.mcp_process:
self.mcp_process.terminate()
await self.mcp_process.wait()
logger.info("Stopped MCP server")
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send a request to the MCP server and get response."""
if not self.mcp_process:
raise RuntimeError("MCP server not started")
request_json = json.dumps(request) + "\n"
self.mcp_process.stdin.write(request_json.encode())
await self.mcp_process.stdin.drain()
response_line = await self.mcp_process.stdout.readline()
if not response_line:
raise RuntimeError("No response from MCP server")
return json.loads(response_line.decode().strip())
async def initialize_mcp_connection(self):
"""Initialize the MCP connection."""
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "leann-twitter-reader", "version": "1.0.0"},
},
}
response = await self.send_mcp_request(init_request)
if "error" in response:
raise RuntimeError(f"MCP initialization failed: {response['error']}")
logger.info("MCP connection initialized successfully")
async def list_available_tools(self) -> list[dict[str, Any]]:
"""List available tools from the MCP server."""
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
response = await self.send_mcp_request(list_request)
if "error" in response:
raise RuntimeError(f"Failed to list tools: {response['error']}")
return response.get("result", {}).get("tools", [])
async def fetch_twitter_bookmarks(self, limit: Optional[int] = None) -> list[dict[str, Any]]:
"""
Fetch Twitter bookmarks using MCP tools.
Args:
limit: Maximum number of bookmarks to fetch
Returns:
List of bookmark dictionaries
"""
tools = await self.list_available_tools()
bookmark_tool = None
# Look for a tool that can fetch bookmarks
for tool in tools:
tool_name = tool.get("name", "").lower()
if any(keyword in tool_name for keyword in ["bookmark", "saved", "favorite"]):
bookmark_tool = tool
break
if not bookmark_tool:
raise RuntimeError("No bookmark fetching tool found in MCP server")
# Prepare tool call parameters
tool_params = {}
if limit or self.max_bookmarks:
tool_params["limit"] = limit or self.max_bookmarks
if self.username:
tool_params["username"] = self.username
fetch_request = {
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {"name": bookmark_tool["name"], "arguments": tool_params},
}
response = await self.send_mcp_request(fetch_request)
if "error" in response:
raise RuntimeError(f"Failed to fetch bookmarks: {response['error']}")
# Extract bookmarks from response
result = response.get("result", {})
if "content" in result and isinstance(result["content"], list):
content = result["content"][0] if result["content"] else {}
if "text" in content:
try:
bookmarks = json.loads(content["text"])
except json.JSONDecodeError:
# If not JSON, treat as plain text
bookmarks = [{"text": content["text"], "source": "twitter"}]
else:
bookmarks = result["content"]
else:
bookmarks = result.get("bookmarks", result.get("tweets", [result]))
return bookmarks if isinstance(bookmarks, list) else [bookmarks]
def _format_bookmark(self, bookmark: dict[str, Any]) -> str:
"""Format a single bookmark for indexing."""
# Extract tweet information
text = bookmark.get("text", bookmark.get("content", ""))
author = bookmark.get(
"author", bookmark.get("username", bookmark.get("user", {}).get("username", "Unknown"))
)
timestamp = bookmark.get("created_at", bookmark.get("timestamp", ""))
url = bookmark.get("url", bookmark.get("tweet_url", ""))
# Extract metadata if available
likes = bookmark.get("likes", bookmark.get("favorite_count", 0))
retweets = bookmark.get("retweets", bookmark.get("retweet_count", 0))
replies = bookmark.get("replies", bookmark.get("reply_count", 0))
# Build formatted bookmark
parts = []
# Header
parts.append("=== Twitter Bookmark ===")
if author:
parts.append(f"Author: @{author}")
if timestamp:
# Format timestamp if it's a standard format
try:
import datetime
if "T" in str(timestamp): # ISO format
dt = datetime.datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
else:
formatted_time = str(timestamp)
parts.append(f"Date: {formatted_time}")
except (ValueError, TypeError):
parts.append(f"Date: {timestamp}")
if url:
parts.append(f"URL: {url}")
# Tweet content
if text and self.include_tweet_content:
parts.append("")
parts.append("Content:")
parts.append(text)
# Metadata
if self.include_metadata and any([likes, retweets, replies]):
parts.append("")
parts.append("Engagement:")
if likes:
parts.append(f" Likes: {likes}")
if retweets:
parts.append(f" Retweets: {retweets}")
if replies:
parts.append(f" Replies: {replies}")
# Extract hashtags and mentions if available
hashtags = bookmark.get("hashtags", [])
mentions = bookmark.get("mentions", [])
if hashtags or mentions:
parts.append("")
if hashtags:
parts.append(f"Hashtags: {', '.join(hashtags)}")
if mentions:
parts.append(f"Mentions: {', '.join(mentions)}")
return "\n".join(parts)
async def read_twitter_bookmarks(self) -> list[str]:
"""
Read Twitter bookmark data and return formatted text chunks.
Returns:
List of formatted text chunks ready for LEANN indexing
"""
try:
await self.start_mcp_server()
await self.initialize_mcp_connection()
print(f"Fetching up to {self.max_bookmarks} bookmarks...")
if self.username:
print(f"Filtering for user: @{self.username}")
bookmarks = await self.fetch_twitter_bookmarks()
if not bookmarks:
print("No bookmarks found")
return []
print(f"Processing {len(bookmarks)} bookmarks...")
all_texts = []
processed_count = 0
for bookmark in bookmarks:
try:
formatted_bookmark = self._format_bookmark(bookmark)
if formatted_bookmark.strip():
all_texts.append(formatted_bookmark)
processed_count += 1
except Exception as e:
logger.warning(f"Failed to format bookmark: {e}")
continue
print(f"Successfully processed {processed_count} bookmarks")
return all_texts
finally:
await self.stop_mcp_server()
async def __aenter__(self):
"""Async context manager entry."""
await self.start_mcp_server()
await self.initialize_mcp_connection()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.stop_mcp_server()

195
apps/twitter_rag.py Normal file
View File

@@ -0,0 +1,195 @@
#!/usr/bin/env python3
"""
Twitter RAG Application with MCP Support
This application enables RAG (Retrieval-Augmented Generation) on Twitter bookmarks
by connecting to Twitter MCP servers to fetch live data and index it in LEANN.
Usage:
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --query "What articles did I bookmark about AI?"
"""
import argparse
import asyncio
from apps.base_rag_example import BaseRAGExample
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
class TwitterMCPRAG(BaseRAGExample):
"""
RAG application for Twitter bookmarks via MCP servers.
This class provides a complete RAG pipeline for Twitter bookmark data, including
MCP server connection, data fetching, indexing, and interactive chat.
"""
def __init__(self):
super().__init__(
name="Twitter MCP RAG",
description="RAG application for Twitter bookmarks via MCP servers",
default_index_name="twitter_bookmarks",
)
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add Twitter MCP-specific arguments."""
parser.add_argument(
"--mcp-server",
type=str,
required=True,
help="Command to start the Twitter MCP server (e.g., 'twitter-mcp-server' or 'npx twitter-mcp-server')",
)
parser.add_argument(
"--username", type=str, help="Twitter username to filter bookmarks (without @)"
)
parser.add_argument(
"--max-bookmarks",
type=int,
default=1000,
help="Maximum number of bookmarks to fetch (default: 1000)",
)
parser.add_argument(
"--no-tweet-content",
action="store_true",
help="Exclude tweet content, only include metadata",
)
parser.add_argument(
"--no-metadata",
action="store_true",
help="Exclude engagement metadata (likes, retweets, etc.)",
)
parser.add_argument(
"--test-connection",
action="store_true",
help="Test MCP server connection and list available tools without indexing",
)
async def test_mcp_connection(self, args) -> bool:
"""Test the MCP server connection and display available tools."""
print(f"Testing connection to MCP server: {args.mcp_server}")
try:
reader = TwitterMCPReader(
mcp_server_command=args.mcp_server,
username=args.username,
include_tweet_content=not args.no_tweet_content,
include_metadata=not args.no_metadata,
max_bookmarks=args.max_bookmarks,
)
async with reader:
tools = await reader.list_available_tools()
print("\n✅ Successfully connected to MCP server!")
print(f"Available tools ({len(tools)}):")
for i, tool in enumerate(tools, 1):
name = tool.get("name", "Unknown")
description = tool.get("description", "No description available")
print(f"\n{i}. {name}")
print(
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
)
# Show input schema if available
schema = tool.get("inputSchema", {})
if schema.get("properties"):
props = list(schema["properties"].keys())[:3] # Show first 3 properties
print(
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
)
return True
except Exception as e:
print(f"\n❌ Failed to connect to MCP server: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure the Twitter MCP server is installed and accessible")
print("2. Check if the server command is correct")
print("3. Ensure you have proper Twitter API credentials configured")
print("4. Verify your Twitter account has bookmarks to fetch")
print("5. Try running the MCP server command directly to test it")
return False
async def load_data(self, args) -> list[str]:
"""Load Twitter bookmarks via MCP server."""
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
if args.username:
print(f"Username filter: @{args.username}")
print(f"Max bookmarks: {args.max_bookmarks}")
print(f"Include tweet content: {not args.no_tweet_content}")
print(f"Include metadata: {not args.no_metadata}")
try:
reader = TwitterMCPReader(
mcp_server_command=args.mcp_server,
username=args.username,
include_tweet_content=not args.no_tweet_content,
include_metadata=not args.no_metadata,
max_bookmarks=args.max_bookmarks,
)
texts = await reader.read_twitter_bookmarks()
if not texts:
print("❌ No bookmarks found! This could mean:")
print("- You don't have any bookmarks on Twitter")
print("- The MCP server couldn't access your bookmarks")
print("- Authentication issues with Twitter API")
print("- The username filter didn't match any bookmarks")
return []
print(f"✅ Successfully loaded {len(texts)} bookmarks from Twitter")
# Show sample of what was loaded
if texts:
sample_text = texts[0][:300] + "..." if len(texts[0]) > 300 else texts[0]
print("\nSample bookmark:")
print("-" * 50)
print(sample_text)
print("-" * 50)
return texts
except Exception as e:
print(f"❌ Error loading Twitter bookmarks: {e}")
print("\nThis might be due to:")
print("- MCP server connection issues")
print("- Twitter API authentication problems")
print("- Network connectivity issues")
print("- Rate limiting from Twitter API")
raise
async def run(self):
"""Main entry point with MCP connection testing."""
args = self.parser.parse_args()
# Test connection if requested
if args.test_connection:
success = await self.test_mcp_connection(args)
if not success:
return
print(
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
)
return
# Run the standard RAG pipeline
await super().run()
async def main():
"""Main entry point for the Twitter MCP RAG application."""
app = TwitterMCPRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())