Merge remote-tracking branch 'origin/main' into datastore-reproduce
This commit is contained in:
130
examples/LEANN_email_reader.py
Normal file
130
examples/LEANN_email_reader.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import os
|
||||
import email
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with embedded metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split('\n', 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get('Subject', 'No Subject')
|
||||
from_addr = msg.get('From', 'Unknown')
|
||||
to_addr = msg.get('To', 'Unknown')
|
||||
date = msg.get('Date', 'Unknown')
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
||||
# if part.get_content_type() == "text/html":
|
||||
# continue
|
||||
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||
# break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[EMAIL METADATA]
|
||||
File: {filename}
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
[END METADATA]
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
@staticmethod
|
||||
def find_all_messages_directories(base_path: str) -> List[Path]:
|
||||
"""
|
||||
Find all Messages directories under the given base path.
|
||||
|
||||
Args:
|
||||
base_path: Base path to search for Messages directories
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to Messages directories
|
||||
"""
|
||||
base_path_obj = Path(base_path)
|
||||
messages_dirs = []
|
||||
|
||||
if not base_path_obj.exists():
|
||||
print(f"Base path {base_path} does not exist")
|
||||
return messages_dirs
|
||||
|
||||
# Find all Messages directories recursively
|
||||
for messages_dir in base_path_obj.rglob("Messages"):
|
||||
if messages_dir.is_dir():
|
||||
messages_dirs.append(messages_dir)
|
||||
print(f"Found Messages directory: {messages_dir}")
|
||||
|
||||
print(f"Found {len(messages_dirs)} Messages directories")
|
||||
return messages_dirs
|
||||
192
examples/email_data/email.py
Normal file
192
examples/email_data/email.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Mbox parser.
|
||||
|
||||
Contains simple parser for mbox files.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fsspec import AbstractFileSystem
|
||||
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MboxReader(BaseReader):
|
||||
"""
|
||||
Mbox parser.
|
||||
|
||||
Extract messages from mailbox files.
|
||||
Returns string including date, subject, sender, receiver and
|
||||
content for each message.
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_MESSAGE_FORMAT: str = (
|
||||
"Date: {_date}\n"
|
||||
"From: {_from}\n"
|
||||
"To: {_to}\n"
|
||||
"Subject: {_subject}\n"
|
||||
"Content: {_content}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
max_count: int = 0,
|
||||
message_format: str = DEFAULT_MESSAGE_FORMAT,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_count = max_count
|
||||
self.message_format = message_format
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
"""Parse file into string."""
|
||||
# Import required libraries
|
||||
import mailbox
|
||||
from email.parser import BytesParser
|
||||
from email.policy import default
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if fs:
|
||||
logger.warning(
|
||||
"fs was specified but MboxReader doesn't support loading "
|
||||
"from fsspec filesystems. Will load from local filesystem instead."
|
||||
)
|
||||
|
||||
i = 0
|
||||
results: List[str] = []
|
||||
# Load file using mailbox
|
||||
bytes_parser = BytesParser(policy=default).parse
|
||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||
|
||||
# Iterate through all messages
|
||||
for _, _msg in enumerate(mbox):
|
||||
try:
|
||||
msg: mailbox.mboxMessage = _msg
|
||||
# Parse multipart messages
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
ctype = part.get_content_type()
|
||||
cdispo = str(part.get("Content-Disposition"))
|
||||
if "attachment" in cdispo:
|
||||
print(f"Attachment found: {part.get_filename()}")
|
||||
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||
content = part.get_payload(decode=True) # decode
|
||||
break
|
||||
# Get plain message payload for non-multipart messages
|
||||
else:
|
||||
content = msg.get_payload(decode=True)
|
||||
|
||||
# Parse message HTML content and remove unneeded whitespace
|
||||
soup = BeautifulSoup(content)
|
||||
stripped_content = " ".join(soup.get_text().split())
|
||||
# Format message to include date, sender, receiver and subject
|
||||
msg_string = self.message_format.format(
|
||||
_date=msg["date"],
|
||||
_from=msg["from"],
|
||||
_to=msg["to"],
|
||||
_subject=msg["subject"],
|
||||
_content=stripped_content,
|
||||
)
|
||||
# Add message string to results
|
||||
results.append(msg_string)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
|
||||
|
||||
# Increment counter and return if max count is met
|
||||
i += 1
|
||||
if self.max_count > 0 and i >= self.max_count:
|
||||
break
|
||||
|
||||
return [Document(text=result, metadata=extra_info or {}) for result in results]
|
||||
|
||||
|
||||
class EmlxMboxReader(MboxReader):
|
||||
"""
|
||||
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||
|
||||
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||
1. Reading .emlx files from a directory
|
||||
2. Converting them to mbox format in memory
|
||||
3. Using the parent MboxReader's parsing logic
|
||||
"""
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
directory: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
if fs:
|
||||
logger.warning(
|
||||
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||
"from fsspec filesystems. Will load from local filesystem instead."
|
||||
)
|
||||
|
||||
# Find all .emlx files in the directory
|
||||
emlx_files = list(directory.glob("*.emlx"))
|
||||
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||
|
||||
if not emlx_files:
|
||||
logger.warning(f"No .emlx files found in {directory}")
|
||||
return []
|
||||
|
||||
# Create a temporary mbox file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
||||
temp_mbox_path = temp_mbox.name
|
||||
|
||||
# Convert .emlx files to mbox format
|
||||
for emlx_file in emlx_files:
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx format: first line is length, rest is email content
|
||||
lines = content.split('\n', 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1] # Skip the length line
|
||||
|
||||
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||
continue
|
||||
|
||||
# Close the temporary file so MboxReader can read it
|
||||
temp_mbox.close()
|
||||
|
||||
try:
|
||||
# Use the parent MboxReader's logic to parse the mbox file
|
||||
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
try:
|
||||
os.unlink(temp_mbox_path)
|
||||
except:
|
||||
pass
|
||||
229
examples/mail_reader_leann.py
Normal file
229
examples/mail_reader_leann.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import asyncio
|
||||
import dotenv
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1):
|
||||
"""
|
||||
Create LEANN index from multiple mail data sources.
|
||||
|
||||
Args:
|
||||
messages_dirs: List of Path objects pointing to Messages directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process per directory
|
||||
"""
|
||||
print("Creating LEANN index from multiple mail data sources...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from LEANN_email_reader import EmlxReader
|
||||
reader = EmlxReader()
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Messages directory
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(messages_dir)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {messages_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000):
|
||||
"""
|
||||
Create LEANN index from mail data.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process
|
||||
"""
|
||||
print("Creating LEANN index from mail data...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from LEANN_email_reader import EmlxReader
|
||||
reader = EmlxReader()
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
documents = reader.load_data(Path(mail_path))
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} email documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=5,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
# Base path to the mail data directory
|
||||
base_mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
|
||||
INDEX_DIR = Path("./mail_index_leann_raw_text_all")
|
||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||
|
||||
# Find all Messages directories
|
||||
from LEANN_email_reader import EmlxReader
|
||||
messages_dirs = EmlxReader.find_all_messages_directories(base_mail_path)
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Messages directories found. Exiting.")
|
||||
return
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH)
|
||||
|
||||
if index_path:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "="*60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
86
examples/mail_reader_llamaindex.py
Normal file
86
examples/mail_reader_llamaindex.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from llama_index.core import VectorStoreIndex, StorageContext
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# --- EMBEDDING MODEL ---
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
import torch
|
||||
|
||||
# --- END EMBEDDING MODEL ---
|
||||
|
||||
# Import EmlxReader from the new module
|
||||
from LEANN_email_reader import EmlxReader
|
||||
|
||||
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000):
|
||||
print("Creating index from mail data with embedded metadata...")
|
||||
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
# Use facebook/contriever as the embedder
|
||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||
# set on device
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
embed_model._model.to("cuda")
|
||||
# set mps
|
||||
elif torch.backends.mps.is_available():
|
||||
embed_model._model.to("mps")
|
||||
else:
|
||||
embed_model._model.to("cpu")
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
transformations=[text_splitter],
|
||||
embed_model=embed_model
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
return index
|
||||
|
||||
def load_index(save_dir: str = "mail_index_embedded"):
|
||||
try:
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store,
|
||||
storage_context=storage_context
|
||||
)
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
def query_index(index, query: str):
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
def main():
|
||||
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||
save_dir = "mail_index_embedded"
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(mail_path, save_dir, max_count=10000)
|
||||
if index:
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying"
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "="*50)
|
||||
query_index(index, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -8,7 +8,6 @@ from llama_index.node_parser.docling import DoclingNodeParser
|
||||
from llama_index.readers.docling import DoclingReader
|
||||
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
|
||||
import asyncio
|
||||
import os
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
import shutil
|
||||
@@ -22,9 +21,11 @@ file_extractor: dict[str, BaseReader] = {
|
||||
".pptx": reader,
|
||||
".pdf": reader,
|
||||
".xlsx": reader,
|
||||
".txt": reader,
|
||||
".md": reader,
|
||||
}
|
||||
node_parser = DoclingNodeParser(
|
||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=64)
|
||||
chunker=HybridChunker(tokenizer="facebook/contriever", max_tokens=128)
|
||||
)
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
@@ -32,7 +33,7 @@ documents = SimpleDirectoryReader(
|
||||
recursive=True,
|
||||
file_extractor=file_extractor,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
|
||||
required_exts=[".pdf", ".docx", ".pptx", ".xlsx", ".txt", ".md"]
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
@@ -41,7 +42,7 @@ for doc in documents:
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
INDEX_DIR = Path("./test_pdf_index")
|
||||
INDEX_DIR = Path("./test_pdf_index_pangu_test")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
@@ -49,14 +50,15 @@ if not INDEX_DIR.exists():
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# CSR compact mode with recompute
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
@@ -80,14 +82,17 @@ async def main(args):
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||
query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True)
|
||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
||||
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf"], help="The LLM backend to use.")
|
||||
parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf).")
|
||||
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf", "openai"], help="The LLM backend to use.")
|
||||
parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).")
|
||||
parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
319
examples/multi_vector_aggregator.py
Normal file
319
examples/multi_vector_aggregator.py
Normal file
@@ -0,0 +1,319 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Multi-Vector Aggregator for Fat Embeddings
|
||||
==========================================
|
||||
|
||||
This module implements aggregation strategies for multi-vector embeddings,
|
||||
similar to ColPali's approach where multiple patch vectors represent a single document.
|
||||
|
||||
Key features:
|
||||
- MaxSim aggregation (take maximum similarity across patches)
|
||||
- Voting-based aggregation (count patch matches)
|
||||
- Weighted aggregation (attention-score weighted)
|
||||
- Spatial clustering of matching patches
|
||||
- Document-level result consolidation
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
@dataclass
|
||||
class PatchResult:
|
||||
"""Represents a single patch search result."""
|
||||
patch_id: int
|
||||
image_name: str
|
||||
image_path: str
|
||||
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||
score: float
|
||||
attention_score: float
|
||||
scale: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class AggregatedResult:
|
||||
"""Represents an aggregated document-level result."""
|
||||
image_name: str
|
||||
image_path: str
|
||||
doc_score: float
|
||||
patch_count: int
|
||||
best_patch: PatchResult
|
||||
all_patches: List[PatchResult]
|
||||
aggregation_method: str
|
||||
spatial_clusters: Optional[List[List[PatchResult]]] = None
|
||||
|
||||
class MultiVectorAggregator:
|
||||
"""
|
||||
Aggregates multiple patch-level results into document-level results.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
aggregation_method: str = "maxsim",
|
||||
spatial_clustering: bool = True,
|
||||
cluster_distance_threshold: float = 100.0):
|
||||
"""
|
||||
Initialize the aggregator.
|
||||
|
||||
Args:
|
||||
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||
spatial_clustering: Whether to cluster spatially close patches
|
||||
cluster_distance_threshold: Distance threshold for spatial clustering
|
||||
"""
|
||||
self.aggregation_method = aggregation_method
|
||||
self.spatial_clustering = spatial_clustering
|
||||
self.cluster_distance_threshold = cluster_distance_threshold
|
||||
|
||||
def aggregate_results(self,
|
||||
search_results: List[Dict[str, Any]],
|
||||
top_k: int = 10) -> List[AggregatedResult]:
|
||||
"""
|
||||
Aggregate patch-level search results into document-level results.
|
||||
|
||||
Args:
|
||||
search_results: List of search results from LeannSearcher
|
||||
top_k: Number of top documents to return
|
||||
|
||||
Returns:
|
||||
List of aggregated document results
|
||||
"""
|
||||
# Group results by image
|
||||
image_groups = defaultdict(list)
|
||||
|
||||
for result in search_results:
|
||||
metadata = result.metadata
|
||||
if "image_name" in metadata and "patch_id" in metadata:
|
||||
patch_result = PatchResult(
|
||||
patch_id=metadata["patch_id"],
|
||||
image_name=metadata["image_name"],
|
||||
image_path=metadata["image_path"],
|
||||
coordinates=tuple(metadata["coordinates"]),
|
||||
score=result.score,
|
||||
attention_score=metadata.get("attention_score", 0.0),
|
||||
scale=metadata.get("scale", 1.0),
|
||||
metadata=metadata
|
||||
)
|
||||
image_groups[metadata["image_name"]].append(patch_result)
|
||||
|
||||
# Aggregate each image group
|
||||
aggregated_results = []
|
||||
for image_name, patches in image_groups.items():
|
||||
if len(patches) == 0:
|
||||
continue
|
||||
|
||||
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||
aggregated_results.append(agg_result)
|
||||
|
||||
# Sort by aggregated score and return top-k
|
||||
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||
return aggregated_results[:top_k]
|
||||
|
||||
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
|
||||
"""Aggregate patches for a single image."""
|
||||
|
||||
if self.aggregation_method == "maxsim":
|
||||
doc_score = max(patch.score for patch in patches)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "voting":
|
||||
# Count patches above threshold
|
||||
threshold = np.percentile([p.score for p in patches], 75)
|
||||
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "weighted":
|
||||
# Weight by attention scores
|
||||
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||
total_weights = sum(p.attention_score for p in patches)
|
||||
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||
|
||||
elif self.aggregation_method == "mean":
|
||||
doc_score = np.mean([patch.score for patch in patches])
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||
|
||||
# Spatial clustering if enabled
|
||||
spatial_clusters = None
|
||||
if self.spatial_clustering:
|
||||
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||
|
||||
return AggregatedResult(
|
||||
image_name=image_name,
|
||||
image_path=patches[0].image_path,
|
||||
doc_score=float(doc_score),
|
||||
patch_count=len(patches),
|
||||
best_patch=best_patch,
|
||||
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||
aggregation_method=self.aggregation_method,
|
||||
spatial_clusters=spatial_clusters
|
||||
)
|
||||
|
||||
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
|
||||
"""Cluster patches that are spatially close to each other."""
|
||||
if len(patches) <= 1:
|
||||
return [patches]
|
||||
|
||||
clusters = []
|
||||
remaining_patches = patches.copy()
|
||||
|
||||
while remaining_patches:
|
||||
# Start new cluster with highest scoring remaining patch
|
||||
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||
current_cluster = [seed_patch]
|
||||
remaining_patches.remove(seed_patch)
|
||||
|
||||
# Add nearby patches to cluster
|
||||
added_to_cluster = True
|
||||
while added_to_cluster:
|
||||
added_to_cluster = False
|
||||
for patch in remaining_patches.copy():
|
||||
if self._is_patch_nearby(patch, current_cluster):
|
||||
current_cluster.append(patch)
|
||||
remaining_patches.remove(patch)
|
||||
added_to_cluster = True
|
||||
|
||||
clusters.append(current_cluster)
|
||||
|
||||
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||
|
||||
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
|
||||
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||
patch_center = self._get_patch_center(patch.coordinates)
|
||||
|
||||
for cluster_patch in cluster:
|
||||
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
|
||||
(patch_center[1] - cluster_center[1])**2)
|
||||
|
||||
if distance <= self.cluster_distance_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
||||
"""Get center point of a patch."""
|
||||
x1, y1, x2, y2 = coordinates
|
||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
|
||||
"""Pretty print aggregated results."""
|
||||
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||
print("=" * 80)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n{i+1}. {result.image_name}")
|
||||
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||
print(f" Path: {result.image_path}")
|
||||
|
||||
# Show best patch
|
||||
best = result.best_patch
|
||||
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
|
||||
|
||||
# Show top patches
|
||||
print(f" 📍 Top Patches:")
|
||||
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
|
||||
|
||||
# Show spatial clusters if available
|
||||
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||
cluster_score = max(p.score for p in cluster)
|
||||
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
|
||||
|
||||
def demo_aggregation():
|
||||
"""Demonstrate the multi-vector aggregation functionality."""
|
||||
print("=== Multi-Vector Aggregation Demo ===")
|
||||
|
||||
# Simulate some patch-level search results
|
||||
# In real usage, these would come from LeannSearcher.search()
|
||||
|
||||
class MockResult:
|
||||
def __init__(self, score, metadata):
|
||||
self.score = score
|
||||
self.metadata = metadata
|
||||
|
||||
# Simulate results for 2 images with multiple patches each
|
||||
mock_results = [
|
||||
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||
MockResult(0.85, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 3,
|
||||
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||
"attention_score": 0.92,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.78, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 7,
|
||||
"coordinates": [200, 300, 324, 424], # Cat area
|
||||
"attention_score": 0.88,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.72, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 12,
|
||||
"coordinates": [150, 100, 274, 224], # Appliances
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.65, {
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 15,
|
||||
"coordinates": [50, 250, 174, 374], # Furniture
|
||||
"attention_score": 0.70,
|
||||
"scale": 1.0
|
||||
}),
|
||||
|
||||
# Image 2: city_street.jpg - 3 patches
|
||||
MockResult(0.68, {
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 2,
|
||||
"coordinates": [300, 100, 424, 224], # Buildings
|
||||
"attention_score": 0.80,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.62, {
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 8,
|
||||
"coordinates": [100, 350, 224, 474], # Street level
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0
|
||||
}),
|
||||
MockResult(0.55, {
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 11,
|
||||
"coordinates": [400, 200, 524, 324], # Sky area
|
||||
"attention_score": 0.60,
|
||||
"scale": 1.0
|
||||
}),
|
||||
]
|
||||
|
||||
# Test different aggregation methods
|
||||
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||
|
||||
for method in methods:
|
||||
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
|
||||
|
||||
aggregator = MultiVectorAggregator(
|
||||
aggregation_method=method,
|
||||
spatial_clustering=True,
|
||||
cluster_distance_threshold=100.0
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||
aggregator.print_aggregated_results(aggregated)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_aggregation()
|
||||
18
examples/resue_index.py
Normal file
18
examples/resue_index.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import asyncio
|
||||
from leann.api import LeannChat
|
||||
from pathlib import Path
|
||||
|
||||
INDEX_DIR = Path("./test_pdf_index_huawei")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
async def main():
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
|
||||
print(f"\n[PHASE 2] Response: {response}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user