Merge remote-tracking branch 'origin/main' into datastore-reproduce
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -8,11 +8,15 @@ demo/indices/
|
|||||||
*pycache*
|
*pycache*
|
||||||
outputs/
|
outputs/
|
||||||
*.pkl
|
*.pkl
|
||||||
|
*.pdf
|
||||||
.history/
|
.history/
|
||||||
scripts/
|
scripts/
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
|
*.eml
|
||||||
|
*.emlx
|
||||||
|
*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
@@ -42,6 +46,7 @@ embedding_comparison_results/
|
|||||||
*.ivecs
|
*.ivecs
|
||||||
*.index
|
*.index
|
||||||
*.bin
|
*.bin
|
||||||
|
*.old
|
||||||
|
|
||||||
read_graph
|
read_graph
|
||||||
analyze_diskann_graph
|
analyze_diskann_graph
|
||||||
|
|||||||
10
.gitmodules
vendored
10
.gitmodules
vendored
@@ -4,3 +4,13 @@
|
|||||||
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
||||||
path = packages/leann-backend-hnsw/third_party/faiss
|
path = packages/leann-backend-hnsw/third_party/faiss
|
||||||
url = https://github.com/yichuan520030910320/faiss.git
|
url = https://github.com/yichuan520030910320/faiss.git
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/msgpack-c
|
||||||
|
url = https://github.com/msgpack/msgpack-c.git
|
||||||
|
branch = cpp_master
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/cppzmq"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/cppzmq
|
||||||
|
url = https://github.com/zeromq/cppzmq.git
|
||||||
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
|||||||
105
README.md
105
README.md
@@ -28,13 +28,15 @@
|
|||||||
### 🎯 Why Leann?
|
### 🎯 Why Leann?
|
||||||
|
|
||||||
Traditional RAG systems face a fundamental trade-off:
|
Traditional RAG systems face a fundamental trade-off:
|
||||||
|
|
||||||
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space
|
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space
|
||||||
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change
|
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change
|
||||||
- **💰 Cost**: Vector databases are expensive to scale
|
- **💰 Cost**: Vector databases are expensive to scale
|
||||||
|
|
||||||
**Leann solves this by:**
|
**Leann solves this by:**
|
||||||
|
|
||||||
- ✅ **Zero embedding storage** - Only graph structure is persisted
|
- ✅ **Zero embedding storage** - Only graph structure is persisted
|
||||||
- ✅ **Real-time computation** - Embeddings computed on-demand with ms latency
|
- ✅ **Real-time computation** - Embeddings computed on-demand with ms latency
|
||||||
- ✅ **Memory efficient** - Runs on consumer hardware (8GB RAM)
|
- ✅ **Memory efficient** - Runs on consumer hardware (8GB RAM)
|
||||||
- ✅ **Always fresh** - No stale embeddings, ever
|
- ✅ **Always fresh** - No stale embeddings, ever
|
||||||
|
|
||||||
@@ -46,6 +48,19 @@ Traditional RAG systems face a fundamental trade-off:
|
|||||||
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
|
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
|
||||||
cd leann
|
cd leann
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
|
```
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf
|
||||||
|
export CC=$(brew --prefix llvm)/bin/clang
|
||||||
|
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (Ubuntu/Debian):**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
|
||||||
uv sync
|
uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -75,19 +90,39 @@ for result in results:
|
|||||||
uv run examples/document_search.py
|
uv run examples/document_search.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
or you want to use python
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate
|
||||||
|
python ./examples/main_cli_example.py
|
||||||
|
```
|
||||||
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
|
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
|
||||||
|
|
||||||
This demo showcases how to build a RAG system for PDF documents using Leann.
|
This demo showcases how to build a RAG system for PDF documents using Leann.
|
||||||
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
|
|
||||||
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
|
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
|
||||||
|
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run examples/main_cli_example.py
|
uv run examples/main_cli_example.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Regenerating Protobuf Files
|
||||||
|
|
||||||
|
If you modify any `.proto` files (such as `embedding.proto`), or if you see errors about protobuf version mismatch, **regenerate the C++ protobuf files** to match your installed version:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
protoc --cpp_out=third_party/DiskANN/include --proto_path=third_party embedding.proto
|
||||||
|
protoc --cpp_out=third_party/DiskANN/src --proto_path=third_party embedding.proto
|
||||||
|
```
|
||||||
|
|
||||||
|
This ensures the generated files are compatible with your system's protobuf library.
|
||||||
|
|
||||||
## ✨ Features
|
## ✨ Features
|
||||||
|
|
||||||
### 🔥 Core Features
|
### 🔥 Core Features
|
||||||
|
|
||||||
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
|
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
|
||||||
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
|
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
|
||||||
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
|
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
|
||||||
@@ -95,6 +130,7 @@ uv run examples/main_cli_example.py
|
|||||||
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
|
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
|
||||||
|
|
||||||
### 🛠️ Technical Highlights
|
### 🛠️ Technical Highlights
|
||||||
|
|
||||||
- **Zero-copy operations** for maximum performance
|
- **Zero-copy operations** for maximum performance
|
||||||
- **SIMD-optimized** distance computations (AVX2/AVX512)
|
- **SIMD-optimized** distance computations (AVX2/AVX512)
|
||||||
- **Async embedding pipeline** with batched processing
|
- **Async embedding pipeline** with batched processing
|
||||||
@@ -102,6 +138,7 @@ uv run examples/main_cli_example.py
|
|||||||
- **Recompute mode** for highest accuracy scenarios
|
- **Recompute mode** for highest accuracy scenarios
|
||||||
|
|
||||||
### 🎨 Developer Experience
|
### 🎨 Developer Experience
|
||||||
|
|
||||||
- **Simple Python API** - Get started in minutes
|
- **Simple Python API** - Get started in minutes
|
||||||
- **Extensible backend system** - Easy to add new algorithms
|
- **Extensible backend system** - Easy to add new algorithms
|
||||||
- **Comprehensive examples** - From basic usage to production deployment
|
- **Comprehensive examples** - From basic usage to production deployment
|
||||||
@@ -111,19 +148,19 @@ uv run examples/main_cli_example.py
|
|||||||
|
|
||||||
### Memory Usage Comparison
|
### Memory Usage Comparison
|
||||||
|
|
||||||
| System | 1M Documents | 10M Documents | 100M Documents |
|
| System | 1M Documents | 10M Documents | 100M Documents |
|
||||||
|--------|-------------|---------------|----------------|
|
| --------------------- | ---------------- | ---------------- | ---------------- |
|
||||||
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
|
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
|
||||||
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
|
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
|
||||||
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
|
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
|
||||||
|
|
||||||
### Query Performance
|
### Query Performance
|
||||||
|
|
||||||
| Backend | Index Size | Query Time | Recall@10 |
|
| Backend | Index Size | Query Time | Recall@10 |
|
||||||
|---------|------------|------------|-----------|
|
| ------------------- | ---------- | ---------- | --------- |
|
||||||
| DiskANN | 1M docs | 12ms | 0.95 |
|
| DiskANN | 1M docs | 12ms | 0.95 |
|
||||||
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
|
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
|
||||||
| HNSW | 1M docs | 8ms | 0.93 |
|
| HNSW | 1M docs | 8ms | 0.93 |
|
||||||
|
|
||||||
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
|
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
|
||||||
|
|
||||||
@@ -145,26 +182,29 @@ uv run examples/main_cli_example.py
|
|||||||
### Key Components
|
### Key Components
|
||||||
|
|
||||||
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
|
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
|
||||||
2. **📊 Graph Index**: Memory-efficient navigation structures
|
2. **📊 Graph Index**: Memory-efficient navigation structures
|
||||||
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
|
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
|
||||||
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
|
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
|
||||||
|
|
||||||
## 🎓 Supported Models & Backends
|
## 🎓 Supported Models & Backends
|
||||||
|
|
||||||
### 🤖 Embedding Models
|
### 🤖 Embedding Models
|
||||||
|
|
||||||
- **sentence-transformers/all-mpnet-base-v2** (default)
|
- **sentence-transformers/all-mpnet-base-v2** (default)
|
||||||
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
|
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
|
||||||
- Any HuggingFace sentence-transformer model
|
- Any HuggingFace sentence-transformer model
|
||||||
- Custom model support via API
|
- Custom model support via API
|
||||||
|
|
||||||
### 🔧 Search Backends
|
### 🔧 Search Backends
|
||||||
|
|
||||||
- **DiskANN**: Microsoft's billion-scale ANN algorithm
|
- **DiskANN**: Microsoft's billion-scale ANN algorithm
|
||||||
- **HNSW**: Hierarchical Navigable Small World graphs
|
- **HNSW**: Hierarchical Navigable Small World graphs
|
||||||
- **Coming soon**: ScaNN, Faiss-IVF, NGT
|
- **Coming soon**: ScaNN, Faiss-IVF, NGT
|
||||||
|
|
||||||
### 📏 Distance Functions
|
### 📏 Distance Functions
|
||||||
|
|
||||||
- **L2**: Euclidean distance for precise similarity
|
- **L2**: Euclidean distance for precise similarity
|
||||||
- **Cosine**: Angular similarity for normalized vectors
|
- **Cosine**: Angular similarity for normalized vectors
|
||||||
- **MIPS**: Maximum Inner Product Search for recommendation systems
|
- **MIPS**: Maximum Inner Product Search for recommendation systems
|
||||||
|
|
||||||
## 🔬 Paper
|
## 🔬 Paper
|
||||||
@@ -188,6 +228,7 @@ If you find Leann useful, please cite:
|
|||||||
## 🌍 Use Cases
|
## 🌍 Use Cases
|
||||||
|
|
||||||
### 💼 Enterprise RAG
|
### 💼 Enterprise RAG
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Handle millions of documents with limited resources
|
# Handle millions of documents with limited resources
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
@@ -198,7 +239,8 @@ builder = LeannBuilder(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🔬 Research & Experimentation
|
### 🔬 Research & Experimentation
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Quick prototyping with different algorithms
|
# Quick prototyping with different algorithms
|
||||||
for backend in ["diskann", "hnsw"]:
|
for backend in ["diskann", "hnsw"]:
|
||||||
@@ -207,6 +249,7 @@ for backend in ["diskann", "hnsw"]:
|
|||||||
```
|
```
|
||||||
|
|
||||||
### 🚀 Real-time Applications
|
### 🚀 Real-time Applications
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Sub-second response times
|
# Sub-second response times
|
||||||
chat = LeannChat("knowledge.leann")
|
chat = LeannChat("knowledge.leann")
|
||||||
@@ -219,6 +262,7 @@ response = chat.ask("What is quantum computing?")
|
|||||||
We welcome contributions! Leann is built by the community, for the community.
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
### Ways to Contribute
|
### Ways to Contribute
|
||||||
|
|
||||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
@@ -226,14 +270,17 @@ We welcome contributions! Leann is built by the community, for the community.
|
|||||||
- 🧪 **Benchmarks**: Share your performance results
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
|
|
||||||
### Development Setup
|
### Development Setup
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/yourname/leann
|
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
|
||||||
cd leann
|
cd leann
|
||||||
|
git submodule update --init --recursive
|
||||||
uv sync --dev
|
uv sync --dev
|
||||||
uv run pytest tests/
|
uv run pytest tests/
|
||||||
```
|
```
|
||||||
|
|
||||||
### Quick Tests
|
### Quick Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Sanity check all distance functions
|
# Sanity check all distance functions
|
||||||
uv run python tests/sanity_checks/test_distance_functions.py
|
uv run python tests/sanity_checks/test_distance_functions.py
|
||||||
@@ -241,17 +288,21 @@ uv run python tests/sanity_checks/test_distance_functions.py
|
|||||||
# Verify L2 implementation
|
# Verify L2 implementation
|
||||||
uv run python tests/sanity_checks/test_l2_verification.py
|
uv run python tests/sanity_checks/test_l2_verification.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## ❓ FAQ
|
## ❓ FAQ
|
||||||
|
|
||||||
### Common Issues
|
### Common Issues
|
||||||
|
|
||||||
#### NCCL Topology Error
|
#### NCCL Topology Error
|
||||||
|
|
||||||
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
||||||
|
|
||||||
```
|
```
|
||||||
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
||||||
```
|
```
|
||||||
|
|
||||||
**Solution**: Set these environment variables before running your script:
|
**Solution**: Set these environment variables before running your script:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
||||||
export NCCL_DEBUG=INFO
|
export NCCL_DEBUG=INFO
|
||||||
@@ -259,23 +310,26 @@ export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
|||||||
export NCCL_IB_DISABLE=1
|
export NCCL_IB_DISABLE=1
|
||||||
export NCCL_NET_PLUGIN=none
|
export NCCL_NET_PLUGIN=none
|
||||||
export NCCL_SOCKET_IFNAME=ens5
|
export NCCL_SOCKET_IFNAME=ens5
|
||||||
|
```
|
||||||
|
|
||||||
## 📈 Roadmap
|
## 📈 Roadmap
|
||||||
|
|
||||||
### 🎯 Q1 2024
|
### 🎯 Q1 2024
|
||||||
- [x] DiskANN backend with MIPS/L2/Cosine support
|
|
||||||
- [x] HNSW backend integration
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
- [x] Real-time embedding pipeline
|
- [X] HNSW backend integration
|
||||||
- [x] Memory-efficient graph pruning
|
- [X] Real-time embedding pipeline
|
||||||
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
### 🚀 Q2 2024
|
### 🚀 Q2 2024
|
||||||
|
|
||||||
- [ ] Distributed search across multiple nodes
|
- [ ] Distributed search across multiple nodes
|
||||||
- [ ] ScaNN backend support
|
- [ ] ScaNN backend support
|
||||||
- [ ] Advanced caching strategies
|
- [ ] Advanced caching strategies
|
||||||
- [ ] Kubernetes deployment guides
|
- [ ] Kubernetes deployment guides
|
||||||
|
|
||||||
### 🌟 Q3 2024
|
### 🌟 Q3 2024
|
||||||
|
|
||||||
- [ ] GPU-accelerated embedding computation
|
- [ ] GPU-accelerated embedding computation
|
||||||
- [ ] Approximate distance functions
|
- [ ] Approximate distance functions
|
||||||
- [ ] Integration with LangChain/LlamaIndex
|
- [ ] Integration with LangChain/LlamaIndex
|
||||||
@@ -297,7 +351,7 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
## 🙏 Acknowledgments
|
## 🙏 Acknowledgments
|
||||||
|
|
||||||
- **Microsoft Research** for the DiskANN algorithm
|
- **Microsoft Research** for the DiskANN algorithm
|
||||||
- **Meta AI** for FAISS and optimization insights
|
- **Meta AI** for FAISS and optimization insights
|
||||||
- **HuggingFace** for the transformer ecosystem
|
- **HuggingFace** for the transformer ecosystem
|
||||||
- **Our amazing contributors** who make this possible
|
- **Our amazing contributors** who make this possible
|
||||||
|
|
||||||
@@ -309,4 +363,5 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
Made with ❤️ by the Leann team
|
Made with ❤️ by the Leann team
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|||||||
130
examples/LEANN_email_reader.py
Normal file
130
examples/LEANN_email_reader.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import os
|
||||||
|
import email
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get('Subject', 'No Subject')
|
||||||
|
from_addr = msg.get('From', 'Unknown')
|
||||||
|
to_addr = msg.get('To', 'Unknown')
|
||||||
|
date = msg.get('Date', 'Unknown')
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
||||||
|
# if part.get_content_type() == "text/html":
|
||||||
|
# continue
|
||||||
|
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
# break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[EMAIL METADATA]
|
||||||
|
File: {filename}
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
[END METADATA]
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No separate metadata - everything is in the text
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_all_messages_directories(base_path: str) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Find all Messages directories under the given base path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_path: Base path to search for Messages directories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to Messages directories
|
||||||
|
"""
|
||||||
|
base_path_obj = Path(base_path)
|
||||||
|
messages_dirs = []
|
||||||
|
|
||||||
|
if not base_path_obj.exists():
|
||||||
|
print(f"Base path {base_path} does not exist")
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
# Find all Messages directories recursively
|
||||||
|
for messages_dir in base_path_obj.rglob("Messages"):
|
||||||
|
if messages_dir.is_dir():
|
||||||
|
messages_dirs.append(messages_dir)
|
||||||
|
print(f"Found Messages directory: {messages_dir}")
|
||||||
|
|
||||||
|
print(f"Found {len(messages_dirs)} Messages directories")
|
||||||
|
return messages_dirs
|
||||||
192
examples/email_data/email.py
Normal file
192
examples/email_data/email.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Contains simple parser for mbox files.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
|
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MboxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Extract messages from mailbox files.
|
||||||
|
Returns string including date, subject, sender, receiver and
|
||||||
|
content for each message.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
|
"Date: {_date}\n"
|
||||||
|
"From: {_from}\n"
|
||||||
|
"To: {_to}\n"
|
||||||
|
"Subject: {_subject}\n"
|
||||||
|
"Content: {_content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
max_count: int = 0,
|
||||||
|
message_format: str = DEFAULT_MESSAGE_FORMAT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Init params."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup # noqa
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.max_count = max_count
|
||||||
|
self.message_format = message_format
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
file: Path,
|
||||||
|
extra_info: Optional[Dict] = None,
|
||||||
|
fs: Optional[AbstractFileSystem] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Parse file into string."""
|
||||||
|
# Import required libraries
|
||||||
|
import mailbox
|
||||||
|
from email.parser import BytesParser
|
||||||
|
from email.policy import default
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but MboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
results: List[str] = []
|
||||||
|
# Load file using mailbox
|
||||||
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
|
|
||||||
|
# Iterate through all messages
|
||||||
|
for _, _msg in enumerate(mbox):
|
||||||
|
try:
|
||||||
|
msg: mailbox.mboxMessage = _msg
|
||||||
|
# Parse multipart messages
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
ctype = part.get_content_type()
|
||||||
|
cdispo = str(part.get("Content-Disposition"))
|
||||||
|
if "attachment" in cdispo:
|
||||||
|
print(f"Attachment found: {part.get_filename()}")
|
||||||
|
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||||
|
content = part.get_payload(decode=True) # decode
|
||||||
|
break
|
||||||
|
# Get plain message payload for non-multipart messages
|
||||||
|
else:
|
||||||
|
content = msg.get_payload(decode=True)
|
||||||
|
|
||||||
|
# Parse message HTML content and remove unneeded whitespace
|
||||||
|
soup = BeautifulSoup(content)
|
||||||
|
stripped_content = " ".join(soup.get_text().split())
|
||||||
|
# Format message to include date, sender, receiver and subject
|
||||||
|
msg_string = self.message_format.format(
|
||||||
|
_date=msg["date"],
|
||||||
|
_from=msg["from"],
|
||||||
|
_to=msg["to"],
|
||||||
|
_subject=msg["subject"],
|
||||||
|
_content=stripped_content,
|
||||||
|
)
|
||||||
|
# Add message string to results
|
||||||
|
results.append(msg_string)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
|
||||||
|
|
||||||
|
# Increment counter and return if max count is met
|
||||||
|
i += 1
|
||||||
|
if self.max_count > 0 and i >= self.max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
return [Document(text=result, metadata=extra_info or {}) for result in results]
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxMboxReader(MboxReader):
|
||||||
|
"""
|
||||||
|
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||||
|
|
||||||
|
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||||
|
1. Reading .emlx files from a directory
|
||||||
|
2. Converting them to mbox format in memory
|
||||||
|
3. Using the parent MboxReader's parsing logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
directory: Path,
|
||||||
|
extra_info: Optional[Dict] = None,
|
||||||
|
fs: Optional[AbstractFileSystem] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find all .emlx files in the directory
|
||||||
|
emlx_files = list(directory.glob("*.emlx"))
|
||||||
|
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||||
|
|
||||||
|
if not emlx_files:
|
||||||
|
logger.warning(f"No .emlx files found in {directory}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a temporary mbox file
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
||||||
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
|
# Convert .emlx files to mbox format
|
||||||
|
for emlx_file in emlx_files:
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx format: first line is length, rest is email content
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
|
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||||
|
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Close the temporary file so MboxReader can read it
|
||||||
|
temp_mbox.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the parent MboxReader's logic to parse the mbox file
|
||||||
|
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||||
|
finally:
|
||||||
|
# Clean up temporary file
|
||||||
|
try:
|
||||||
|
os.unlink(temp_mbox_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
229
examples/mail_reader_leann.py
Normal file
229
examples/mail_reader_leann.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import dotenv
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple mail data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages_dirs: List of Path objects pointing to Messages directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of emails to process per directory
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple mail data sources...")
|
||||||
|
|
||||||
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
|
from LEANN_email_reader import EmlxReader
|
||||||
|
reader = EmlxReader()
|
||||||
|
# from email_data.email import EmlxMboxReader
|
||||||
|
# from pathlib import Path
|
||||||
|
# reader = EmlxMboxReader()
|
||||||
|
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each Messages directory
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(messages_dir)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {messages_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000):
|
||||||
|
"""
|
||||||
|
Create LEANN index from mail data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mail_path: Path to the mail directory
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of emails to process
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from mail data...")
|
||||||
|
|
||||||
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
|
from LEANN_email_reader import EmlxReader
|
||||||
|
reader = EmlxReader()
|
||||||
|
# from email_data.email import EmlxMboxReader
|
||||||
|
# from pathlib import Path
|
||||||
|
# reader = EmlxMboxReader()
|
||||||
|
documents = reader.load_data(Path(mail_path))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} email documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=5,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=32,
|
||||||
|
beam_width=1
|
||||||
|
)
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Base path to the mail data directory
|
||||||
|
base_mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||||
|
|
||||||
|
INDEX_DIR = Path("./mail_index_leann_raw_text_all")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||||
|
|
||||||
|
# Find all Messages directories
|
||||||
|
from LEANN_email_reader import EmlxReader
|
||||||
|
messages_dirs = EmlxReader.find_all_messages_directories(base_mail_path)
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Messages directories found. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"how's the icloud related advertisement saying",
|
||||||
|
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
86
examples/mail_reader_llamaindex.py
Normal file
86
examples/mail_reader_llamaindex.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import VectorStoreIndex, StorageContext
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# --- EMBEDDING MODEL ---
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# --- END EMBEDDING MODEL ---
|
||||||
|
|
||||||
|
# Import EmlxReader from the new module
|
||||||
|
from LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000):
|
||||||
|
print("Creating index from mail data with embedded metadata...")
|
||||||
|
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
# Use facebook/contriever as the embedder
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||||
|
# set on device
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
embed_model._model.to("cuda")
|
||||||
|
# set mps
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
embed_model._model.to("mps")
|
||||||
|
else:
|
||||||
|
embed_model._model.to("cpu")
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents,
|
||||||
|
transformations=[text_splitter],
|
||||||
|
embed_model=embed_model
|
||||||
|
)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
index.storage_context.persist(persist_dir=save_dir)
|
||||||
|
print(f"Index saved to {save_dir}")
|
||||||
|
return index
|
||||||
|
|
||||||
|
def load_index(save_dir: str = "mail_index_embedded"):
|
||||||
|
try:
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
|
index = VectorStoreIndex.from_vector_store(
|
||||||
|
storage_context.vector_store,
|
||||||
|
storage_context=storage_context
|
||||||
|
)
|
||||||
|
print(f"Index loaded from {save_dir}")
|
||||||
|
return index
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def query_index(index, query: str):
|
||||||
|
if index is None:
|
||||||
|
print("No index available for querying.")
|
||||||
|
return
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
response = query_engine.query(query)
|
||||||
|
print(f"Query: {query}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||||
|
save_dir = "mail_index_embedded"
|
||||||
|
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||||
|
print("Loading existing index...")
|
||||||
|
index = load_index(save_dir)
|
||||||
|
else:
|
||||||
|
print("Creating new index...")
|
||||||
|
index = create_and_save_index(mail_path, save_dir, max_count=10000)
|
||||||
|
if index:
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"how's the icloud related advertisement saying"
|
||||||
|
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||||
|
]
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*50)
|
||||||
|
query_index(index, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -8,7 +8,6 @@ from llama_index.node_parser.docling import DoclingNodeParser
|
|||||||
from llama_index.readers.docling import DoclingReader
|
from llama_index.readers.docling import DoclingReader
|
||||||
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
|
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
import shutil
|
import shutil
|
||||||
@@ -22,9 +21,11 @@ file_extractor: dict[str, BaseReader] = {
|
|||||||
".pptx": reader,
|
".pptx": reader,
|
||||||
".pdf": reader,
|
".pdf": reader,
|
||||||
".xlsx": reader,
|
".xlsx": reader,
|
||||||
|
".txt": reader,
|
||||||
|
".md": reader,
|
||||||
}
|
}
|
||||||
node_parser = DoclingNodeParser(
|
node_parser = DoclingNodeParser(
|
||||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=64)
|
chunker=HybridChunker(tokenizer="facebook/contriever", max_tokens=128)
|
||||||
)
|
)
|
||||||
print("Loading documents...")
|
print("Loading documents...")
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
@@ -32,7 +33,7 @@ documents = SimpleDirectoryReader(
|
|||||||
recursive=True,
|
recursive=True,
|
||||||
file_extractor=file_extractor,
|
file_extractor=file_extractor,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
|
required_exts=[".pdf", ".docx", ".pptx", ".xlsx", ".txt", ".md"]
|
||||||
).load_data(show_progress=True)
|
).load_data(show_progress=True)
|
||||||
print("Documents loaded.")
|
print("Documents loaded.")
|
||||||
all_texts = []
|
all_texts = []
|
||||||
@@ -41,7 +42,7 @@ for doc in documents:
|
|||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_pdf_index")
|
INDEX_DIR = Path("./test_pdf_index_pangu_test")
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
if not INDEX_DIR.exists():
|
||||||
@@ -49,14 +50,15 @@ if not INDEX_DIR.exists():
|
|||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# CSR compact mode with recompute
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model="facebook/contriever",
|
embedding_model="facebook/contriever",
|
||||||
graph_degree=32,
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
is_recompute=True
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||||
@@ -80,14 +82,17 @@ async def main(args):
|
|||||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
|
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||||
|
query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True)
|
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
||||||
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf"], help="The LLM backend to use.")
|
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf", "openai"], help="The LLM backend to use.")
|
||||||
parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf).")
|
parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).")
|
||||||
parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.")
|
parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
319
examples/multi_vector_aggregator.py
Normal file
319
examples/multi_vector_aggregator.py
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Multi-Vector Aggregator for Fat Embeddings
|
||||||
|
==========================================
|
||||||
|
|
||||||
|
This module implements aggregation strategies for multi-vector embeddings,
|
||||||
|
similar to ColPali's approach where multiple patch vectors represent a single document.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- MaxSim aggregation (take maximum similarity across patches)
|
||||||
|
- Voting-based aggregation (count patch matches)
|
||||||
|
- Weighted aggregation (attention-score weighted)
|
||||||
|
- Spatial clustering of matching patches
|
||||||
|
- Document-level result consolidation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Any, Tuple, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PatchResult:
|
||||||
|
"""Represents a single patch search result."""
|
||||||
|
patch_id: int
|
||||||
|
image_name: str
|
||||||
|
image_path: str
|
||||||
|
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||||
|
score: float
|
||||||
|
attention_score: float
|
||||||
|
scale: float
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AggregatedResult:
|
||||||
|
"""Represents an aggregated document-level result."""
|
||||||
|
image_name: str
|
||||||
|
image_path: str
|
||||||
|
doc_score: float
|
||||||
|
patch_count: int
|
||||||
|
best_patch: PatchResult
|
||||||
|
all_patches: List[PatchResult]
|
||||||
|
aggregation_method: str
|
||||||
|
spatial_clusters: Optional[List[List[PatchResult]]] = None
|
||||||
|
|
||||||
|
class MultiVectorAggregator:
|
||||||
|
"""
|
||||||
|
Aggregates multiple patch-level results into document-level results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
aggregation_method: str = "maxsim",
|
||||||
|
spatial_clustering: bool = True,
|
||||||
|
cluster_distance_threshold: float = 100.0):
|
||||||
|
"""
|
||||||
|
Initialize the aggregator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||||
|
spatial_clustering: Whether to cluster spatially close patches
|
||||||
|
cluster_distance_threshold: Distance threshold for spatial clustering
|
||||||
|
"""
|
||||||
|
self.aggregation_method = aggregation_method
|
||||||
|
self.spatial_clustering = spatial_clustering
|
||||||
|
self.cluster_distance_threshold = cluster_distance_threshold
|
||||||
|
|
||||||
|
def aggregate_results(self,
|
||||||
|
search_results: List[Dict[str, Any]],
|
||||||
|
top_k: int = 10) -> List[AggregatedResult]:
|
||||||
|
"""
|
||||||
|
Aggregate patch-level search results into document-level results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of search results from LeannSearcher
|
||||||
|
top_k: Number of top documents to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of aggregated document results
|
||||||
|
"""
|
||||||
|
# Group results by image
|
||||||
|
image_groups = defaultdict(list)
|
||||||
|
|
||||||
|
for result in search_results:
|
||||||
|
metadata = result.metadata
|
||||||
|
if "image_name" in metadata and "patch_id" in metadata:
|
||||||
|
patch_result = PatchResult(
|
||||||
|
patch_id=metadata["patch_id"],
|
||||||
|
image_name=metadata["image_name"],
|
||||||
|
image_path=metadata["image_path"],
|
||||||
|
coordinates=tuple(metadata["coordinates"]),
|
||||||
|
score=result.score,
|
||||||
|
attention_score=metadata.get("attention_score", 0.0),
|
||||||
|
scale=metadata.get("scale", 1.0),
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
image_groups[metadata["image_name"]].append(patch_result)
|
||||||
|
|
||||||
|
# Aggregate each image group
|
||||||
|
aggregated_results = []
|
||||||
|
for image_name, patches in image_groups.items():
|
||||||
|
if len(patches) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||||
|
aggregated_results.append(agg_result)
|
||||||
|
|
||||||
|
# Sort by aggregated score and return top-k
|
||||||
|
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||||
|
return aggregated_results[:top_k]
|
||||||
|
|
||||||
|
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
|
||||||
|
"""Aggregate patches for a single image."""
|
||||||
|
|
||||||
|
if self.aggregation_method == "maxsim":
|
||||||
|
doc_score = max(patch.score for patch in patches)
|
||||||
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
|
elif self.aggregation_method == "voting":
|
||||||
|
# Count patches above threshold
|
||||||
|
threshold = np.percentile([p.score for p in patches], 75)
|
||||||
|
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||||
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
|
elif self.aggregation_method == "weighted":
|
||||||
|
# Weight by attention scores
|
||||||
|
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||||
|
total_weights = sum(p.attention_score for p in patches)
|
||||||
|
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||||
|
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||||
|
|
||||||
|
elif self.aggregation_method == "mean":
|
||||||
|
doc_score = np.mean([patch.score for patch in patches])
|
||||||
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||||
|
|
||||||
|
# Spatial clustering if enabled
|
||||||
|
spatial_clusters = None
|
||||||
|
if self.spatial_clustering:
|
||||||
|
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||||
|
|
||||||
|
return AggregatedResult(
|
||||||
|
image_name=image_name,
|
||||||
|
image_path=patches[0].image_path,
|
||||||
|
doc_score=float(doc_score),
|
||||||
|
patch_count=len(patches),
|
||||||
|
best_patch=best_patch,
|
||||||
|
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||||
|
aggregation_method=self.aggregation_method,
|
||||||
|
spatial_clusters=spatial_clusters
|
||||||
|
)
|
||||||
|
|
||||||
|
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
|
||||||
|
"""Cluster patches that are spatially close to each other."""
|
||||||
|
if len(patches) <= 1:
|
||||||
|
return [patches]
|
||||||
|
|
||||||
|
clusters = []
|
||||||
|
remaining_patches = patches.copy()
|
||||||
|
|
||||||
|
while remaining_patches:
|
||||||
|
# Start new cluster with highest scoring remaining patch
|
||||||
|
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||||
|
current_cluster = [seed_patch]
|
||||||
|
remaining_patches.remove(seed_patch)
|
||||||
|
|
||||||
|
# Add nearby patches to cluster
|
||||||
|
added_to_cluster = True
|
||||||
|
while added_to_cluster:
|
||||||
|
added_to_cluster = False
|
||||||
|
for patch in remaining_patches.copy():
|
||||||
|
if self._is_patch_nearby(patch, current_cluster):
|
||||||
|
current_cluster.append(patch)
|
||||||
|
remaining_patches.remove(patch)
|
||||||
|
added_to_cluster = True
|
||||||
|
|
||||||
|
clusters.append(current_cluster)
|
||||||
|
|
||||||
|
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||||
|
|
||||||
|
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
|
||||||
|
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||||
|
patch_center = self._get_patch_center(patch.coordinates)
|
||||||
|
|
||||||
|
for cluster_patch in cluster:
|
||||||
|
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||||
|
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
|
||||||
|
(patch_center[1] - cluster_center[1])**2)
|
||||||
|
|
||||||
|
if distance <= self.cluster_distance_threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
||||||
|
"""Get center point of a patch."""
|
||||||
|
x1, y1, x2, y2 = coordinates
|
||||||
|
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||||
|
|
||||||
|
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
|
||||||
|
"""Pretty print aggregated results."""
|
||||||
|
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f"\n{i+1}. {result.image_name}")
|
||||||
|
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||||
|
print(f" Path: {result.image_path}")
|
||||||
|
|
||||||
|
# Show best patch
|
||||||
|
best = result.best_patch
|
||||||
|
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
|
||||||
|
|
||||||
|
# Show top patches
|
||||||
|
print(f" 📍 Top Patches:")
|
||||||
|
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||||
|
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
|
||||||
|
|
||||||
|
# Show spatial clusters if available
|
||||||
|
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||||
|
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||||
|
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||||
|
cluster_score = max(p.score for p in cluster)
|
||||||
|
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
|
||||||
|
|
||||||
|
def demo_aggregation():
|
||||||
|
"""Demonstrate the multi-vector aggregation functionality."""
|
||||||
|
print("=== Multi-Vector Aggregation Demo ===")
|
||||||
|
|
||||||
|
# Simulate some patch-level search results
|
||||||
|
# In real usage, these would come from LeannSearcher.search()
|
||||||
|
|
||||||
|
class MockResult:
|
||||||
|
def __init__(self, score, metadata):
|
||||||
|
self.score = score
|
||||||
|
self.metadata = metadata
|
||||||
|
|
||||||
|
# Simulate results for 2 images with multiple patches each
|
||||||
|
mock_results = [
|
||||||
|
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||||
|
MockResult(0.85, {
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 3,
|
||||||
|
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||||
|
"attention_score": 0.92,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.78, {
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 7,
|
||||||
|
"coordinates": [200, 300, 324, 424], # Cat area
|
||||||
|
"attention_score": 0.88,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.72, {
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 12,
|
||||||
|
"coordinates": [150, 100, 274, 224], # Appliances
|
||||||
|
"attention_score": 0.75,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.65, {
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 15,
|
||||||
|
"coordinates": [50, 250, 174, 374], # Furniture
|
||||||
|
"attention_score": 0.70,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
|
||||||
|
# Image 2: city_street.jpg - 3 patches
|
||||||
|
MockResult(0.68, {
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 2,
|
||||||
|
"coordinates": [300, 100, 424, 224], # Buildings
|
||||||
|
"attention_score": 0.80,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.62, {
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 8,
|
||||||
|
"coordinates": [100, 350, 224, 474], # Street level
|
||||||
|
"attention_score": 0.75,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
MockResult(0.55, {
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 11,
|
||||||
|
"coordinates": [400, 200, 524, 324], # Sky area
|
||||||
|
"attention_score": 0.60,
|
||||||
|
"scale": 1.0
|
||||||
|
}),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test different aggregation methods
|
||||||
|
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||||
|
|
||||||
|
for method in methods:
|
||||||
|
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
|
||||||
|
|
||||||
|
aggregator = MultiVectorAggregator(
|
||||||
|
aggregation_method=method,
|
||||||
|
spatial_clustering=True,
|
||||||
|
cluster_distance_threshold=100.0
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||||
|
aggregator.print_aggregated_results(aggregated)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo_aggregation()
|
||||||
18
examples/resue_index.py
Normal file
18
examples/resue_index.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import asyncio
|
||||||
|
from leann.api import LeannChat
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
INDEX_DIR = Path("./test_pdf_index_huawei")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH)
|
||||||
|
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||||
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
|
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
|
||||||
|
print(f"\n[PHASE 2] Response: {response}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "0.1.0",
|
|
||||||
"backend_name": "diskann",
|
|
||||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
|
||||||
"num_chunks": 6,
|
|
||||||
"chunks": [
|
|
||||||
{
|
|
||||||
"text": "Python is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Machine learning transforms industries",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Neural networks process complex data",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Java is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "C++ is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "C# is a powerful programming language",
|
|
||||||
"metadata": {}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
1
packages/__init__.py
Normal file
1
packages/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
1
packages/leann-backend-diskann/__init__.py
Normal file
1
packages/leann-backend-diskann/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# This file makes the directory a Python package
|
||||||
@@ -15,6 +15,8 @@ import os
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import zmq
|
import zmq
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import pickle
|
||||||
|
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
@@ -109,8 +111,6 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
|||||||
Load passages from a JSONL file with label map support
|
Load passages from a JSONL file with label map support
|
||||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
if not os.path.exists(passages_file):
|
if not os.path.exists(passages_file):
|
||||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||||
@@ -210,7 +210,6 @@ def create_embedding_server_thread(
|
|||||||
passages = load_passages_from_metadata(passages_file)
|
passages = load_passages_from_metadata(passages_file)
|
||||||
else:
|
else:
|
||||||
# Try to find metadata file in same directory
|
# Try to find metadata file in same directory
|
||||||
from pathlib import Path
|
|
||||||
passages_dir = Path(passages_file).parent
|
passages_dir = Path(passages_file).parent
|
||||||
meta_files = list(passages_dir.glob("*.meta.json"))
|
meta_files = list(passages_dir.glob("*.meta.json"))
|
||||||
if meta_files:
|
if meta_files:
|
||||||
|
|||||||
@@ -2,6 +2,33 @@
|
|||||||
cmake_minimum_required(VERSION 3.24)
|
cmake_minimum_required(VERSION 3.24)
|
||||||
project(leann_backend_hnsw_wrapper)
|
project(leann_backend_hnsw_wrapper)
|
||||||
|
|
||||||
|
# Set OpenMP path for macOS
|
||||||
|
if(APPLE)
|
||||||
|
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||||
|
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||||
|
set(OpenMP_C_LIB_NAMES "omp")
|
||||||
|
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||||
|
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Build ZeroMQ from source
|
||||||
|
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
|
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
||||||
|
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
||||||
|
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
||||||
|
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
||||||
|
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
||||||
|
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
||||||
|
add_subdirectory(third_party/libzmq)
|
||||||
|
|
||||||
|
# Add cppzmq headers
|
||||||
|
include_directories(third_party/cppzmq)
|
||||||
|
|
||||||
|
# Configure msgpack-c - disable boost dependency
|
||||||
|
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||||
|
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||||
|
include_directories(third_party/msgpack-c/include)
|
||||||
|
|
||||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
import pickle
|
import pickle
|
||||||
|
import shutil
|
||||||
|
|
||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||||
@@ -77,17 +78,29 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
self._convert_to_csr(index_file)
|
self._convert_to_csr(index_file)
|
||||||
|
|
||||||
def _convert_to_csr(self, index_file: Path):
|
def _convert_to_csr(self, index_file: Path):
|
||||||
|
"""Convert built index to CSR format"""
|
||||||
|
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||||
|
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||||
|
|
||||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||||
|
|
||||||
success = convert_hnsw_graph_to_csr(
|
success = convert_hnsw_graph_to_csr(
|
||||||
str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
|
str(index_file),
|
||||||
|
str(csr_temp_file),
|
||||||
|
prune_embeddings=self.is_recompute
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
import shutil
|
print("✅ CSR conversion successful.")
|
||||||
|
index_file_old = index_file.with_suffix(".old")
|
||||||
|
shutil.move(str(index_file), str(index_file_old))
|
||||||
shutil.move(str(csr_temp_file), str(index_file))
|
shutil.move(str(csr_temp_file), str(index_file))
|
||||||
|
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
||||||
else:
|
else:
|
||||||
|
# Clean up and fail fast
|
||||||
if csr_temp_file.exists():
|
if csr_temp_file.exists():
|
||||||
os.remove(csr_temp_file)
|
os.remove(csr_temp_file)
|
||||||
raise RuntimeError("CSR conversion failed")
|
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
||||||
|
|
||||||
class HNSWSearcher(BaseSearcher):
|
class HNSWSearcher(BaseSearcher):
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
@@ -99,7 +112,10 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
|
|
||||||
self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta()
|
self.is_compact, self.is_pruned = (
|
||||||
|
self.meta.get('is_compact', True),
|
||||||
|
self.meta.get('is_pruned', True)
|
||||||
|
)
|
||||||
|
|
||||||
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||||
if not index_file.exists():
|
if not index_file.exists():
|
||||||
@@ -114,11 +130,6 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
|
|
||||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||||
|
|
||||||
def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]:
|
|
||||||
is_compact = self.meta.get('is_compact', True)
|
|
||||||
is_pruned = self.meta.get('is_pruned', True)
|
|
||||||
return is_compact, is_pruned
|
|
||||||
|
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
||||||
from . import faiss
|
from . import faiss
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,14 @@
|
|||||||
# packages/leann-core/src/leann/__init__.py
|
# packages/leann-core/src/leann/__init__.py
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
|
||||||
|
# Fix OpenMP threading issues on macOS ARM64
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
os.environ["OMP_NUM_THREADS"] = "1"
|
||||||
|
os.environ["MKL_NUM_THREADS"] = "1"
|
||||||
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
|
os.environ["KMP_BLOCKTIME"] = "0"
|
||||||
|
|
||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
"""
|
||||||
This file contains the core API for the LEANN project, now definitively updated
|
This file contains the core API for the LEANN project, now definitively updated
|
||||||
with the correct, original embedding logic from the user's reference code.
|
with the correct, original embedding logic from the user's reference code.
|
||||||
@@ -11,6 +10,7 @@ from pathlib import Path
|
|||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import uuid
|
import uuid
|
||||||
|
import torch
|
||||||
|
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
@@ -25,13 +25,22 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"sentence-transformers not available. Install with: pip install sentence-transformers"
|
f"sentence-transformers not available. Install with: pip install sentence-transformers"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
# Load model using sentence-transformers
|
# Load model using sentence-transformers
|
||||||
model = SentenceTransformer(model_name)
|
model = SentenceTransformer(model_name)
|
||||||
|
|
||||||
|
model = model.half()
|
||||||
|
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
|
||||||
|
# use acclerater GPU or MAC GPU
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
model = model.to("cuda")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
model = model.to("mps")
|
||||||
|
|
||||||
# Generate embeddings
|
# Generate embeddings
|
||||||
embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64)
|
embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
# --- Core API Classes (Restored and Unchanged) ---
|
# --- Core API Classes (Restored and Unchanged) ---
|
||||||
@@ -181,5 +190,25 @@ class LeannChat:
|
|||||||
def ask(self, question: str, top_k=5, **kwargs):
|
def ask(self, question: str, top_k=5, **kwargs):
|
||||||
results = self.searcher.search(question, top_k=top_k, **kwargs)
|
results = self.searcher.search(question, top_k=top_k, **kwargs)
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
|
prompt = (
|
||||||
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
|
f"{context}\n\n"
|
||||||
|
f"Question: {question}\n\n"
|
||||||
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
|
)
|
||||||
|
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
|
||||||
|
|
||||||
|
def start_interactive(self):
|
||||||
|
print("\nLeann Chat started (type 'quit' to exit)")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = input("You: ").strip()
|
||||||
|
if user_input.lower() in ['quit', 'exit']:
|
||||||
|
break
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
response = self.ask(user_input)
|
||||||
|
print(f"Leann: {response}")
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
@@ -7,6 +7,7 @@ supporting different backends like Ollama, Hugging Face Transformers, and a simu
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -95,7 +96,57 @@ class HFChat(LLMInterface):
|
|||||||
}
|
}
|
||||||
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
||||||
results = self.pipeline(prompt, **params)
|
results = self.pipeline(prompt, **params)
|
||||||
return results[0]['generated_text']
|
|
||||||
|
# Handle different response formats from transformers
|
||||||
|
if isinstance(results, list) and len(results) > 0:
|
||||||
|
generated_text = results[0].get('generated_text', '') if isinstance(results[0], dict) else str(results[0])
|
||||||
|
else:
|
||||||
|
generated_text = str(results)
|
||||||
|
|
||||||
|
# Extract only the newly generated portion by removing the original prompt
|
||||||
|
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
||||||
|
response = generated_text[len(prompt):].strip()
|
||||||
|
else:
|
||||||
|
# Fallback: return the full response if prompt removal fails
|
||||||
|
response = str(generated_text)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
class OpenAIChat(LLMInterface):
|
||||||
|
"""LLM interface for OpenAI models."""
|
||||||
|
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.")
|
||||||
|
|
||||||
|
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
self.client = openai.OpenAI(api_key=self.api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'.")
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
# Default parameters for OpenAI
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(**params)
|
||||||
|
return response.choices[0].message.content.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error communicating with OpenAI: {e}")
|
||||||
|
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||||
|
|
||||||
class SimulatedChat(LLMInterface):
|
class SimulatedChat(LLMInterface):
|
||||||
"""A simple simulated chat for testing and development."""
|
"""A simple simulated chat for testing and development."""
|
||||||
@@ -127,9 +178,11 @@ def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
|||||||
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
|
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
|
||||||
|
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
return OllamaChat(model=model, host=llm_config.get("host"))
|
return OllamaChat(model=model or "llama3:8b", host=llm_config.get("host", "http://localhost:11434"))
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
return HFChat(model_name=model)
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
|
elif llm_type == "openai":
|
||||||
|
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
return SimulatedChat()
|
return SimulatedChat()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ dependencies = [
|
|||||||
"colorama",
|
"colorama",
|
||||||
"boto3",
|
"boto3",
|
||||||
"protobuf==4.25.3",
|
"protobuf==4.25.3",
|
||||||
"sglang[all]",
|
"sglang",
|
||||||
"ollama",
|
"ollama",
|
||||||
"requests>=2.25.0",
|
"requests>=2.25.0",
|
||||||
"sentence-transformers>=2.2.0",
|
"sentence-transformers>=2.2.0",
|
||||||
|
|||||||
147
test/mail_reader_llamaindex.py
Normal file
147
test/mail_reader_llamaindex.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
import os
|
||||||
|
import email
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import VectorStoreIndex, Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get('Subject', 'No Subject')
|
||||||
|
from_addr = msg.get('From', 'Unknown')
|
||||||
|
to_addr = msg.get('To', 'Unknown')
|
||||||
|
date = msg.get('Date', 'Unknown')
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
||||||
|
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
# break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Create document content
|
||||||
|
doc_content = f"""
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
'file_path': filepath,
|
||||||
|
'subject': subject,
|
||||||
|
'from': from_addr,
|
||||||
|
'to': to_addr,
|
||||||
|
'date': date,
|
||||||
|
'filename': filename
|
||||||
|
}
|
||||||
|
if count == 0:
|
||||||
|
print("--------------------------------")
|
||||||
|
print('dir path', dirpath)
|
||||||
|
print(metadata)
|
||||||
|
print(doc_content)
|
||||||
|
print("--------------------------------")
|
||||||
|
body=[]
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
print("-------------------------------- get content type -------------------------------")
|
||||||
|
print(part.get_content_type())
|
||||||
|
print(part)
|
||||||
|
# body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore'))
|
||||||
|
print("-------------------------------- get content type -------------------------------")
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
print(body)
|
||||||
|
|
||||||
|
print(body)
|
||||||
|
print("--------------------------------")
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"!!!!!!! Error parsing email from {filepath}: {e} !!!!!!!!")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"!!!!!!! Error reading file !!!!!!!! {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
# Use the custom EmlxReader instead of MboxReader
|
||||||
|
documents = EmlxReader().load_data(
|
||||||
|
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||||
|
max_count=1000
|
||||||
|
) # Returns list of documents
|
||||||
|
|
||||||
|
# Configure the index with larger chunk size to handle long metadata
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# Create a custom text splitter with larger chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200)
|
||||||
|
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents,
|
||||||
|
transformations=[text_splitter]
|
||||||
|
) # Initialize index with documents
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||||
|
print(res)
|
||||||
213
test/mail_reader_save_load.py
Normal file
213
test/mail_reader_save_load.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
import os
|
||||||
|
import email
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import VectorStoreIndex, Document, StorageContext
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get('Subject', 'No Subject')
|
||||||
|
from_addr = msg.get('From', 'Unknown')
|
||||||
|
to_addr = msg.get('To', 'Unknown')
|
||||||
|
date = msg.get('Date', 'Unknown')
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain":
|
||||||
|
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Create document content
|
||||||
|
doc_content = f"""
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
'file_path': filepath,
|
||||||
|
'subject': subject,
|
||||||
|
'from': from_addr,
|
||||||
|
'to': to_addr,
|
||||||
|
'date': date,
|
||||||
|
'filename': filename
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000):
|
||||||
|
"""
|
||||||
|
Create the index from mail data and save it to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mail_path: Path to the mail directory
|
||||||
|
save_dir: Directory to save the index
|
||||||
|
max_count: Maximum number of emails to process
|
||||||
|
"""
|
||||||
|
print("Creating index from mail data...")
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create text splitter
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0)
|
||||||
|
|
||||||
|
# Create index
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents,
|
||||||
|
transformations=[text_splitter]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the index
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
index.storage_context.persist(persist_dir=save_dir)
|
||||||
|
print(f"Index saved to {save_dir}")
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
def load_index(save_dir: str = "mail_index"):
|
||||||
|
"""
|
||||||
|
Load the saved index from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir: Directory where the index is saved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded index or None if loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load storage context
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
|
|
||||||
|
# Load index
|
||||||
|
index = VectorStoreIndex.from_vector_store(
|
||||||
|
storage_context.vector_store,
|
||||||
|
storage_context=storage_context
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Index loaded from {save_dir}")
|
||||||
|
return index
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def query_index(index, query: str):
|
||||||
|
"""
|
||||||
|
Query the loaded index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: The loaded index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
if index is None:
|
||||||
|
print("No index available for querying.")
|
||||||
|
return
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
response = query_engine.query(query)
|
||||||
|
print(f"Query: {query}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||||
|
save_dir = "mail_index"
|
||||||
|
|
||||||
|
# Check if index already exists
|
||||||
|
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||||
|
print("Loading existing index...")
|
||||||
|
index = load_index(save_dir)
|
||||||
|
else:
|
||||||
|
print("Creating new index...")
|
||||||
|
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||||
|
|
||||||
|
if index:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"What emails mention GSR appointments?",
|
||||||
|
"Find emails about deadlines"
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*50)
|
||||||
|
query_index(index, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
211
test/mail_reader_small_chunks.py
Normal file
211
test/mail_reader_small_chunks.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
import os
|
||||||
|
import email
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import VectorStoreIndex, Document, StorageContext
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with reduced metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get('Subject', 'No Subject')
|
||||||
|
from_addr = msg.get('From', 'Unknown')
|
||||||
|
to_addr = msg.get('To', 'Unknown')
|
||||||
|
date = msg.get('Date', 'Unknown')
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain":
|
||||||
|
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create minimal metadata (only essential info)
|
||||||
|
metadata = {
|
||||||
|
'subject': subject[:50], # Truncate subject
|
||||||
|
'from': from_addr[:30], # Truncate from
|
||||||
|
'date': date[:20], # Truncate date
|
||||||
|
'filename': filename # Keep filename
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000):
|
||||||
|
"""
|
||||||
|
Create the index from mail data and save it to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mail_path: Path to the mail directory
|
||||||
|
save_dir: Directory to save the index
|
||||||
|
max_count: Maximum number of emails to process
|
||||||
|
"""
|
||||||
|
print("Creating index from mail data with small chunks...")
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create text splitter with small chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
|
||||||
|
|
||||||
|
# Create index
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents,
|
||||||
|
transformations=[text_splitter]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the index
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
index.storage_context.persist(persist_dir=save_dir)
|
||||||
|
print(f"Index saved to {save_dir}")
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
def load_index(save_dir: str = "mail_index_small"):
|
||||||
|
"""
|
||||||
|
Load the saved index from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir: Directory where the index is saved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded index or None if loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load storage context
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
|
|
||||||
|
# Load index
|
||||||
|
index = VectorStoreIndex.from_vector_store(
|
||||||
|
storage_context.vector_store,
|
||||||
|
storage_context=storage_context
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Index loaded from {save_dir}")
|
||||||
|
return index
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def query_index(index, query: str):
|
||||||
|
"""
|
||||||
|
Query the loaded index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: The loaded index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
if index is None:
|
||||||
|
print("No index available for querying.")
|
||||||
|
return
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
response = query_engine.query(query)
|
||||||
|
print(f"Query: {query}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||||
|
save_dir = "mail_index_small"
|
||||||
|
|
||||||
|
# Check if index already exists
|
||||||
|
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||||
|
print("Loading existing index...")
|
||||||
|
index = load_index(save_dir)
|
||||||
|
else:
|
||||||
|
print("Creating new index...")
|
||||||
|
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||||
|
|
||||||
|
if index:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"What emails mention GSR appointments?",
|
||||||
|
"Find emails about deadlines"
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*50)
|
||||||
|
query_index(index, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
147
test/mail_reader_test.py
Normal file
147
test/mail_reader_test.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
import os
|
||||||
|
import email
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import VectorStoreIndex, Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Check if directory exists and is accessible
|
||||||
|
if not os.path.exists(input_dir):
|
||||||
|
print(f"Error: Directory '{input_dir}' does not exist")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
if not os.access(input_dir, os.R_OK):
|
||||||
|
print(f"Error: Directory '{input_dir}' is not accessible (permission denied)")
|
||||||
|
print("This is likely due to macOS security restrictions on Mail app data")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
print(f"Scanning directory: {input_dir}")
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
print(f"Found .emlx file: {filepath}")
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get('Subject', 'No Subject')
|
||||||
|
from_addr = msg.get('From', 'Unknown')
|
||||||
|
to_addr = msg.get('To', 'Unknown')
|
||||||
|
date = msg.get('Date', 'Unknown')
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain":
|
||||||
|
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Create document content
|
||||||
|
doc_content = f"""
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
'file_path': filepath,
|
||||||
|
'subject': subject,
|
||||||
|
'from': from_addr,
|
||||||
|
'to': to_addr,
|
||||||
|
'date': date,
|
||||||
|
'filename': filename
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Use the current directory where the sample.emlx file is located
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
print("Testing EmlxReader with sample .emlx file...")
|
||||||
|
print(f"Scanning directory: {current_dir}")
|
||||||
|
|
||||||
|
# Use the custom EmlxReader
|
||||||
|
documents = EmlxReader().load_data(current_dir, max_count=1000)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Make sure sample.emlx exists in the examples directory.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\nSuccessfully loaded {len(documents)} document(s)")
|
||||||
|
|
||||||
|
# Initialize index with documents
|
||||||
|
index = VectorStoreIndex.from_documents(documents)
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
|
||||||
|
print("\nTesting query: 'Hows Berkeley Graduate Student Instructor'")
|
||||||
|
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||||
|
print(f"Response: {res}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
99
test/query_saved_index.py
Normal file
99
test/query_saved_index.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
from llama_index.core import VectorStoreIndex, StorageContext
|
||||||
|
|
||||||
|
def load_index(save_dir: str = "mail_index"):
|
||||||
|
"""
|
||||||
|
Load the saved index from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir: Directory where the index is saved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded index or None if loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load storage context
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
|
|
||||||
|
# Load index
|
||||||
|
index = VectorStoreIndex.from_vector_store(
|
||||||
|
storage_context.vector_store,
|
||||||
|
storage_context=storage_context
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Index loaded from {save_dir}")
|
||||||
|
return index
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def query_index(index, query: str):
|
||||||
|
"""
|
||||||
|
Query the loaded index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: The loaded index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
if index is None:
|
||||||
|
print("No index available for querying.")
|
||||||
|
return
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
response = query_engine.query(query)
|
||||||
|
print(f"\nQuery: {query}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
save_dir = "mail_index"
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
if not os.path.exists(save_dir) or not os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||||
|
print(f"Index not found in {save_dir}")
|
||||||
|
print("Please run mail_reader_save_load.py first to create the index.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load the index
|
||||||
|
index = load_index(save_dir)
|
||||||
|
|
||||||
|
if not index:
|
||||||
|
print("Failed to load index.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Email Query Interface")
|
||||||
|
print("="*60)
|
||||||
|
print("Type 'quit' to exit")
|
||||||
|
print("Type 'help' for example queries")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Interactive query loop
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("\nEnter your query: ").strip()
|
||||||
|
|
||||||
|
if query.lower() == 'quit':
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
elif query.lower() == 'help':
|
||||||
|
print("\nExample queries:")
|
||||||
|
print("- Hows Berkeley Graduate Student Instructor")
|
||||||
|
print("- What emails mention GSR appointments?")
|
||||||
|
print("- Find emails about deadlines")
|
||||||
|
print("- Search for emails from specific sender")
|
||||||
|
print("- Find emails about meetings")
|
||||||
|
continue
|
||||||
|
elif not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
query_index(index, query)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing query: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user