- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments - Replace Chinese comments with English equivalents - Fix unused imports with proper noqa annotations for intentional imports - Fix bare except clauses with specific exception types - Fix redefined variables and undefined names - Add ruff noqa annotations for generated protobuf files - Add lint and format check to GitHub Actions CI pipeline
130 lines
4.1 KiB
Python
130 lines
4.1 KiB
Python
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add the project root to Python path so we can import from examples
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
import torch
|
|
from llama_index.core import StorageContext, VectorStoreIndex
|
|
from llama_index.core.node_parser import SentenceSplitter
|
|
|
|
# --- EMBEDDING MODEL ---
|
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
|
|
# --- END EMBEDDING MODEL ---
|
|
# Import EmlxReader from the new module
|
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
|
|
|
|
|
def create_and_save_index(
|
|
mail_path: str,
|
|
save_dir: str = "mail_index_embedded",
|
|
max_count: int = 1000,
|
|
include_html: bool = False,
|
|
):
|
|
print("Creating index from mail data with embedded metadata...")
|
|
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
|
if not documents:
|
|
print("No documents loaded. Exiting.")
|
|
return None
|
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
|
# Use facebook/contriever as the embedder
|
|
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
|
# set on device
|
|
|
|
if torch.cuda.is_available():
|
|
embed_model._model.to("cuda")
|
|
# set mps
|
|
elif torch.backends.mps.is_available():
|
|
embed_model._model.to("mps")
|
|
else:
|
|
embed_model._model.to("cpu")
|
|
index = VectorStoreIndex.from_documents(
|
|
documents, transformations=[text_splitter], embed_model=embed_model
|
|
)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
index.storage_context.persist(persist_dir=save_dir)
|
|
print(f"Index saved to {save_dir}")
|
|
return index
|
|
|
|
|
|
def load_index(save_dir: str = "mail_index_embedded"):
|
|
try:
|
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
|
index = VectorStoreIndex.from_vector_store(
|
|
storage_context.vector_store, storage_context=storage_context
|
|
)
|
|
print(f"Index loaded from {save_dir}")
|
|
return index
|
|
except Exception as e:
|
|
print(f"Error loading index: {e}")
|
|
return None
|
|
|
|
|
|
def query_index(index, query: str):
|
|
if index is None:
|
|
print("No index available for querying.")
|
|
return
|
|
query_engine = index.as_query_engine()
|
|
response = query_engine.query(query)
|
|
print(f"Query: {query}")
|
|
print(f"Response: {response}")
|
|
|
|
|
|
def main():
|
|
# Parse command line arguments
|
|
parser = argparse.ArgumentParser(
|
|
description="LlamaIndex Mail Reader - Create and query email index"
|
|
)
|
|
parser.add_argument(
|
|
"--mail-path",
|
|
type=str,
|
|
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
|
help="Path to mail data directory",
|
|
)
|
|
parser.add_argument(
|
|
"--save-dir",
|
|
type=str,
|
|
default="mail_index_embedded",
|
|
help="Directory to store the index (default: mail_index_embedded)",
|
|
)
|
|
parser.add_argument(
|
|
"--max-emails", type=int, default=10000, help="Maximum number of emails to process"
|
|
)
|
|
parser.add_argument(
|
|
"--include-html",
|
|
action="store_true",
|
|
default=False,
|
|
help="Include HTML content in email processing (default: False)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
mail_path = args.mail_path
|
|
save_dir = args.save_dir
|
|
|
|
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
|
print("Loading existing index...")
|
|
index = load_index(save_dir)
|
|
else:
|
|
print("Creating new index...")
|
|
index = create_and_save_index(
|
|
mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html
|
|
)
|
|
if index:
|
|
queries = [
|
|
"Hows Berkeley Graduate Student Instructor",
|
|
"how's the icloud related advertisement saying",
|
|
"Whats the number of class recommend to take per semester for incoming EECS students",
|
|
]
|
|
for query in queries:
|
|
print("\n" + "=" * 50)
|
|
query_index(index, query)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|