Compare commits
8 Commits
colqwen
...
feature/ch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f75c0fd19a | ||
|
|
d18adadf58 | ||
|
|
f52bce23c3 | ||
|
|
68333d1837 | ||
|
|
f1355b70d8 | ||
|
|
2dd4147de2 | ||
|
|
be17980114 | ||
|
|
5f7806e16f |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -18,7 +18,6 @@ demo/experiment_results/**/*.json
|
||||
*.eml
|
||||
*.emlx
|
||||
*.json
|
||||
*.png
|
||||
!.vscode/*.json
|
||||
*.sh
|
||||
*.txt
|
||||
@@ -102,6 +101,3 @@ CLAUDE.local.md
|
||||
.claude/*.local.*
|
||||
.claude/local/*
|
||||
benchmarks/data/
|
||||
|
||||
## multi vector
|
||||
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||
|
||||
150
README.md
150
README.md
@@ -176,7 +176,7 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, ChatGPT conversations, Claude conversations, and more.
|
||||
|
||||
|
||||
|
||||
@@ -477,6 +477,154 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### 🤖 ChatGPT Chat History: Your Personal AI Conversation Archive!
|
||||
|
||||
Transform your ChatGPT conversations into a searchable knowledge base! Search through all your ChatGPT discussions about coding, research, brainstorming, and more.
|
||||
|
||||
```bash
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_export.html --query "How do I create a list in Python?"
|
||||
```
|
||||
|
||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your ChatGPT discussions again.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Export ChatGPT Data</strong></summary>
|
||||
|
||||
**Step-by-step export process:**
|
||||
|
||||
1. **Sign in to ChatGPT**
|
||||
2. **Click your profile icon** in the top right corner
|
||||
3. **Navigate to Settings** → **Data Controls**
|
||||
4. **Click "Export"** under Export Data
|
||||
5. **Confirm the export** request
|
||||
6. **Download the ZIP file** from the email link (expires in 24 hours)
|
||||
7. **Extract or use directly** with LEANN
|
||||
|
||||
**Supported formats:**
|
||||
- `.html` files from ChatGPT exports
|
||||
- `.zip` archives from ChatGPT
|
||||
- Directories with multiple export files
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: ChatGPT-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--export-path PATH # Path to ChatGPT export file (.html/.zip) or directory (default: ./chatgpt_export)
|
||||
--separate-messages # Process each message separately instead of concatenated conversations
|
||||
--chunk-size N # Text chunk size (default: 512)
|
||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage with HTML export
|
||||
python -m apps.chatgpt_rag --export-path conversations.html
|
||||
|
||||
# Process ZIP archive from ChatGPT
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_export.zip
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_data.html --query "Python programming help"
|
||||
|
||||
# Process individual messages for fine-grained search
|
||||
python -m apps.chatgpt_rag --separate-messages --export-path chatgpt_export.html
|
||||
|
||||
# Process directory containing multiple exports
|
||||
python -m apps.chatgpt_rag --export-path ./chatgpt_exports/ --max-items 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your ChatGPT conversations are indexed, you can search with queries like:
|
||||
- "What did I ask ChatGPT about Python programming?"
|
||||
- "Show me conversations about machine learning algorithms"
|
||||
- "Find discussions about web development frameworks"
|
||||
- "What coding advice did ChatGPT give me?"
|
||||
- "Search for conversations about debugging techniques"
|
||||
- "Find ChatGPT's recommendations for learning resources"
|
||||
|
||||
</details>
|
||||
|
||||
### 🤖 Claude Chat History: Your Personal AI Conversation Archive!
|
||||
|
||||
Transform your Claude conversations into a searchable knowledge base! Search through all your Claude discussions about coding, research, brainstorming, and more.
|
||||
|
||||
```bash
|
||||
python -m apps.claude_rag --export-path claude_export.json --query "What did I ask about Python dictionaries?"
|
||||
```
|
||||
|
||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your Claude discussions again.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Export Claude Data</strong></summary>
|
||||
|
||||
**Step-by-step export process:**
|
||||
|
||||
1. **Open Claude** in your browser
|
||||
2. **Navigate to Settings** (look for gear icon or settings menu)
|
||||
3. **Find Export/Download** options in your account settings
|
||||
4. **Download conversation data** (usually in JSON format)
|
||||
5. **Place the file** in your project directory
|
||||
|
||||
*Note: Claude export methods may vary depending on the interface you're using. Check Claude's help documentation for the most current export instructions.*
|
||||
|
||||
**Supported formats:**
|
||||
- `.json` files (recommended)
|
||||
- `.zip` archives containing JSON data
|
||||
- Directories with multiple export files
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Claude-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--export-path PATH # Path to Claude export file (.json/.zip) or directory (default: ./claude_export)
|
||||
--separate-messages # Process each message separately instead of concatenated conversations
|
||||
--chunk-size N # Text chunk size (default: 512)
|
||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage with JSON export
|
||||
python -m apps.claude_rag --export-path my_claude_conversations.json
|
||||
|
||||
# Process ZIP archive from Claude
|
||||
python -m apps.claude_rag --export-path claude_export.zip
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.claude_rag --export-path claude_data.json --query "machine learning advice"
|
||||
|
||||
# Process individual messages for fine-grained search
|
||||
python -m apps.claude_rag --separate-messages --export-path claude_export.json
|
||||
|
||||
# Process directory containing multiple exports
|
||||
python -m apps.claude_rag --export-path ./claude_exports/ --max-items 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your Claude conversations are indexed, you can search with queries like:
|
||||
- "What did I ask Claude about Python programming?"
|
||||
- "Show me conversations about machine learning algorithms"
|
||||
- "Find discussions about software architecture patterns"
|
||||
- "What debugging advice did Claude give me?"
|
||||
- "Search for conversations about data structures"
|
||||
- "Find Claude's recommendations for learning resources"
|
||||
|
||||
</details>
|
||||
|
||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||
|
||||
<details>
|
||||
|
||||
0
apps/chatgpt_data/__init__.py
Normal file
0
apps/chatgpt_data/__init__.py
Normal file
413
apps/chatgpt_data/chatgpt_reader.py
Normal file
413
apps/chatgpt_data/chatgpt_reader.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
ChatGPT export data reader.
|
||||
|
||||
Reads and processes ChatGPT export data from chat.html files.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ChatGPTReader(BaseReader):
|
||||
"""
|
||||
ChatGPT export data reader.
|
||||
|
||||
Reads ChatGPT conversation data from exported chat.html files or zip archives.
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa
|
||||
except ImportError:
|
||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
|
||||
"""
|
||||
Extract chat.html from ChatGPT export zip file.
|
||||
|
||||
Args:
|
||||
zip_path: Path to the ChatGPT export zip file
|
||||
|
||||
Returns:
|
||||
HTML content as string, or None if not found
|
||||
"""
|
||||
try:
|
||||
with ZipFile(zip_path, "r") as zip_file:
|
||||
# Look for chat.html or conversations.html
|
||||
html_files = [
|
||||
f
|
||||
for f in zip_file.namelist()
|
||||
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
|
||||
]
|
||||
|
||||
if not html_files:
|
||||
print(f"No HTML chat file found in {zip_path}")
|
||||
return None
|
||||
|
||||
# Use the first HTML file found
|
||||
html_file = html_files[0]
|
||||
print(f"Found HTML file: {html_file}")
|
||||
|
||||
with zip_file.open(html_file) as f:
|
||||
return f.read().decode("utf-8", errors="ignore")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting HTML from zip {zip_path}: {e}")
|
||||
return None
|
||||
|
||||
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
|
||||
"""
|
||||
Parse ChatGPT HTML export to extract conversations.
|
||||
|
||||
Args:
|
||||
html_content: HTML content from ChatGPT export
|
||||
|
||||
Returns:
|
||||
List of conversation dictionaries
|
||||
"""
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
conversations = []
|
||||
|
||||
# Try different possible structures for ChatGPT exports
|
||||
# Structure 1: Look for conversation containers
|
||||
conversation_containers = soup.find_all(
|
||||
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
|
||||
)
|
||||
|
||||
if not conversation_containers:
|
||||
# Structure 2: Look for message containers directly
|
||||
conversation_containers = [soup] # Use the entire document as one conversation
|
||||
|
||||
for container in conversation_containers:
|
||||
conversation = self._extract_conversation_from_container(container)
|
||||
if conversation and conversation.get("messages"):
|
||||
conversations.append(conversation)
|
||||
|
||||
# If no structured conversations found, try to extract all text as one conversation
|
||||
if not conversations:
|
||||
all_text = soup.get_text(separator="\n", strip=True)
|
||||
if all_text:
|
||||
conversations.append(
|
||||
{
|
||||
"title": "ChatGPT Conversation",
|
||||
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
|
||||
"timestamp": None,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations
|
||||
|
||||
def _extract_conversation_from_container(self, container) -> dict | None:
|
||||
"""
|
||||
Extract conversation data from a container element.
|
||||
|
||||
Args:
|
||||
container: BeautifulSoup element containing conversation
|
||||
|
||||
Returns:
|
||||
Dictionary with conversation data or None
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# Look for message elements with various possible structures
|
||||
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
|
||||
|
||||
for selector in message_selectors:
|
||||
message_elements = container.select(selector)
|
||||
if message_elements:
|
||||
break
|
||||
else:
|
||||
message_elements = []
|
||||
|
||||
# If no structured messages found, treat the entire container as one message
|
||||
if not message_elements:
|
||||
text_content = container.get_text(separator="\n", strip=True)
|
||||
if text_content:
|
||||
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
|
||||
else:
|
||||
for element in message_elements:
|
||||
message = self._extract_message_from_element(element)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Try to extract conversation title
|
||||
title_element = container.find(["h1", "h2", "h3", "title"])
|
||||
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
|
||||
|
||||
# Try to extract timestamp from various possible locations
|
||||
timestamp = self._extract_timestamp_from_container(container)
|
||||
|
||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||
|
||||
def _extract_message_from_element(self, element) -> dict | None:
|
||||
"""
|
||||
Extract message data from an element.
|
||||
|
||||
Args:
|
||||
element: BeautifulSoup element containing message
|
||||
|
||||
Returns:
|
||||
Dictionary with message data or None
|
||||
"""
|
||||
text_content = element.get_text(separator=" ", strip=True)
|
||||
|
||||
# Skip empty or very short messages
|
||||
if not text_content or len(text_content.strip()) < 3:
|
||||
return None
|
||||
|
||||
# Try to determine role (user/assistant) from class names or content
|
||||
role = "mixed" # Default role
|
||||
|
||||
class_names = " ".join(element.get("class", [])).lower()
|
||||
if "user" in class_names or "human" in class_names:
|
||||
role = "user"
|
||||
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
|
||||
role = "assistant"
|
||||
elif text_content.lower().startswith(("you:", "user:", "me:")):
|
||||
role = "user"
|
||||
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
|
||||
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
|
||||
role = "assistant"
|
||||
text_content = re.sub(
|
||||
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# Try to extract timestamp
|
||||
timestamp = self._extract_timestamp_from_element(element)
|
||||
|
||||
return {"role": role, "content": text_content, "timestamp": timestamp}
|
||||
|
||||
def _extract_timestamp_from_element(self, element) -> str | None:
|
||||
"""Extract timestamp from element."""
|
||||
# Look for timestamp in various attributes and child elements
|
||||
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
|
||||
for attr in timestamp_attrs:
|
||||
if element.get(attr):
|
||||
return element.get(attr)
|
||||
|
||||
# Look for time elements
|
||||
time_element = element.find("time")
|
||||
if time_element:
|
||||
return time_element.get("datetime") or time_element.get_text(strip=True)
|
||||
|
||||
# Look for date-like text patterns
|
||||
text = element.get_text()
|
||||
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
|
||||
|
||||
for pattern in date_patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
return None
|
||||
|
||||
def _extract_timestamp_from_container(self, container) -> str | None:
|
||||
"""Extract timestamp from conversation container."""
|
||||
return self._extract_timestamp_from_element(container)
|
||||
|
||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||
"""
|
||||
Create concatenated content from conversation messages.
|
||||
|
||||
Args:
|
||||
conversation: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
title = conversation.get("title", "ChatGPT Conversation")
|
||||
messages = conversation.get("messages", [])
|
||||
timestamp = conversation.get("timestamp", "Unknown")
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if role == "user":
|
||||
prefix = "[You]"
|
||||
elif role == "assistant":
|
||||
prefix = "[ChatGPT]"
|
||||
else:
|
||||
prefix = "[Message]"
|
||||
|
||||
# Add timestamp if available
|
||||
if msg_timestamp:
|
||||
prefix += f" ({msg_timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {content}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""Conversation: {title}
|
||||
Date: {timestamp}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load ChatGPT export data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing ChatGPT export files or path to specific file
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum number of conversations to process
|
||||
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
|
||||
include_metadata (bool): Whether to include metadata in documents
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", -1)
|
||||
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
|
||||
include_metadata = load_kwargs.get("include_metadata", True)
|
||||
|
||||
if not chatgpt_export_path:
|
||||
print("No ChatGPT export path provided")
|
||||
return docs
|
||||
|
||||
export_path = Path(chatgpt_export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"ChatGPT export path not found: {export_path}")
|
||||
return docs
|
||||
|
||||
html_content = None
|
||||
|
||||
# Handle different input types
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() == ".zip":
|
||||
# Extract HTML from zip file
|
||||
html_content = self._extract_html_from_zip(export_path)
|
||||
elif export_path.suffix.lower() == ".html":
|
||||
# Read HTML file directly
|
||||
try:
|
||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||
html_content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading HTML file {export_path}: {e}")
|
||||
return docs
|
||||
else:
|
||||
print(f"Unsupported file type: {export_path.suffix}")
|
||||
return docs
|
||||
|
||||
elif export_path.is_dir():
|
||||
# Look for HTML files in directory
|
||||
html_files = list(export_path.glob("*.html"))
|
||||
zip_files = list(export_path.glob("*.zip"))
|
||||
|
||||
if html_files:
|
||||
# Use first HTML file found
|
||||
html_file = html_files[0]
|
||||
print(f"Found HTML file: {html_file}")
|
||||
try:
|
||||
with open(html_file, encoding="utf-8", errors="ignore") as f:
|
||||
html_content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading HTML file {html_file}: {e}")
|
||||
return docs
|
||||
|
||||
elif zip_files:
|
||||
# Use first zip file found
|
||||
zip_file = zip_files[0]
|
||||
print(f"Found zip file: {zip_file}")
|
||||
html_content = self._extract_html_from_zip(zip_file)
|
||||
|
||||
else:
|
||||
print(f"No HTML or zip files found in {export_path}")
|
||||
return docs
|
||||
|
||||
if not html_content:
|
||||
print("No HTML content found to process")
|
||||
return docs
|
||||
|
||||
# Parse conversations from HTML
|
||||
print("Parsing ChatGPT conversations from HTML...")
|
||||
conversations = self._parse_chatgpt_html(html_content)
|
||||
|
||||
if not conversations:
|
||||
print("No conversations found in HTML content")
|
||||
return docs
|
||||
|
||||
print(f"Found {len(conversations)} conversations")
|
||||
|
||||
# Process conversations into documents
|
||||
count = 0
|
||||
for conversation in conversations:
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Create one document per conversation with concatenated messages
|
||||
doc_content = self._create_concatenated_content(conversation)
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"title": conversation.get("title", "ChatGPT Conversation"),
|
||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||
"message_count": len(conversation.get("messages", [])),
|
||||
"source": "ChatGPT Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
else:
|
||||
# Create separate documents for each message
|
||||
for message in conversation.get("messages", []):
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# Create document content with context
|
||||
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
|
||||
Role: {role}
|
||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||
Message: {content}
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
|
||||
"role": role,
|
||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||
"source": "ChatGPT Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(docs)} documents from ChatGPT export")
|
||||
return docs
|
||||
186
apps/chatgpt_rag.py
Normal file
186
apps/chatgpt_rag.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
ChatGPT RAG example using the unified interface.
|
||||
Supports ChatGPT export data from chat.html files.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .chatgpt_data.chatgpt_reader import ChatGPTReader
|
||||
|
||||
|
||||
class ChatGPTRAG(BaseRAGExample):
|
||||
"""RAG example for ChatGPT conversation data."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all conversations by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="ChatGPT",
|
||||
description="Process and query ChatGPT conversation exports with LEANN",
|
||||
default_index_name="chatgpt_conversations_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add ChatGPT-specific arguments."""
|
||||
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
|
||||
chatgpt_group.add_argument(
|
||||
"--export-path",
|
||||
type=str,
|
||||
default="./chatgpt_export",
|
||||
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--separate-messages",
|
||||
action="store_true",
|
||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
|
||||
"""
|
||||
Find ChatGPT export files in the given path.
|
||||
|
||||
Args:
|
||||
export_path: Path to search for exports
|
||||
|
||||
Returns:
|
||||
List of paths to ChatGPT export files
|
||||
"""
|
||||
export_files = []
|
||||
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() in [".zip", ".html"]:
|
||||
export_files.append(export_path)
|
||||
elif export_path.is_dir():
|
||||
# Look for zip and html files
|
||||
export_files.extend(export_path.glob("*.zip"))
|
||||
export_files.extend(export_path.glob("*.html"))
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load ChatGPT export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"ChatGPT export path not found: {export_path}")
|
||||
print(
|
||||
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
|
||||
)
|
||||
print("\nTo export your ChatGPT data:")
|
||||
print("1. Sign in to ChatGPT")
|
||||
print("2. Click on your profile icon → Settings → Data Controls")
|
||||
print("3. Click 'Export' under Export Data")
|
||||
print("4. Download the zip file from the email link")
|
||||
print("5. Extract or place the file/directory at the specified path")
|
||||
return []
|
||||
|
||||
# Find export files
|
||||
export_files = self._find_chatgpt_exports(export_path)
|
||||
|
||||
if not export_files:
|
||||
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(export_files)} ChatGPT export files")
|
||||
|
||||
# Create reader with appropriate settings
|
||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||
reader = ChatGPTReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Process each export file
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_file in enumerate(export_files):
|
||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per file
|
||||
max_per_file = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_file = remaining
|
||||
|
||||
# Load conversations
|
||||
documents = reader.load_data(
|
||||
chatgpt_export_path=str(export_file),
|
||||
max_count=max_per_file,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} conversations from this file")
|
||||
else:
|
||||
print(f"No conversations loaded from {export_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_file}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No conversations found to process!")
|
||||
print("\nTroubleshooting:")
|
||||
print("- Ensure the export file is a valid ChatGPT export")
|
||||
print("- Check that the HTML file contains conversation data")
|
||||
print("- Try extracting the zip file and pointing to the HTML file directly")
|
||||
return []
|
||||
|
||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||
print("Now starting to split into text chunks... this may take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for ChatGPT RAG
|
||||
print("\n🤖 ChatGPT RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did I ask about Python programming?'")
|
||||
print("- 'Show me conversations about machine learning'")
|
||||
print("- 'Find discussions about travel planning'")
|
||||
print("- 'What advice did ChatGPT give me about career development?'")
|
||||
print("- 'Search for conversations about cooking recipes'")
|
||||
print("\nTo get started:")
|
||||
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
|
||||
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
|
||||
print("3. Run this script to build your personal ChatGPT knowledge base!")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = ChatGPTRAG()
|
||||
asyncio.run(rag.run())
|
||||
0
apps/claude_data/__init__.py
Normal file
0
apps/claude_data/__init__.py
Normal file
420
apps/claude_data/claude_reader.py
Normal file
420
apps/claude_data/claude_reader.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Claude export data reader.
|
||||
|
||||
Reads and processes Claude conversation data from exported JSON files.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ClaudeReader(BaseReader):
|
||||
"""
|
||||
Claude export data reader.
|
||||
|
||||
Reads Claude conversation data from exported JSON files or zip archives.
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
|
||||
"""
|
||||
Extract JSON files from Claude export zip file.
|
||||
|
||||
Args:
|
||||
zip_path: Path to the Claude export zip file
|
||||
|
||||
Returns:
|
||||
List of JSON content strings, or empty list if not found
|
||||
"""
|
||||
json_contents = []
|
||||
try:
|
||||
with ZipFile(zip_path, "r") as zip_file:
|
||||
# Look for JSON files
|
||||
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
|
||||
|
||||
if not json_files:
|
||||
print(f"No JSON files found in {zip_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(json_files)} JSON files in archive")
|
||||
|
||||
for json_file in json_files:
|
||||
with zip_file.open(json_file) as f:
|
||||
content = f.read().decode("utf-8", errors="ignore")
|
||||
json_contents.append(content)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting JSON from zip {zip_path}: {e}")
|
||||
|
||||
return json_contents
|
||||
|
||||
def _parse_claude_json(self, json_content: str) -> list[dict]:
|
||||
"""
|
||||
Parse Claude JSON export to extract conversations.
|
||||
|
||||
Args:
|
||||
json_content: JSON content from Claude export
|
||||
|
||||
Returns:
|
||||
List of conversation dictionaries
|
||||
"""
|
||||
try:
|
||||
data = json.loads(json_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON: {e}")
|
||||
return []
|
||||
|
||||
conversations = []
|
||||
|
||||
# Handle different possible JSON structures
|
||||
if isinstance(data, list):
|
||||
# If data is a list of conversations
|
||||
for item in data:
|
||||
conversation = self._extract_conversation_from_json(item)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
elif isinstance(data, dict):
|
||||
# Check for common structures
|
||||
if "conversations" in data:
|
||||
# Structure: {"conversations": [...]}
|
||||
for item in data["conversations"]:
|
||||
conversation = self._extract_conversation_from_json(item)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
elif "messages" in data:
|
||||
# Single conversation with messages
|
||||
conversation = self._extract_conversation_from_json(data)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
else:
|
||||
# Try to treat the whole object as a conversation
|
||||
conversation = self._extract_conversation_from_json(data)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
|
||||
return conversations
|
||||
|
||||
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
|
||||
"""
|
||||
Extract conversation data from a JSON object.
|
||||
|
||||
Args:
|
||||
conv_data: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Dictionary with conversation data or None
|
||||
"""
|
||||
if not isinstance(conv_data, dict):
|
||||
return None
|
||||
|
||||
messages = []
|
||||
|
||||
# Look for messages in various possible structures
|
||||
message_sources = []
|
||||
if "messages" in conv_data:
|
||||
message_sources = conv_data["messages"]
|
||||
elif "chat" in conv_data:
|
||||
message_sources = conv_data["chat"]
|
||||
elif "conversation" in conv_data:
|
||||
message_sources = conv_data["conversation"]
|
||||
else:
|
||||
# If no clear message structure, try to extract from the object itself
|
||||
if "content" in conv_data and "role" in conv_data:
|
||||
message_sources = [conv_data]
|
||||
|
||||
for msg_data in message_sources:
|
||||
message = self._extract_message_from_json(msg_data)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Extract conversation metadata
|
||||
title = self._extract_title_from_conversation(conv_data, messages)
|
||||
timestamp = self._extract_timestamp_from_conversation(conv_data)
|
||||
|
||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||
|
||||
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
|
||||
"""
|
||||
Extract message data from a JSON message object.
|
||||
|
||||
Args:
|
||||
msg_data: Dictionary containing message data
|
||||
|
||||
Returns:
|
||||
Dictionary with message data or None
|
||||
"""
|
||||
if not isinstance(msg_data, dict):
|
||||
return None
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = ""
|
||||
content_fields = ["content", "text", "message", "body"]
|
||||
for field in content_fields:
|
||||
if msg_data.get(field):
|
||||
content = str(msg_data[field])
|
||||
break
|
||||
|
||||
if not content or len(content.strip()) < 3:
|
||||
return None
|
||||
|
||||
# Extract role (user/assistant/human/ai/claude)
|
||||
role = "mixed" # Default role
|
||||
role_fields = ["role", "sender", "from", "author", "type"]
|
||||
for field in role_fields:
|
||||
if msg_data.get(field):
|
||||
role_value = str(msg_data[field]).lower()
|
||||
if role_value in ["user", "human", "person"]:
|
||||
role = "user"
|
||||
elif role_value in ["assistant", "ai", "claude", "bot"]:
|
||||
role = "assistant"
|
||||
break
|
||||
|
||||
# Extract timestamp
|
||||
timestamp = self._extract_timestamp_from_message(msg_data)
|
||||
|
||||
return {"role": role, "content": content, "timestamp": timestamp}
|
||||
|
||||
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
|
||||
"""Extract timestamp from message data."""
|
||||
timestamp_fields = ["timestamp", "created_at", "date", "time"]
|
||||
for field in timestamp_fields:
|
||||
if msg_data.get(field):
|
||||
return str(msg_data[field])
|
||||
return None
|
||||
|
||||
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
|
||||
"""Extract timestamp from conversation data."""
|
||||
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
|
||||
for field in timestamp_fields:
|
||||
if conv_data.get(field):
|
||||
return str(conv_data[field])
|
||||
return None
|
||||
|
||||
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
|
||||
"""Extract or generate title for conversation."""
|
||||
# Try to find explicit title
|
||||
title_fields = ["title", "name", "subject", "topic"]
|
||||
for field in title_fields:
|
||||
if conv_data.get(field):
|
||||
return str(conv_data[field])
|
||||
|
||||
# Generate title from first user message
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content", "")
|
||||
if content:
|
||||
# Use first 50 characters as title
|
||||
title = content[:50].strip()
|
||||
if len(content) > 50:
|
||||
title += "..."
|
||||
return title
|
||||
|
||||
return "Claude Conversation"
|
||||
|
||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||
"""
|
||||
Create concatenated content from conversation messages.
|
||||
|
||||
Args:
|
||||
conversation: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
title = conversation.get("title", "Claude Conversation")
|
||||
messages = conversation.get("messages", [])
|
||||
timestamp = conversation.get("timestamp", "Unknown")
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if role == "user":
|
||||
prefix = "[You]"
|
||||
elif role == "assistant":
|
||||
prefix = "[Claude]"
|
||||
else:
|
||||
prefix = "[Message]"
|
||||
|
||||
# Add timestamp if available
|
||||
if msg_timestamp:
|
||||
prefix += f" ({msg_timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {content}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""Conversation: {title}
|
||||
Date: {timestamp}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load Claude export data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing Claude export files or path to specific file
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum number of conversations to process
|
||||
claude_export_path (str): Specific path to Claude export file/directory
|
||||
include_metadata (bool): Whether to include metadata in documents
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", -1)
|
||||
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
|
||||
include_metadata = load_kwargs.get("include_metadata", True)
|
||||
|
||||
if not claude_export_path:
|
||||
print("No Claude export path provided")
|
||||
return docs
|
||||
|
||||
export_path = Path(claude_export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"Claude export path not found: {export_path}")
|
||||
return docs
|
||||
|
||||
json_contents = []
|
||||
|
||||
# Handle different input types
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() == ".zip":
|
||||
# Extract JSON from zip file
|
||||
json_contents = self._extract_json_from_zip(export_path)
|
||||
elif export_path.suffix.lower() == ".json":
|
||||
# Read JSON file directly
|
||||
try:
|
||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||
json_contents.append(f.read())
|
||||
except Exception as e:
|
||||
print(f"Error reading JSON file {export_path}: {e}")
|
||||
return docs
|
||||
else:
|
||||
print(f"Unsupported file type: {export_path.suffix}")
|
||||
return docs
|
||||
|
||||
elif export_path.is_dir():
|
||||
# Look for JSON files in directory
|
||||
json_files = list(export_path.glob("*.json"))
|
||||
zip_files = list(export_path.glob("*.zip"))
|
||||
|
||||
if json_files:
|
||||
print(f"Found {len(json_files)} JSON files in directory")
|
||||
for json_file in json_files:
|
||||
try:
|
||||
with open(json_file, encoding="utf-8", errors="ignore") as f:
|
||||
json_contents.append(f.read())
|
||||
except Exception as e:
|
||||
print(f"Error reading JSON file {json_file}: {e}")
|
||||
continue
|
||||
|
||||
if zip_files:
|
||||
print(f"Found {len(zip_files)} ZIP files in directory")
|
||||
for zip_file in zip_files:
|
||||
zip_contents = self._extract_json_from_zip(zip_file)
|
||||
json_contents.extend(zip_contents)
|
||||
|
||||
if not json_files and not zip_files:
|
||||
print(f"No JSON or ZIP files found in {export_path}")
|
||||
return docs
|
||||
|
||||
if not json_contents:
|
||||
print("No JSON content found to process")
|
||||
return docs
|
||||
|
||||
# Parse conversations from JSON content
|
||||
print("Parsing Claude conversations from JSON...")
|
||||
all_conversations = []
|
||||
for json_content in json_contents:
|
||||
conversations = self._parse_claude_json(json_content)
|
||||
all_conversations.extend(conversations)
|
||||
|
||||
if not all_conversations:
|
||||
print("No conversations found in JSON content")
|
||||
return docs
|
||||
|
||||
print(f"Found {len(all_conversations)} conversations")
|
||||
|
||||
# Process conversations into documents
|
||||
count = 0
|
||||
for conversation in all_conversations:
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Create one document per conversation with concatenated messages
|
||||
doc_content = self._create_concatenated_content(conversation)
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"title": conversation.get("title", "Claude Conversation"),
|
||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||
"message_count": len(conversation.get("messages", [])),
|
||||
"source": "Claude Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
else:
|
||||
# Create separate documents for each message
|
||||
for message in conversation.get("messages", []):
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# Create document content with context
|
||||
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
|
||||
Role: {role}
|
||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||
Message: {content}
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"conversation_title": conversation.get("title", "Claude Conversation"),
|
||||
"role": role,
|
||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||
"source": "Claude Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(docs)} documents from Claude export")
|
||||
return docs
|
||||
189
apps/claude_rag.py
Normal file
189
apps/claude_rag.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Claude RAG example using the unified interface.
|
||||
Supports Claude export data from JSON files.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .claude_data.claude_reader import ClaudeReader
|
||||
|
||||
|
||||
class ClaudeRAG(BaseRAGExample):
|
||||
"""RAG example for Claude conversation data."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all conversations by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Claude",
|
||||
description="Process and query Claude conversation exports with LEANN",
|
||||
default_index_name="claude_conversations_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add Claude-specific arguments."""
|
||||
claude_group = parser.add_argument_group("Claude Parameters")
|
||||
claude_group.add_argument(
|
||||
"--export-path",
|
||||
type=str,
|
||||
default="./claude_export",
|
||||
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--separate-messages",
|
||||
action="store_true",
|
||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _find_claude_exports(self, export_path: Path) -> list[Path]:
|
||||
"""
|
||||
Find Claude export files in the given path.
|
||||
|
||||
Args:
|
||||
export_path: Path to search for exports
|
||||
|
||||
Returns:
|
||||
List of paths to Claude export files
|
||||
"""
|
||||
export_files = []
|
||||
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() in [".zip", ".json"]:
|
||||
export_files.append(export_path)
|
||||
elif export_path.is_dir():
|
||||
# Look for zip and json files
|
||||
export_files.extend(export_path.glob("*.zip"))
|
||||
export_files.extend(export_path.glob("*.json"))
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load Claude export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"Claude export path not found: {export_path}")
|
||||
print(
|
||||
"Please ensure you have exported your Claude data and placed it in the correct location."
|
||||
)
|
||||
print("\nTo export your Claude data:")
|
||||
print("1. Open Claude in your browser")
|
||||
print("2. Look for export/download options in settings or conversation menu")
|
||||
print("3. Download the conversation data (usually in JSON format)")
|
||||
print("4. Place the file/directory at the specified path")
|
||||
print(
|
||||
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
|
||||
)
|
||||
return []
|
||||
|
||||
# Find export files
|
||||
export_files = self._find_claude_exports(export_path)
|
||||
|
||||
if not export_files:
|
||||
print(f"No Claude export files (.json or .zip) found in: {export_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(export_files)} Claude export files")
|
||||
|
||||
# Create reader with appropriate settings
|
||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||
reader = ClaudeReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Process each export file
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_file in enumerate(export_files):
|
||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per file
|
||||
max_per_file = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_file = remaining
|
||||
|
||||
# Load conversations
|
||||
documents = reader.load_data(
|
||||
claude_export_path=str(export_file),
|
||||
max_count=max_per_file,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} conversations from this file")
|
||||
else:
|
||||
print(f"No conversations loaded from {export_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_file}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No conversations found to process!")
|
||||
print("\nTroubleshooting:")
|
||||
print("- Ensure the export file is a valid Claude export")
|
||||
print("- Check that the JSON file contains conversation data")
|
||||
print("- Try using a different export format or method")
|
||||
print("- Check Claude's documentation for current export procedures")
|
||||
return []
|
||||
|
||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||
print("Now starting to split into text chunks... this may take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for Claude RAG
|
||||
print("\n🤖 Claude RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did I ask Claude about Python programming?'")
|
||||
print("- 'Show me conversations about machine learning'")
|
||||
print("- 'Find discussions about code optimization'")
|
||||
print("- 'What advice did Claude give me about software design?'")
|
||||
print("- 'Search for conversations about debugging techniques'")
|
||||
print("\nTo get started:")
|
||||
print("1. Export your Claude conversation data")
|
||||
print("2. Place the JSON/ZIP file in ./claude_export/")
|
||||
print("3. Run this script to build your personal Claude knowledge base!")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = ClaudeRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -1,182 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
_repo_root = Path(current_file).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
||||
|
||||
|
||||
class LeannMultiVector:
|
||||
def __init__(
|
||||
self,
|
||||
index_path: str,
|
||||
dim: int = 128,
|
||||
distance_metric: str = "mips",
|
||||
m: int = 16,
|
||||
ef_construction: int = 500,
|
||||
is_compact: bool = False,
|
||||
is_recompute: bool = False,
|
||||
embedding_model_name: str = "colvision",
|
||||
) -> None:
|
||||
self.index_path = index_path
|
||||
self.dim = dim
|
||||
self.embedding_model_name = embedding_model_name
|
||||
self._pending_items: list[dict] = []
|
||||
self._backend_kwargs = {
|
||||
"distance_metric": distance_metric,
|
||||
"M": m,
|
||||
"efConstruction": ef_construction,
|
||||
"is_compact": is_compact,
|
||||
"is_recompute": is_recompute,
|
||||
}
|
||||
self._labels_meta: list[dict] = []
|
||||
|
||||
def _meta_dict(self) -> dict:
|
||||
return {
|
||||
"version": "1.0",
|
||||
"backend_name": "hnsw",
|
||||
"embedding_model": self.embedding_model_name,
|
||||
"embedding_mode": "custom",
|
||||
"dimensions": self.dim,
|
||||
"backend_kwargs": self._backend_kwargs,
|
||||
"is_compact": self._backend_kwargs.get("is_compact", True),
|
||||
"is_pruned": self._backend_kwargs.get("is_compact", True)
|
||||
and self._backend_kwargs.get("is_recompute", True),
|
||||
}
|
||||
|
||||
def create_collection(self) -> None:
|
||||
path = Path(self.index_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def insert(self, data: dict) -> None:
|
||||
self._pending_items.append(
|
||||
{
|
||||
"doc_id": int(data["doc_id"]),
|
||||
"filepath": data.get("filepath", ""),
|
||||
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
||||
}
|
||||
)
|
||||
|
||||
def _labels_path(self) -> Path:
|
||||
index_path_obj = Path(self.index_path)
|
||||
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
|
||||
|
||||
def _meta_path(self) -> Path:
|
||||
index_path_obj = Path(self.index_path)
|
||||
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||
|
||||
def create_index(self) -> None:
|
||||
if not self._pending_items:
|
||||
return
|
||||
|
||||
embeddings: list[np.ndarray] = []
|
||||
labels_meta: list[dict] = []
|
||||
|
||||
for item in self._pending_items:
|
||||
doc_id = int(item["doc_id"])
|
||||
filepath = item.get("filepath", "")
|
||||
colbert_vecs = item["colbert_vecs"]
|
||||
for seq_id, vec in enumerate(colbert_vecs):
|
||||
vec_np = np.asarray(vec, dtype=np.float32)
|
||||
embeddings.append(vec_np)
|
||||
labels_meta.append(
|
||||
{
|
||||
"id": f"{doc_id}:{seq_id}",
|
||||
"doc_id": doc_id,
|
||||
"seq_id": int(seq_id),
|
||||
"filepath": filepath,
|
||||
}
|
||||
)
|
||||
|
||||
if not embeddings:
|
||||
return
|
||||
|
||||
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||
# print shape of embeddings_np
|
||||
print(embeddings_np.shape)
|
||||
|
||||
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||
ids = [str(i) for i in range(embeddings_np.shape[0])]
|
||||
builder.build(embeddings_np, ids, self.index_path)
|
||||
|
||||
import json as _json
|
||||
|
||||
with open(self._meta_path(), "w", encoding="utf-8") as f:
|
||||
_json.dump(self._meta_dict(), f, indent=2)
|
||||
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||
_json.dump(labels_meta, f)
|
||||
|
||||
self._labels_meta = labels_meta
|
||||
|
||||
def _load_labels_meta_if_needed(self) -> None:
|
||||
if self._labels_meta:
|
||||
return
|
||||
labels_path = self._labels_path()
|
||||
if labels_path.exists():
|
||||
import json as _json
|
||||
|
||||
with open(labels_path, encoding="utf-8") as f:
|
||||
self._labels_meta = _json.load(f)
|
||||
|
||||
def search(
|
||||
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||
) -> list[tuple[float, int]]:
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(1, -1)
|
||||
if data.dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
self._load_labels_meta_if_needed()
|
||||
|
||||
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
||||
raw = searcher.search(
|
||||
data,
|
||||
first_stage_k,
|
||||
recompute_embeddings=False,
|
||||
complexity=128,
|
||||
beam_width=1,
|
||||
prune_ratio=0.0,
|
||||
batch_size=0,
|
||||
)
|
||||
|
||||
labels = raw.get("labels")
|
||||
distances = raw.get("distances")
|
||||
if labels is None or distances is None:
|
||||
return []
|
||||
|
||||
doc_scores: dict[int, float] = {}
|
||||
B = len(labels)
|
||||
for b in range(B):
|
||||
per_doc_best: dict[int, float] = {}
|
||||
for k, sid in enumerate(labels[b]):
|
||||
try:
|
||||
idx = int(sid)
|
||||
except Exception:
|
||||
continue
|
||||
if 0 <= idx < len(self._labels_meta):
|
||||
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
|
||||
else:
|
||||
continue
|
||||
score = float(distances[b][k])
|
||||
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
|
||||
per_doc_best[doc_id] = score
|
||||
for doc_id, best_score in per_doc_best.items():
|
||||
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
|
||||
|
||||
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
@@ -1,477 +0,0 @@
|
||||
## Jupyter-style notebook script
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||
_repo_root = Path(current_file).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||
|
||||
# %%
|
||||
# Config
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
|
||||
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||
|
||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||
USE_HF_DATASET: bool = True
|
||||
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||
DATASET_SPLIT: str = "train"
|
||||
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||
|
||||
# Local pages (used when USE_HF_DATASET == False)
|
||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||
PAGES_DIR: str = "./pages"
|
||||
|
||||
# Index + retrieval settings
|
||||
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||
TOPK: int = 1
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = False
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
SIMILARITY_MAP: bool = True
|
||||
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||
ANSWER: bool = True
|
||||
MAX_NEW_TOKENS: int = 128
|
||||
|
||||
|
||||
# %%
|
||||
# Helpers
|
||||
def _natural_sort_key(name: str) -> int:
|
||||
m = re.search(r"\d+", name)
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||
filenames = sorted(filenames, key=_natural_sort_key)
|
||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||
images = [Image.open(p) for p in filepaths]
|
||||
return filepaths, images
|
||||
|
||||
|
||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||
if not pdf_path:
|
||||
return
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
try:
|
||||
from pdf2image import convert_from_path
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||
) from e
|
||||
images = convert_from_path(pdf_path, dpi=dpi)
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||
|
||||
|
||||
def _select_device_and_dtype():
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import get_torch_device
|
||||
|
||||
device_str = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
"mps"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
device = get_torch_device(device_str)
|
||||
# Stable dtype selection to avoid NaNs:
|
||||
# - CUDA: prefer bfloat16 if supported, else float16
|
||||
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||
# - CPU: float32
|
||||
if device_str == "cuda":
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||
except Exception:
|
||||
pass
|
||||
elif device_str == "mps":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = torch.float32
|
||||
return device_str, device, dtype
|
||||
|
||||
|
||||
def _load_colvision(model_choice: str):
|
||||
import torch
|
||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
device_str, device, dtype = _select_device_and_dtype()
|
||||
|
||||
if model_choice == "colqwen2":
|
||||
model_name = "vidore/colqwen2-v1.0"
|
||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||
attn_implementation = (
|
||||
"flash_attention_2"
|
||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||
else "eager"
|
||||
)
|
||||
model = ColQwen2.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
).eval()
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
else:
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
).eval()
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
|
||||
return model_name, model, processor, device_str, device, dtype
|
||||
|
||||
|
||||
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Ensure deterministic eval and autocast for stability
|
||||
model.eval()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[Image.Image](images),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
|
||||
doc_vecs: list[Any] = []
|
||||
for batch_doc in dataloader:
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
embeddings_doc = model(**batch_doc)
|
||||
else:
|
||||
embeddings_doc = model(**batch_doc)
|
||||
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||
return doc_vecs
|
||||
|
||||
|
||||
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
model.eval()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
q_vecs: list[Any] = []
|
||||
for batch_query in dataloader:
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
embeddings_query = model(**batch_query)
|
||||
else:
|
||||
embeddings_query = model(**batch_query)
|
||||
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
return q_vecs
|
||||
|
||||
|
||||
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
|
||||
dim = int(doc_vecs[0].shape[-1])
|
||||
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||
retriever.create_collection()
|
||||
for i, vec in enumerate(doc_vecs):
|
||||
data = {
|
||||
"colbert_vecs": vec.float().numpy(),
|
||||
"doc_id": i,
|
||||
"filepath": filepaths[i],
|
||||
}
|
||||
retriever.insert(data)
|
||||
retriever.create_index()
|
||||
return retriever
|
||||
|
||||
|
||||
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
||||
index_base = Path(index_path)
|
||||
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||
if index_base.exists() and meta.exists() and labels.exists():
|
||||
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_similarity_map(
|
||||
model,
|
||||
processor,
|
||||
image: Image.Image,
|
||||
query: str,
|
||||
token_idx: Optional[int] = None,
|
||||
output_path: Optional[str] = None,
|
||||
) -> tuple[int, float]:
|
||||
import torch
|
||||
from colpali_engine.interpretability import (
|
||||
get_similarity_maps_from_embeddings,
|
||||
plot_similarity_map,
|
||||
)
|
||||
|
||||
batch_images = processor.process_images([image]).to(model.device)
|
||||
batch_queries = processor.process_queries([query]).to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_embeddings = model.forward(**batch_images)
|
||||
query_embeddings = model.forward(**batch_queries)
|
||||
|
||||
n_patches = processor.get_n_patches(
|
||||
image_size=image.size,
|
||||
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||
)
|
||||
image_mask = processor.get_image_mask(batch_images)
|
||||
|
||||
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||
image_embeddings=image_embeddings,
|
||||
query_embeddings=query_embeddings,
|
||||
n_patches=n_patches,
|
||||
image_mask=image_mask,
|
||||
)
|
||||
|
||||
similarity_maps = batched_similarity_maps[0]
|
||||
|
||||
# Determine token index if not provided: choose the token with highest max score
|
||||
if token_idx is None:
|
||||
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||
token_idx = int(per_token_max.argmax().item())
|
||||
|
||||
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||
|
||||
if output_path:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plot_similarity_map(
|
||||
image=image,
|
||||
similarity_map=similarity_maps[token_idx],
|
||||
figsize=(14, 14),
|
||||
show_colorbar=False,
|
||||
)
|
||||
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
plt.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
return token_idx, float(max_sim_score)
|
||||
|
||||
|
||||
class QwenVL:
|
||||
def __init__(self, device: str):
|
||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
|
||||
min_pixels = 256 * 28 * 28
|
||||
max_pixels = 1280 * 28 * 28
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||
)
|
||||
|
||||
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
content = []
|
||||
for img in images:
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="jpeg")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||
content.append({"type": "text", "text": query})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
return self.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# Step 1: Prepare data
|
||||
if USE_HF_DATASET:
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||
filepaths: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for i in tqdm(range(N), desc="Loading dataset"):
|
||||
p = dataset[i]
|
||||
# Compose a descriptive identifier for printing later
|
||||
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||
print(identifier)
|
||||
filepaths.append(identifier)
|
||||
images.append(p["page_image"]) # PIL Image
|
||||
else:
|
||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Step 2: Load model and processor
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
# Step 3: Build or load index
|
||||
retriever: Optional[LeannMultiVector] = None
|
||||
if not REBUILD_INDEX:
|
||||
try:
|
||||
one_vec = _embed_images(model, processor, [images[0]])[0]
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
||||
except Exception:
|
||||
retriever = None
|
||||
|
||||
if retriever is None:
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
|
||||
|
||||
|
||||
# %%
|
||||
# Step 4: Embed query and search
|
||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
top_images: list[Image.Image] = []
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
path = filepaths[doc_id]
|
||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||
top_images.append(images[doc_id])
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
|
||||
base = _Path(SAVE_TOP_IMAGE)
|
||||
base.parent.mkdir(parents=True, exist_ok=True)
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if base.suffix:
|
||||
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||
else:
|
||||
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||
img.save(str(out_path))
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
|
||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||
|
||||
# %%
|
||||
# Step 5: Similarity maps for top-K results
|
||||
if results and SIMILARITY_MAP:
|
||||
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||
from pathlib import Path as _Path
|
||||
|
||||
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if output_base:
|
||||
if output_base.suffix:
|
||||
out_dir = output_base.parent
|
||||
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||
out_path = str(out_dir / out_name)
|
||||
else:
|
||||
out_dir = output_base
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||
else:
|
||||
out_path = None
|
||||
chosen_idx, max_sim = _generate_similarity_map(
|
||||
model=model,
|
||||
processor=processor,
|
||||
image=img,
|
||||
query=QUERY,
|
||||
token_idx=token_idx,
|
||||
output_path=out_path,
|
||||
)
|
||||
if out_path:
|
||||
print(
|
||||
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Step 6: Optional answer generation
|
||||
if results and ANSWER:
|
||||
qwen = QwenVL(device=device_str)
|
||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||
print("\nAnswer:")
|
||||
print(response)
|
||||
@@ -1,134 +0,0 @@
|
||||
# pip install pdf2image
|
||||
# pip install pymilvus
|
||||
# pip install colpali_engine
|
||||
# pip install tqdm
|
||||
# pip install pillow
|
||||
|
||||
# %%
|
||||
from pdf2image import convert_from_path
|
||||
|
||||
pdf_path = "pdfs/2004.12832v2.pdf"
|
||||
images = convert_from_path(pdf_path)
|
||||
|
||||
for i, image in enumerate(images):
|
||||
image.save(f"pages/page_{i + 1}.png", "PNG")
|
||||
|
||||
# %%
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Make local leann packages importable without installing
|
||||
_repo_root = Path(__file__).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
import sys
|
||||
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
from leann_multi_vector import LeannMultiVector
|
||||
|
||||
|
||||
class LeannRetriever(LeannMultiVector):
|
||||
pass
|
||||
|
||||
|
||||
# %%
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from colpali_engine.models import ColPali
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||
_device_str = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
"mps"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
device = get_torch_device(_device_str)
|
||||
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=_dtype,
|
||||
device_map=device,
|
||||
).eval()
|
||||
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||
|
||||
queries = [
|
||||
"How to end-to-end retrieval with ColBert",
|
||||
"Where is ColBERT performance Table, including text representation results?",
|
||||
]
|
||||
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
qs: list[torch.Tensor] = []
|
||||
for batch_query in dataloader:
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
embeddings_query = model(**batch_query)
|
||||
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
print(qs[0].shape)
|
||||
# %%
|
||||
|
||||
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](images),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
|
||||
ds: list[torch.Tensor] = []
|
||||
for batch_doc in tqdm(dataloader):
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
embeddings_doc = model(**batch_doc)
|
||||
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||
|
||||
print(ds[0].shape)
|
||||
|
||||
# %%
|
||||
# Build HNSW index via LeannRetriever primitives and run search
|
||||
index_path = "./indexes/colpali.leann"
|
||||
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||
retriever.create_collection()
|
||||
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||
for i in range(len(filepaths)):
|
||||
data = {
|
||||
"colbert_vecs": ds[i].float().numpy(),
|
||||
"doc_id": i,
|
||||
"filepath": filepaths[i],
|
||||
}
|
||||
retriever.insert(data)
|
||||
retriever.create_index()
|
||||
for query in qs:
|
||||
query_np = query.float().numpy()
|
||||
result = retriever.search(query_np, topk=1)
|
||||
print(filepaths[result[0][1]])
|
||||
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
404
examples/dynamic_update_no_recompute.py
Normal file
404
examples/dynamic_update_no_recompute.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""Dynamic HNSW update demo without compact storage.
|
||||
|
||||
This script reproduces the minimal scenario we used while debugging on-the-fly
|
||||
recompute:
|
||||
|
||||
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
|
||||
2. Print the top results with `recompute_embeddings=True`.
|
||||
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
|
||||
4. Run the same query again to show the newly inserted passages.
|
||||
|
||||
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
|
||||
ZMQ activity)::
|
||||
|
||||
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
|
||||
uv run -m examples.dynamic_update_no_recompute \
|
||||
--index-path .leann/examples/leann-demo.leann
|
||||
|
||||
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
|
||||
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
|
||||
It issues the query "What's LEANN?" before and after the update to show how the
|
||||
new passages become immediately searchable. The script uses the
|
||||
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
|
||||
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
|
||||
freshly added passages are embedded locally just like the initial build.
|
||||
|
||||
To make storage comparisons easy, the script can also build a matching
|
||||
``is_recompute=False`` baseline (enabled by default) and report the index size
|
||||
delta after the update. Disable the baseline run with
|
||||
``--skip-compare-no-recompute`` if you only need the recompute flow.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
from apps.chunking import create_text_chunks
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
DEFAULT_QUERY = "What's LEANN?"
|
||||
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path]) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
return [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
|
||||
|
||||
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
|
||||
searcher = LeannSearcher(str(index_path))
|
||||
try:
|
||||
return searcher.search(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
batch_size=16,
|
||||
)
|
||||
finally:
|
||||
searcher.cleanup()
|
||||
|
||||
|
||||
def print_results(title: str, results: Iterable) -> None:
|
||||
print(f"\n=== {title} ===")
|
||||
res_list = list(results)
|
||||
print(f"results count: {len(res_list)}")
|
||||
print("passages:")
|
||||
if not res_list:
|
||||
print(" (no passages returned)")
|
||||
for res in res_list:
|
||||
snippet = res.text.replace("\n", " ")[:120]
|
||||
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
is_recompute: bool,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def update_index(
|
||||
index_path: Path,
|
||||
start_id: int,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
is_recompute: bool,
|
||||
) -> None:
|
||||
updater = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
for offset, passage in enumerate(paragraphs, start=start_id):
|
||||
updater.add_text(passage, metadata={"id": str(offset)})
|
||||
updater.update_index(str(index_path))
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
"""Remove leftover index artifacts for a clean rebuild."""
|
||||
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def index_file_size(index_path: Path) -> int:
|
||||
"""Return the size of the primary .index file for the given index path."""
|
||||
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
return index_file.stat().st_size if index_file.exists() else 0
|
||||
|
||||
|
||||
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
|
||||
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(meta_path.read_text())
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def run_workflow(
|
||||
*,
|
||||
label: str,
|
||||
index_path: Path,
|
||||
initial_paragraphs: list[str],
|
||||
update_paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
is_recompute: bool,
|
||||
query: str,
|
||||
top_k: int,
|
||||
) -> dict[str, Any]:
|
||||
prefix = f"[{label}] " if label else ""
|
||||
|
||||
ensure_index_dir(index_path)
|
||||
cleanup_index_files(index_path)
|
||||
|
||||
print(f"{prefix}Building initial index...")
|
||||
build_initial_index(
|
||||
index_path,
|
||||
initial_paragraphs,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
|
||||
initial_size = index_file_size(index_path)
|
||||
before_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
|
||||
print(f"\n{prefix}Updating index with additional passages...")
|
||||
update_index(
|
||||
index_path,
|
||||
start_id=len(initial_paragraphs),
|
||||
paragraphs=update_paragraphs,
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
|
||||
after_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
updated_size = index_file_size(index_path)
|
||||
|
||||
return {
|
||||
"initial_size": initial_size,
|
||||
"updated_size": updated_size,
|
||||
"delta": updated_size - initial_size,
|
||||
"before_results": before_results,
|
||||
"after_results": after_results,
|
||||
"metadata": load_metadata_snapshot(index_path),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
type=Path,
|
||||
nargs="+",
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
help="Initial document files (PDF/TXT) used to build the base index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/examples/leann-demo.leann"),
|
||||
help="Destination index path (default: .leann/examples/leann-demo.leann)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-count",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of chunks to use from the initial documents (default: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
type=Path,
|
||||
nargs="*",
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
help="Additional documents to add during update (PDF/TXT)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-count",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of chunks to append from update documents (default: 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-text",
|
||||
type=str,
|
||||
default=(
|
||||
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
|
||||
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
|
||||
"on demand to keep disk usage minimal."
|
||||
),
|
||||
help="Fallback text to append if --update-files is omitted",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of results to show for each search (default: 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=DEFAULT_QUERY,
|
||||
help="Query to run before/after the update",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare-no-recompute",
|
||||
dest="compare_no_recompute",
|
||||
action="store_true",
|
||||
help="Also run a baseline with is_recompute=False and report its index growth.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-compare-no-recompute",
|
||||
dest="compare_no_recompute",
|
||||
action="store_false",
|
||||
help="Skip building the no-recompute baseline.",
|
||||
)
|
||||
parser.set_defaults(compare_no_recompute=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
ensure_index_dir(args.index_path)
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
initial_chunks = load_chunks_from_files(list(args.initial_files))
|
||||
if not initial_chunks:
|
||||
raise ValueError("No text chunks extracted from the initial files.")
|
||||
|
||||
initial = initial_chunks[: args.initial_count]
|
||||
if not initial:
|
||||
raise ValueError("Initial chunk set is empty after applying --initial-count.")
|
||||
|
||||
if args.update_files:
|
||||
update_chunks = load_chunks_from_files(list(args.update_files))
|
||||
if not update_chunks:
|
||||
raise ValueError("No text chunks extracted from the update files.")
|
||||
to_add = update_chunks[: args.update_count]
|
||||
else:
|
||||
if not args.update_text:
|
||||
raise ValueError("Provide --update-files or --update-text for the update step.")
|
||||
to_add = [args.update_text]
|
||||
if not to_add:
|
||||
raise ValueError("Update chunk set is empty after applying --update-count.")
|
||||
|
||||
recompute_stats = run_workflow(
|
||||
label="recompute",
|
||||
index_path=args.index_path,
|
||||
initial_paragraphs=initial,
|
||||
update_paragraphs=to_add,
|
||||
model_name=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
is_recompute=True,
|
||||
query=args.query,
|
||||
top_k=args.top_k,
|
||||
)
|
||||
|
||||
print_results("initial search", recompute_stats["before_results"])
|
||||
print_results("after update", recompute_stats["after_results"])
|
||||
print(
|
||||
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||
f" (Δ {recompute_stats['delta']})"
|
||||
)
|
||||
|
||||
if recompute_stats["metadata"]:
|
||||
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||
print("[recompute] metadata snapshot:")
|
||||
print(json.dumps(meta_view, indent=2))
|
||||
|
||||
if args.compare_no_recompute:
|
||||
baseline_path = (
|
||||
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
|
||||
)
|
||||
baseline_stats = run_workflow(
|
||||
label="no-recompute",
|
||||
index_path=baseline_path,
|
||||
initial_paragraphs=initial,
|
||||
update_paragraphs=to_add,
|
||||
model_name=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
is_recompute=False,
|
||||
query=args.query,
|
||||
top_k=args.top_k,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
|
||||
f" (Δ {baseline_stats['delta']})"
|
||||
)
|
||||
|
||||
after_texts = [res.text for res in recompute_stats["after_results"]]
|
||||
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
|
||||
if after_texts == baseline_after_texts:
|
||||
print(
|
||||
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||
)
|
||||
else:
|
||||
print("[no-recompute] WARNING: search results differ from recompute baseline.")
|
||||
|
||||
if baseline_stats["metadata"]:
|
||||
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||
print("[no-recompute] metadata snapshot:")
|
||||
print(json.dumps(meta_view, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,6 +5,8 @@ import os
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -237,6 +239,288 @@ def write_compact_format(
|
||||
f_out.write(storage_data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HNSWComponents:
|
||||
original_hnsw_data: dict[str, Any]
|
||||
assign_probas_np: np.ndarray
|
||||
cum_nneighbor_per_level_np: np.ndarray
|
||||
levels_np: np.ndarray
|
||||
is_compact: bool
|
||||
compact_level_ptr: Optional[np.ndarray] = None
|
||||
compact_node_offsets_np: Optional[np.ndarray] = None
|
||||
compact_neighbors_data: Optional[list[int]] = None
|
||||
offsets_np: Optional[np.ndarray] = None
|
||||
neighbors_np: Optional[np.ndarray] = None
|
||||
storage_fourcc: int = NULL_INDEX_FOURCC
|
||||
storage_data: bytes = b""
|
||||
|
||||
|
||||
def _read_hnsw_structure(f) -> HNSWComponents:
|
||||
original_hnsw_data: dict[str, Any] = {}
|
||||
|
||||
hnsw_index_fourcc = read_struct(f, "<I")
|
||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
raise ValueError(
|
||||
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
|
||||
)
|
||||
|
||||
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||
original_hnsw_data["d"] = read_struct(f, "<i")
|
||||
original_hnsw_data["ntotal"] = read_struct(f, "<q")
|
||||
original_hnsw_data["dummy1"] = read_struct(f, "<q")
|
||||
original_hnsw_data["dummy2"] = read_struct(f, "<q")
|
||||
original_hnsw_data["is_trained"] = read_struct(f, "?")
|
||||
original_hnsw_data["metric_type"] = read_struct(f, "<i")
|
||||
original_hnsw_data["metric_arg"] = 0.0
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
|
||||
|
||||
assign_probas_np = read_numpy_vector(f, np.float64, "d")
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
|
||||
levels_np = read_numpy_vector(f, np.int32, "i")
|
||||
|
||||
ntotal = len(levels_np)
|
||||
if ntotal != original_hnsw_data["ntotal"]:
|
||||
original_hnsw_data["ntotal"] = ntotal
|
||||
|
||||
pos_before_compact = f.tell()
|
||||
is_compact_flag = None
|
||||
try:
|
||||
is_compact_flag = read_struct(f, "<?")
|
||||
except EOFError:
|
||||
is_compact_flag = None
|
||||
|
||||
if is_compact_flag:
|
||||
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
|
||||
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||
|
||||
storage_fourcc = read_struct(f, "<I")
|
||||
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
storage_data = f.read()
|
||||
|
||||
return HNSWComponents(
|
||||
original_hnsw_data=original_hnsw_data,
|
||||
assign_probas_np=assign_probas_np,
|
||||
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||
levels_np=levels_np,
|
||||
is_compact=True,
|
||||
compact_level_ptr=compact_level_ptr,
|
||||
compact_node_offsets_np=compact_node_offsets_np,
|
||||
compact_neighbors_data=compact_neighbors_data,
|
||||
storage_fourcc=storage_fourcc,
|
||||
storage_data=storage_data,
|
||||
)
|
||||
|
||||
# Non-compact case
|
||||
f.seek(pos_before_compact)
|
||||
|
||||
pos_before_probe = f.tell()
|
||||
try:
|
||||
suspected_flag = read_struct(f, "<B")
|
||||
if suspected_flag != 0x00:
|
||||
f.seek(pos_before_probe)
|
||||
except EOFError:
|
||||
f.seek(pos_before_probe)
|
||||
|
||||
offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||
neighbors_np = read_numpy_vector(f, np.int32, "i")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||
|
||||
storage_fourcc = NULL_INDEX_FOURCC
|
||||
storage_data = b""
|
||||
try:
|
||||
storage_fourcc = read_struct(f, "<I")
|
||||
storage_data = f.read()
|
||||
except EOFError:
|
||||
storage_fourcc = NULL_INDEX_FOURCC
|
||||
|
||||
return HNSWComponents(
|
||||
original_hnsw_data=original_hnsw_data,
|
||||
assign_probas_np=assign_probas_np,
|
||||
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||
levels_np=levels_np,
|
||||
is_compact=False,
|
||||
offsets_np=offsets_np,
|
||||
neighbors_np=neighbors_np,
|
||||
storage_fourcc=storage_fourcc,
|
||||
storage_data=storage_data,
|
||||
)
|
||||
|
||||
|
||||
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
|
||||
with open(path, "rb") as f:
|
||||
return _read_hnsw_structure(f)
|
||||
|
||||
|
||||
def write_original_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
offsets_np,
|
||||
neighbors_np,
|
||||
storage_fourcc,
|
||||
storage_data,
|
||||
):
|
||||
"""Write non-compact HNSW data in original FAISS order."""
|
||||
|
||||
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
||||
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
||||
|
||||
write_numpy_vector(f_out, assign_probas_np, "d")
|
||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
||||
write_numpy_vector(f_out, levels_np, "i")
|
||||
|
||||
write_numpy_vector(f_out, offsets_np, "Q")
|
||||
write_numpy_vector(f_out, neighbors_np, "i")
|
||||
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
||||
|
||||
f_out.write(struct.pack("<I", storage_fourcc))
|
||||
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||
f_out.write(storage_data)
|
||||
|
||||
|
||||
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
|
||||
"""Rewrite an HNSW index while dropping the embedded storage section."""
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
||||
original_hnsw_data: dict[str, Any] = {}
|
||||
|
||||
hnsw_index_fourcc = read_struct(f_in, "<I")
|
||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
print(
|
||||
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
|
||||
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
||||
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["metric_arg"] = 0.0
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
||||
|
||||
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
|
||||
ntotal = len(levels_np)
|
||||
if ntotal != original_hnsw_data["ntotal"]:
|
||||
original_hnsw_data["ntotal"] = ntotal
|
||||
|
||||
pos_before_compact = f_in.tell()
|
||||
is_compact_flag = None
|
||||
try:
|
||||
is_compact_flag = read_struct(f_in, "<?")
|
||||
except EOFError:
|
||||
is_compact_flag = None
|
||||
|
||||
if is_compact_flag:
|
||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
|
||||
_storage_fourcc = read_struct(f_in, "<I")
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
_storage_data = f_in.read()
|
||||
|
||||
write_compact_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
compact_level_ptr,
|
||||
compact_node_offsets_np,
|
||||
compact_neighbors_data,
|
||||
NULL_INDEX_FOURCC,
|
||||
b"",
|
||||
)
|
||||
else:
|
||||
f_in.seek(pos_before_compact)
|
||||
|
||||
pos_before_probe = f_in.tell()
|
||||
try:
|
||||
suspected_flag = read_struct(f_in, "<B")
|
||||
if suspected_flag != 0x00:
|
||||
f_in.seek(pos_before_probe)
|
||||
except EOFError:
|
||||
f_in.seek(pos_before_probe)
|
||||
|
||||
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
|
||||
_storage_fourcc = None
|
||||
_storage_data = b""
|
||||
try:
|
||||
_storage_fourcc = read_struct(f_in, "<I")
|
||||
_storage_data = f_in.read()
|
||||
except EOFError:
|
||||
_storage_fourcc = NULL_INDEX_FOURCC
|
||||
|
||||
write_original_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
offsets_np,
|
||||
neighbors_np,
|
||||
NULL_INDEX_FOURCC,
|
||||
b"",
|
||||
)
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
|
||||
return True
|
||||
except Exception as exc:
|
||||
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
# --- Main Conversion Logic ---
|
||||
|
||||
|
||||
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
pass
|
||||
|
||||
|
||||
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
|
||||
"""Convenience wrapper to prune embeddings in-place."""
|
||||
|
||||
temp_path = f"{index_filename}.prune.tmp"
|
||||
success = prune_hnsw_embeddings(index_filename, temp_path)
|
||||
if success:
|
||||
try:
|
||||
os.replace(temp_path, index_filename)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.error(f"Failed to replace original index with pruned version: {exc}")
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return success
|
||||
|
||||
|
||||
# --- Script Execution ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
@@ -14,7 +14,7 @@ from leann.interface import (
|
||||
from leann.registry import register_backend
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -92,6 +92,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
elif self.is_recompute:
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
@@ -133,10 +135,10 @@ class HNSWSearcher(BaseSearcher):
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
|
||||
self.is_compact, self.is_pruned = (
|
||||
self.meta.get("is_compact", True),
|
||||
self.meta.get("is_pruned", True),
|
||||
)
|
||||
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
|
||||
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
|
||||
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
|
||||
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
|
||||
|
||||
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||
if not index_file.exists():
|
||||
|
||||
@@ -24,13 +24,26 @@ logger = logging.getLogger(__name__)
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Ensure we have a handler if none exists
|
||||
# Ensure we have handlers if none exist
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
stream_handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
|
||||
if log_path:
|
||||
try:
|
||||
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
except Exception as exc: # pragma: no cover - best effort logging
|
||||
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
|
||||
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: ed96ff7dba...1d51f0c074
@@ -15,6 +15,7 @@ from pathlib import Path
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
|
||||
@@ -476,9 +477,7 @@ class LeannBuilder:
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||
meta_data["is_compact"] = is_compact
|
||||
meta_data["is_pruned"] = (
|
||||
is_compact and is_recompute
|
||||
) # Pruned only if compact and recompute
|
||||
meta_data["is_pruned"] = bool(is_recompute)
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
@@ -598,13 +597,157 @@ class LeannBuilder:
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||
meta_data["is_compact"] = is_compact
|
||||
meta_data["is_pruned"] = is_compact and is_recompute
|
||||
meta_data["is_pruned"] = bool(is_recompute)
|
||||
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||
|
||||
def update_index(self, index_path: str):
|
||||
"""Append new passages and vectors to an existing HNSW index."""
|
||||
if not self.chunks:
|
||||
raise ValueError("No new chunks provided for update.")
|
||||
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_name = path.name
|
||||
index_prefix = path.stem
|
||||
|
||||
meta_path = index_dir / f"{index_name}.meta.json"
|
||||
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
|
||||
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
|
||||
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
|
||||
if not index_file.exists():
|
||||
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
|
||||
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
backend_name = meta.get("backend_name")
|
||||
if backend_name != self.backend_name:
|
||||
raise ValueError(
|
||||
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
|
||||
)
|
||||
|
||||
meta_backend_kwargs = meta.get("backend_kwargs", {})
|
||||
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
|
||||
if index_is_compact:
|
||||
raise ValueError(
|
||||
"Compact HNSW indices do not support in-place updates. Rebuild required."
|
||||
)
|
||||
|
||||
distance_metric = meta_backend_kwargs.get(
|
||||
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
|
||||
).lower()
|
||||
needs_recompute = bool(
|
||||
meta.get("is_pruned")
|
||||
or meta_backend_kwargs.get("is_recompute")
|
||||
or self.backend_kwargs.get("is_recompute")
|
||||
)
|
||||
|
||||
with open(offset_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
existing_ids = set(offset_map.keys())
|
||||
|
||||
valid_chunks: list[dict[str, Any]] = []
|
||||
for chunk in self.chunks:
|
||||
text = chunk.get("text", "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
continue
|
||||
metadata = chunk.setdefault("metadata", {})
|
||||
passage_id = chunk.get("id") or metadata.get("id")
|
||||
if passage_id and passage_id in existing_ids:
|
||||
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||
valid_chunks.append(chunk)
|
||||
|
||||
if not valid_chunks:
|
||||
raise ValueError("No valid chunks to append.")
|
||||
|
||||
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed,
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
is_build=True,
|
||||
)
|
||||
|
||||
embedding_dim = embeddings.shape[1]
|
||||
expected_dim = meta.get("dimensions")
|
||||
if expected_dim is not None and expected_dim != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
|
||||
)
|
||||
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
|
||||
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
if distance_metric == "cosine":
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
|
||||
index = faiss.read_index(str(index_file))
|
||||
if hasattr(index, "is_recompute"):
|
||||
index.is_recompute = needs_recompute
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
if index.d != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
||||
)
|
||||
|
||||
base_id = index.ntotal
|
||||
for offset, chunk in enumerate(valid_chunks):
|
||||
new_id = str(base_id + offset)
|
||||
chunk.setdefault("metadata", {})["id"] = new_id
|
||||
chunk["id"] = new_id
|
||||
|
||||
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for chunk in valid_chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
logger.info(
|
||||
"Appended %d passages to index '%s'. New total: %d",
|
||||
len(valid_chunks),
|
||||
index_path,
|
||||
len(offset_map),
|
||||
)
|
||||
|
||||
self.chunks.clear()
|
||||
|
||||
if needs_recompute:
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
|
||||
@@ -104,11 +104,7 @@ astchunk = { path = "packages/astchunk-leann", editable = true }
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
line-length = 100
|
||||
extend-exclude = [
|
||||
"third_party",
|
||||
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann.py",
|
||||
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
|
||||
]
|
||||
extend-exclude = ["third_party"]
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
Reference in New Issue
Block a user