diff --git a/.gitignore b/.gitignore index 356504c..26740c5 100755 --- a/.gitignore +++ b/.gitignore @@ -8,11 +8,15 @@ demo/indices/ *pycache* outputs/ *.pkl +*.pdf .history/ scripts/ lm_eval.egg-info/ demo/experiment_results/**/*.json *.jsonl +*.eml +*.emlx +*.json *.sh *.txt !CMakeLists.txt @@ -42,6 +46,7 @@ embedding_comparison_results/ *.ivecs *.index *.bin +*.old read_graph analyze_diskann_graph diff --git a/.gitmodules b/.gitmodules index 8c49b3e..1899ae5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,13 @@ [submodule "packages/leann-backend-hnsw/third_party/faiss"] path = packages/leann-backend-hnsw/third_party/faiss url = https://github.com/yichuan520030910320/faiss.git +[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"] + path = packages/leann-backend-hnsw/third_party/msgpack-c + url = https://github.com/msgpack/msgpack-c.git + branch = cpp_master +[submodule "packages/leann-backend-hnsw/third_party/cppzmq"] + path = packages/leann-backend-hnsw/third_party/cppzmq + url = https://github.com/zeromq/cppzmq.git +[submodule "packages/leann-backend-hnsw/third_party/libzmq"] + path = packages/leann-backend-hnsw/third_party/libzmq + url = https://github.com/zeromq/libzmq.git diff --git a/README.md b/README.md index 2771f44..e51f8c9 100755 --- a/README.md +++ b/README.md @@ -28,13 +28,15 @@ ### 🎯 Why Leann? Traditional RAG systems face a fundamental trade-off: + - **πŸ’Ύ Storage**: Storing embeddings for millions of documents requires massive disk space - **πŸ”„ Freshness**: Pre-computed embeddings become stale when documents change - **πŸ’° Cost**: Vector databases are expensive to scale **Leann solves this by:** + - βœ… **Zero embedding storage** - Only graph structure is persisted -- βœ… **Real-time computation** - Embeddings computed on-demand with ms latency +- βœ… **Real-time computation** - Embeddings computed on-demand with ms latency - βœ… **Memory efficient** - Runs on consumer hardware (8GB RAM) - βœ… **Always fresh** - No stale embeddings, ever @@ -46,6 +48,19 @@ Traditional RAG systems face a fundamental trade-off: git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann cd leann git submodule update --init --recursive +``` + +**macOS:** +```bash +brew install llvm libomp boost protobuf +export CC=$(brew --prefix llvm)/bin/clang +export CXX=$(brew --prefix llvm)/bin/clang++ +uv sync +``` + +**Linux (Ubuntu/Debian):** +```bash +sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev uv sync ``` @@ -75,19 +90,39 @@ for result in results: uv run examples/document_search.py ``` +or you want to use python + +```bash +source .venv/bin/activate +python ./examples/main_cli_example.py +``` **PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)** This demo showcases how to build a RAG system for PDF documents using Leann. -1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory. -2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function. + +1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory. +2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function. ```bash uv run examples/main_cli_example.py ``` +### Regenerating Protobuf Files + +If you modify any `.proto` files (such as `embedding.proto`), or if you see errors about protobuf version mismatch, **regenerate the C++ protobuf files** to match your installed version: + +```bash +cd packages/leann-backend-diskann +protoc --cpp_out=third_party/DiskANN/include --proto_path=third_party embedding.proto +protoc --cpp_out=third_party/DiskANN/src --proto_path=third_party embedding.proto +``` + +This ensures the generated files are compatible with your system's protobuf library. + ## ✨ Features ### πŸ”₯ Core Features + - **πŸ“Š Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search) - **πŸ—οΈ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API - **πŸ”„ Real-time Embeddings**: Dynamic computation using optimized ZMQ servers @@ -95,6 +130,7 @@ uv run examples/main_cli_example.py - **🎯 Graph Pruning**: Advanced techniques for memory-efficient search ### πŸ› οΈ Technical Highlights + - **Zero-copy operations** for maximum performance - **SIMD-optimized** distance computations (AVX2/AVX512) - **Async embedding pipeline** with batched processing @@ -102,6 +138,7 @@ uv run examples/main_cli_example.py - **Recompute mode** for highest accuracy scenarios ### 🎨 Developer Experience + - **Simple Python API** - Get started in minutes - **Extensible backend system** - Easy to add new algorithms - **Comprehensive examples** - From basic usage to production deployment @@ -111,19 +148,19 @@ uv run examples/main_cli_example.py ### Memory Usage Comparison -| System | 1M Documents | 10M Documents | 100M Documents | -|--------|-------------|---------------|----------------| -| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB | -| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** | -| **Reduction** | **94.2%** | **96.1%** | **97.3%** | +| System | 1M Documents | 10M Documents | 100M Documents | +| --------------------- | ---------------- | ---------------- | ---------------- | +| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB | +| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** | +| **Reduction** | **94.2%** | **96.1%** | **97.3%** | ### Query Performance -| Backend | Index Size | Query Time | Recall@10 | -|---------|------------|------------|-----------| -| DiskANN | 1M docs | 12ms | 0.95 | -| DiskANN + Recompute | 1M docs | 145ms | 0.98 | -| HNSW | 1M docs | 8ms | 0.93 | +| Backend | Index Size | Query Time | Recall@10 | +| ------------------- | ---------- | ---------- | --------- | +| DiskANN | 1M docs | 12ms | 0.95 | +| DiskANN + Recompute | 1M docs | 145ms | 0.98 | +| HNSW | 1M docs | 8ms | 0.93 | *Benchmarks run on AMD Ryzen 7 with 32GB RAM* @@ -145,26 +182,29 @@ uv run examples/main_cli_example.py ### Key Components 1. **🧠 Embedding Engine**: Real-time transformer inference with caching -2. **πŸ“Š Graph Index**: Memory-efficient navigation structures +2. **πŸ“Š Graph Index**: Memory-efficient navigation structures 3. **πŸ”„ Search Coordinator**: Orchestrates embedding + graph search 4. **⚑ Backend Adapters**: Pluggable algorithm implementations ## πŸŽ“ Supported Models & Backends ### πŸ€– Embedding Models + - **sentence-transformers/all-mpnet-base-v2** (default) - **sentence-transformers/all-MiniLM-L6-v2** (lightweight) - Any HuggingFace sentence-transformer model - Custom model support via API -### πŸ”§ Search Backends +### πŸ”§ Search Backends + - **DiskANN**: Microsoft's billion-scale ANN algorithm - **HNSW**: Hierarchical Navigable Small World graphs - **Coming soon**: ScaNN, Faiss-IVF, NGT ### πŸ“ Distance Functions + - **L2**: Euclidean distance for precise similarity -- **Cosine**: Angular similarity for normalized vectors +- **Cosine**: Angular similarity for normalized vectors - **MIPS**: Maximum Inner Product Search for recommendation systems ## πŸ”¬ Paper @@ -188,6 +228,7 @@ If you find Leann useful, please cite: ## 🌍 Use Cases ### πŸ’Ό Enterprise RAG + ```python # Handle millions of documents with limited resources builder = LeannBuilder( @@ -198,7 +239,8 @@ builder = LeannBuilder( ) ``` -### πŸ”¬ Research & Experimentation +### πŸ”¬ Research & Experimentation + ```python # Quick prototyping with different algorithms for backend in ["diskann", "hnsw"]: @@ -207,6 +249,7 @@ for backend in ["diskann", "hnsw"]: ``` ### πŸš€ Real-time Applications + ```python # Sub-second response times chat = LeannChat("knowledge.leann") @@ -219,6 +262,7 @@ response = chat.ask("What is quantum computing?") We welcome contributions! Leann is built by the community, for the community. ### Ways to Contribute + - πŸ› **Bug Reports**: Found an issue? Let us know! - πŸ’‘ **Feature Requests**: Have an idea? We'd love to hear it! - πŸ”§ **Code Contributions**: PRs welcome for all skill levels @@ -226,14 +270,17 @@ We welcome contributions! Leann is built by the community, for the community. - πŸ§ͺ **Benchmarks**: Share your performance results ### Development Setup + ```bash -git clone https://github.com/yourname/leann +git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann cd leann +git submodule update --init --recursive uv sync --dev uv run pytest tests/ ``` ### Quick Tests + ```bash # Sanity check all distance functions uv run python tests/sanity_checks/test_distance_functions.py @@ -241,17 +288,21 @@ uv run python tests/sanity_checks/test_distance_functions.py # Verify L2 implementation uv run python tests/sanity_checks/test_l2_verification.py ``` + ## ❓ FAQ ### Common Issues #### NCCL Topology Error + **Problem**: You encounter `ncclTopoComputePaths` error during document processing: + ``` ncclTopoComputePaths (system=, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688 ``` **Solution**: Set these environment variables before running your script: + ```bash export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml export NCCL_DEBUG=INFO @@ -259,23 +310,26 @@ export NCCL_DEBUG_SUBSYS=INIT,GRAPH export NCCL_IB_DISABLE=1 export NCCL_NET_PLUGIN=none export NCCL_SOCKET_IFNAME=ens5 - +``` ## πŸ“ˆ Roadmap ### 🎯 Q1 2024 -- [x] DiskANN backend with MIPS/L2/Cosine support -- [x] HNSW backend integration -- [x] Real-time embedding pipeline -- [x] Memory-efficient graph pruning + +- [X] DiskANN backend with MIPS/L2/Cosine support +- [X] HNSW backend integration +- [X] Real-time embedding pipeline +- [X] Memory-efficient graph pruning ### πŸš€ Q2 2024 + - [ ] Distributed search across multiple nodes - [ ] ScaNN backend support - [ ] Advanced caching strategies - [ ] Kubernetes deployment guides ### 🌟 Q3 2024 + - [ ] GPU-accelerated embedding computation - [ ] Approximate distance functions - [ ] Integration with LangChain/LlamaIndex @@ -297,7 +351,7 @@ MIT License - see [LICENSE](LICENSE) for details. ## πŸ™ Acknowledgments - **Microsoft Research** for the DiskANN algorithm -- **Meta AI** for FAISS and optimization insights +- **Meta AI** for FAISS and optimization insights - **HuggingFace** for the transformer ecosystem - **Our amazing contributors** who make this possible @@ -309,4 +363,5 @@ MIT License - see [LICENSE](LICENSE) for details.

Made with ❀️ by the Leann team -

\ No newline at end of file +

+ diff --git a/examples/LEANN_email_reader.py b/examples/LEANN_email_reader.py new file mode 100644 index 0000000..81e9bc2 --- /dev/null +++ b/examples/LEANN_email_reader.py @@ -0,0 +1,130 @@ +import os +import email +from pathlib import Path +from typing import List, Any +from llama_index.core import Document +from llama_index.core.readers.base import BaseReader + +class EmlxReader(BaseReader): + """ + Apple Mail .emlx file reader with embedded metadata. + + Reads individual .emlx files from Apple Mail's storage format. + """ + + def __init__(self) -> None: + """Initialize.""" + pass + + def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: + """ + Load data from the input directory containing .emlx files. + + Args: + input_dir: Directory containing .emlx files + **load_kwargs: + max_count (int): Maximum amount of messages to read. + """ + docs: List[Document] = [] + max_count = load_kwargs.get('max_count', 1000) + count = 0 + + # Walk through the directory recursively + for dirpath, dirnames, filenames in os.walk(input_dir): + # Skip hidden directories + dirnames[:] = [d for d in dirnames if not d.startswith(".")] + + for filename in filenames: + if count >= max_count: + break + + if filename.endswith(".emlx"): + filepath = os.path.join(dirpath, filename) + try: + # Read the .emlx file + with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + # .emlx files have a length prefix followed by the email content + # The first line contains the length, followed by the email + lines = content.split('\n', 1) + if len(lines) >= 2: + email_content = lines[1] + + # Parse the email using Python's email module + try: + msg = email.message_from_string(email_content) + + # Extract email metadata + subject = msg.get('Subject', 'No Subject') + from_addr = msg.get('From', 'Unknown') + to_addr = msg.get('To', 'Unknown') + date = msg.get('Date', 'Unknown') + + # Extract email body + body = "" + if msg.is_multipart(): + for part in msg.walk(): + if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html": + # if part.get_content_type() == "text/html": + # continue + body += part.get_payload(decode=True).decode('utf-8', errors='ignore') + # break + else: + body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') + + # Create document content with metadata embedded in text + doc_content = f""" +[EMAIL METADATA] +File: {filename} +From: {from_addr} +To: {to_addr} +Subject: {subject} +Date: {date} +[END METADATA] + +{body} +""" + + # No separate metadata - everything is in the text + doc = Document(text=doc_content, metadata={}) + docs.append(doc) + count += 1 + + except Exception as e: + print(f"Error parsing email from {filepath}: {e}") + continue + + except Exception as e: + print(f"Error reading file {filepath}: {e}") + continue + + print(f"Loaded {len(docs)} email documents") + return docs + + @staticmethod + def find_all_messages_directories(base_path: str) -> List[Path]: + """ + Find all Messages directories under the given base path. + + Args: + base_path: Base path to search for Messages directories + + Returns: + List of Path objects pointing to Messages directories + """ + base_path_obj = Path(base_path) + messages_dirs = [] + + if not base_path_obj.exists(): + print(f"Base path {base_path} does not exist") + return messages_dirs + + # Find all Messages directories recursively + for messages_dir in base_path_obj.rglob("Messages"): + if messages_dir.is_dir(): + messages_dirs.append(messages_dir) + print(f"Found Messages directory: {messages_dir}") + + print(f"Found {len(messages_dirs)} Messages directories") + return messages_dirs \ No newline at end of file diff --git a/examples/email_data/email.py b/examples/email_data/email.py new file mode 100644 index 0000000..689618b --- /dev/null +++ b/examples/email_data/email.py @@ -0,0 +1,192 @@ +""" +Mbox parser. + +Contains simple parser for mbox files. + +""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional +from fsspec import AbstractFileSystem + +from llama_index.core.readers.base import BaseReader +from llama_index.core.schema import Document + +logger = logging.getLogger(__name__) + + +class MboxReader(BaseReader): + """ + Mbox parser. + + Extract messages from mailbox files. + Returns string including date, subject, sender, receiver and + content for each message. + + """ + + DEFAULT_MESSAGE_FORMAT: str = ( + "Date: {_date}\n" + "From: {_from}\n" + "To: {_to}\n" + "Subject: {_subject}\n" + "Content: {_content}" + ) + + def __init__( + self, + *args: Any, + max_count: int = 0, + message_format: str = DEFAULT_MESSAGE_FORMAT, + **kwargs: Any, + ) -> None: + """Init params.""" + try: + from bs4 import BeautifulSoup # noqa + except ImportError: + raise ImportError( + "`beautifulsoup4` package not found: `pip install beautifulsoup4`" + ) + + super().__init__(*args, **kwargs) + self.max_count = max_count + self.message_format = message_format + + def load_data( + self, + file: Path, + extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None, + ) -> List[Document]: + """Parse file into string.""" + # Import required libraries + import mailbox + from email.parser import BytesParser + from email.policy import default + + from bs4 import BeautifulSoup + + if fs: + logger.warning( + "fs was specified but MboxReader doesn't support loading " + "from fsspec filesystems. Will load from local filesystem instead." + ) + + i = 0 + results: List[str] = [] + # Load file using mailbox + bytes_parser = BytesParser(policy=default).parse + mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore + + # Iterate through all messages + for _, _msg in enumerate(mbox): + try: + msg: mailbox.mboxMessage = _msg + # Parse multipart messages + if msg.is_multipart(): + for part in msg.walk(): + ctype = part.get_content_type() + cdispo = str(part.get("Content-Disposition")) + if "attachment" in cdispo: + print(f"Attachment found: {part.get_filename()}") + if ctype == "text/plain" and "attachment" not in cdispo: + content = part.get_payload(decode=True) # decode + break + # Get plain message payload for non-multipart messages + else: + content = msg.get_payload(decode=True) + + # Parse message HTML content and remove unneeded whitespace + soup = BeautifulSoup(content) + stripped_content = " ".join(soup.get_text().split()) + # Format message to include date, sender, receiver and subject + msg_string = self.message_format.format( + _date=msg["date"], + _from=msg["from"], + _to=msg["to"], + _subject=msg["subject"], + _content=stripped_content, + ) + # Add message string to results + results.append(msg_string) + except Exception as e: + logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}") + + # Increment counter and return if max count is met + i += 1 + if self.max_count > 0 and i >= self.max_count: + break + + return [Document(text=result, metadata=extra_info or {}) for result in results] + + +class EmlxMboxReader(MboxReader): + """ + EmlxMboxReader - Modified MboxReader that handles directories of .emlx files. + + Extends MboxReader to work with Apple Mail's .emlx format by: + 1. Reading .emlx files from a directory + 2. Converting them to mbox format in memory + 3. Using the parent MboxReader's parsing logic + """ + + def load_data( + self, + directory: Path, + extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None, + ) -> List[Document]: + """Parse .emlx files from directory into strings using MboxReader logic.""" + import tempfile + import os + + if fs: + logger.warning( + "fs was specified but EmlxMboxReader doesn't support loading " + "from fsspec filesystems. Will load from local filesystem instead." + ) + + # Find all .emlx files in the directory + emlx_files = list(directory.glob("*.emlx")) + logger.info(f"Found {len(emlx_files)} .emlx files in {directory}") + + if not emlx_files: + logger.warning(f"No .emlx files found in {directory}") + return [] + + # Create a temporary mbox file + with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox: + temp_mbox_path = temp_mbox.name + + # Convert .emlx files to mbox format + for emlx_file in emlx_files: + try: + # Read the .emlx file + with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + # .emlx format: first line is length, rest is email content + lines = content.split('\n', 1) + if len(lines) >= 2: + email_content = lines[1] # Skip the length line + + # Write to mbox format (each message starts with "From " and ends with blank line) + temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n") + + except Exception as e: + logger.warning(f"Failed to process {emlx_file}: {e}") + continue + + # Close the temporary file so MboxReader can read it + temp_mbox.close() + + try: + # Use the parent MboxReader's logic to parse the mbox file + return super().load_data(Path(temp_mbox_path), extra_info, fs) + finally: + # Clean up temporary file + try: + os.unlink(temp_mbox_path) + except: + pass \ No newline at end of file diff --git a/examples/mail_reader_leann.py b/examples/mail_reader_leann.py new file mode 100644 index 0000000..dae6df8 --- /dev/null +++ b/examples/mail_reader_leann.py @@ -0,0 +1,229 @@ +import os +import asyncio +import dotenv +from pathlib import Path +from typing import List, Any +from leann.api import LeannBuilder, LeannSearcher, LeannChat +from llama_index.core.node_parser import SentenceSplitter + +dotenv.load_dotenv() + +def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1): + """ + Create LEANN index from multiple mail data sources. + + Args: + messages_dirs: List of Path objects pointing to Messages directories + index_path: Path to save the LEANN index + max_count: Maximum number of emails to process per directory + """ + print("Creating LEANN index from multiple mail data sources...") + + # Load documents using EmlxReader from LEANN_email_reader + from LEANN_email_reader import EmlxReader + reader = EmlxReader() + # from email_data.email import EmlxMboxReader + # from pathlib import Path + # reader = EmlxMboxReader() + + all_documents = [] + total_processed = 0 + + # Process each Messages directory + for i, messages_dir in enumerate(messages_dirs): + print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}") + + try: + documents = reader.load_data(messages_dir) + if documents: + print(f"Loaded {len(documents)} email documents from {messages_dir}") + all_documents.extend(documents) + total_processed += len(documents) + + # Check if we've reached the max count + if max_count > 0 and total_processed >= max_count: + print(f"Reached max count of {max_count} documents") + break + else: + print(f"No documents loaded from {messages_dir}") + except Exception as e: + print(f"Error processing {messages_dir}: {e}") + continue + + if not all_documents: + print("No documents loaded from any source. Exiting.") + return None + + print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories") + + # Create text splitter with 256 chunk size + text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25) + + # Convert Documents to text strings and chunk them + all_texts = [] + for doc in all_documents: + # Split the document into chunks + nodes = text_splitter.get_nodes_from_documents([doc]) + for node in nodes: + all_texts.append(node.get_content()) + + print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents") + + # Create LEANN index directory + INDEX_DIR = Path(index_path).parent + + if not INDEX_DIR.exists(): + print(f"--- Index directory not found, building new index ---") + INDEX_DIR.mkdir(exist_ok=True) + + print(f"--- Building new LEANN index ---") + + print(f"\n[PHASE 1] Building Leann index...") + + # Use HNSW backend for better macOS compatibility + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="facebook/contriever", + graph_degree=32, + complexity=64, + is_compact=True, + is_recompute=True, + num_threads=1 # Force single-threaded mode + ) + + print(f"Adding {len(all_texts)} email chunks to index...") + for chunk_text in all_texts: + builder.add_text(chunk_text) + + builder.build_index(index_path) + print(f"\nLEANN index built at {index_path}!") + else: + print(f"--- Using existing index at {INDEX_DIR} ---") + + return index_path + +def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000): + """ + Create LEANN index from mail data. + + Args: + mail_path: Path to the mail directory + index_path: Path to save the LEANN index + max_count: Maximum number of emails to process + """ + print("Creating LEANN index from mail data...") + + # Load documents using EmlxReader from LEANN_email_reader + from LEANN_email_reader import EmlxReader + reader = EmlxReader() + # from email_data.email import EmlxMboxReader + # from pathlib import Path + # reader = EmlxMboxReader() + documents = reader.load_data(Path(mail_path)) + + if not documents: + print("No documents loaded. Exiting.") + return None + + print(f"Loaded {len(documents)} email documents") + + # Create text splitter with 256 chunk size + text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25) + + # Convert Documents to text strings and chunk them + all_texts = [] + for doc in documents: + # Split the document into chunks + nodes = text_splitter.get_nodes_from_documents([doc]) + for node in nodes: + all_texts.append(node.get_content()) + + print(f"Created {len(all_texts)} text chunks from {len(documents)} documents") + + # Create LEANN index directory + INDEX_DIR = Path(index_path).parent + + if not INDEX_DIR.exists(): + print(f"--- Index directory not found, building new index ---") + INDEX_DIR.mkdir(exist_ok=True) + + print(f"--- Building new LEANN index ---") + + print(f"\n[PHASE 1] Building Leann index...") + + # Use HNSW backend for better macOS compatibility + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="facebook/contriever", + graph_degree=32, + complexity=64, + is_compact=True, + is_recompute=True, + num_threads=1 # Force single-threaded mode + ) + + print(f"Adding {len(all_texts)} email chunks to index...") + for chunk_text in all_texts: + builder.add_text(chunk_text) + + builder.build_index(index_path) + print(f"\nLEANN index built at {index_path}!") + else: + print(f"--- Using existing index at {INDEX_DIR} ---") + + return index_path + +async def query_leann_index(index_path: str, query: str): + """ + Query the LEANN index. + + Args: + index_path: Path to the LEANN index + query: The query string + """ + print(f"\n[PHASE 2] Starting Leann chat session...") + chat = LeannChat(index_path=index_path) + + print(f"You: {query}") + chat_response = chat.ask( + query, + top_k=5, + recompute_beighbor_embeddings=True, + complexity=32, + beam_width=1 + ) + print(f"Leann: {chat_response}") + +async def main(): + # Base path to the mail data directory + base_mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data" + + INDEX_DIR = Path("./mail_index_leann_raw_text_all") + INDEX_PATH = str(INDEX_DIR / "mail_documents.leann") + + # Find all Messages directories + from LEANN_email_reader import EmlxReader + messages_dirs = EmlxReader.find_all_messages_directories(base_mail_path) + + if not messages_dirs: + print("No Messages directories found. Exiting.") + return + + # Create or load the LEANN index from all sources + index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH) + + if index_path: + # Example queries + queries = [ + "Hows Berkeley Graduate Student Instructor", + "how's the icloud related advertisement saying", + "Whats the number of class recommend to take per semester for incoming EECS students" + + ] + + for query in queries: + print("\n" + "="*60) + await query_leann_index(index_path, query) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/mail_reader_llamaindex.py b/examples/mail_reader_llamaindex.py new file mode 100644 index 0000000..97abe10 --- /dev/null +++ b/examples/mail_reader_llamaindex.py @@ -0,0 +1,86 @@ +import os +from pathlib import Path +from typing import List, Any +from llama_index.core import VectorStoreIndex, StorageContext +from llama_index.core.node_parser import SentenceSplitter + +# --- EMBEDDING MODEL --- +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +import torch + +# --- END EMBEDDING MODEL --- + +# Import EmlxReader from the new module +from LEANN_email_reader import EmlxReader + +def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000): + print("Creating index from mail data with embedded metadata...") + documents = EmlxReader().load_data(mail_path, max_count=max_count) + if not documents: + print("No documents loaded. Exiting.") + return None + text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25) + # Use facebook/contriever as the embedder + embed_model = HuggingFaceEmbedding(model_name="facebook/contriever") + # set on device + import torch + if torch.cuda.is_available(): + embed_model._model.to("cuda") + # set mps + elif torch.backends.mps.is_available(): + embed_model._model.to("mps") + else: + embed_model._model.to("cpu") + index = VectorStoreIndex.from_documents( + documents, + transformations=[text_splitter], + embed_model=embed_model + ) + os.makedirs(save_dir, exist_ok=True) + index.storage_context.persist(persist_dir=save_dir) + print(f"Index saved to {save_dir}") + return index + +def load_index(save_dir: str = "mail_index_embedded"): + try: + storage_context = StorageContext.from_defaults(persist_dir=save_dir) + index = VectorStoreIndex.from_vector_store( + storage_context.vector_store, + storage_context=storage_context + ) + print(f"Index loaded from {save_dir}") + return index + except Exception as e: + print(f"Error loading index: {e}") + return None + +def query_index(index, query: str): + if index is None: + print("No index available for querying.") + return + query_engine = index.as_query_engine() + response = query_engine.query(query) + print(f"Query: {query}") + print(f"Response: {response}") + +def main(): + mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages" + save_dir = "mail_index_embedded" + if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")): + print("Loading existing index...") + index = load_index(save_dir) + else: + print("Creating new index...") + index = create_and_save_index(mail_path, save_dir, max_count=10000) + if index: + queries = [ + "Hows Berkeley Graduate Student Instructor", + "how's the icloud related advertisement saying" + "Whats the number of class recommend to take per semester for incoming EECS students" + ] + for query in queries: + print("\n" + "="*50) + query_index(index, query) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/main_cli_example.py b/examples/main_cli_example.py index a54feaa..7509a29 100644 --- a/examples/main_cli_example.py +++ b/examples/main_cli_example.py @@ -8,7 +8,6 @@ from llama_index.node_parser.docling import DoclingNodeParser from llama_index.readers.docling import DoclingReader from docling_core.transforms.chunker.hybrid_chunker import HybridChunker import asyncio -import os import dotenv from leann.api import LeannBuilder, LeannSearcher, LeannChat import shutil @@ -22,9 +21,11 @@ file_extractor: dict[str, BaseReader] = { ".pptx": reader, ".pdf": reader, ".xlsx": reader, + ".txt": reader, + ".md": reader, } node_parser = DoclingNodeParser( - chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=64) + chunker=HybridChunker(tokenizer="facebook/contriever", max_tokens=128) ) print("Loading documents...") documents = SimpleDirectoryReader( @@ -32,7 +33,7 @@ documents = SimpleDirectoryReader( recursive=True, file_extractor=file_extractor, encoding="utf-8", - required_exts=[".pdf", ".docx", ".pptx", ".xlsx"] + required_exts=[".pdf", ".docx", ".pptx", ".xlsx", ".txt", ".md"] ).load_data(show_progress=True) print("Documents loaded.") all_texts = [] @@ -41,7 +42,7 @@ for doc in documents: for node in nodes: all_texts.append(node.get_content()) -INDEX_DIR = Path("./test_pdf_index") +INDEX_DIR = Path("./test_pdf_index_pangu_test") INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann") if not INDEX_DIR.exists(): @@ -49,14 +50,15 @@ if not INDEX_DIR.exists(): print(f"\n[PHASE 1] Building Leann index...") - # CSR compact mode with recompute + # Use HNSW backend for better macOS compatibility builder = LeannBuilder( backend_name="hnsw", embedding_model="facebook/contriever", graph_degree=32, complexity=64, is_compact=True, - is_recompute=True + is_recompute=True, + num_threads=1 # Force single-threaded mode ) print(f"Loaded {len(all_texts)} text chunks from documents.") @@ -80,14 +82,17 @@ async def main(args): chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config) query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" + query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?" + query = "δ»€δΉˆζ˜―η›˜ε€ε€§ζ¨‘εž‹δ»₯εŠη›˜ε€εΌ€ε‘θΏ‡η¨‹δΈ­ι‡εˆ°δΊ†δ»€δΉˆι˜΄ζš—ι’οΌŒδ»»εŠ‘δ»€δΈ€θˆ¬εœ¨δ»€δΉˆεŸŽεΈ‚ι’ε‘" + print(f"You: {query}") - chat_response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True) + chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1) print(f"Leann: {chat_response}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.") - parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf"], help="The LLM backend to use.") - parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf).") + parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf", "openai"], help="The LLM backend to use.") + parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).") parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.") args = parser.parse_args() diff --git a/examples/multi_vector_aggregator.py b/examples/multi_vector_aggregator.py new file mode 100644 index 0000000..f2f1c2a --- /dev/null +++ b/examples/multi_vector_aggregator.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +""" +Multi-Vector Aggregator for Fat Embeddings +========================================== + +This module implements aggregation strategies for multi-vector embeddings, +similar to ColPali's approach where multiple patch vectors represent a single document. + +Key features: +- MaxSim aggregation (take maximum similarity across patches) +- Voting-based aggregation (count patch matches) +- Weighted aggregation (attention-score weighted) +- Spatial clustering of matching patches +- Document-level result consolidation +""" + +import numpy as np +from typing import List, Dict, Any, Tuple, Optional +from dataclasses import dataclass +from collections import defaultdict +import json + +@dataclass +class PatchResult: + """Represents a single patch search result.""" + patch_id: int + image_name: str + image_path: str + coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2) + score: float + attention_score: float + scale: float + metadata: Dict[str, Any] + +@dataclass +class AggregatedResult: + """Represents an aggregated document-level result.""" + image_name: str + image_path: str + doc_score: float + patch_count: int + best_patch: PatchResult + all_patches: List[PatchResult] + aggregation_method: str + spatial_clusters: Optional[List[List[PatchResult]]] = None + +class MultiVectorAggregator: + """ + Aggregates multiple patch-level results into document-level results. + """ + + def __init__(self, + aggregation_method: str = "maxsim", + spatial_clustering: bool = True, + cluster_distance_threshold: float = 100.0): + """ + Initialize the aggregator. + + Args: + aggregation_method: "maxsim", "voting", "weighted", or "mean" + spatial_clustering: Whether to cluster spatially close patches + cluster_distance_threshold: Distance threshold for spatial clustering + """ + self.aggregation_method = aggregation_method + self.spatial_clustering = spatial_clustering + self.cluster_distance_threshold = cluster_distance_threshold + + def aggregate_results(self, + search_results: List[Dict[str, Any]], + top_k: int = 10) -> List[AggregatedResult]: + """ + Aggregate patch-level search results into document-level results. + + Args: + search_results: List of search results from LeannSearcher + top_k: Number of top documents to return + + Returns: + List of aggregated document results + """ + # Group results by image + image_groups = defaultdict(list) + + for result in search_results: + metadata = result.metadata + if "image_name" in metadata and "patch_id" in metadata: + patch_result = PatchResult( + patch_id=metadata["patch_id"], + image_name=metadata["image_name"], + image_path=metadata["image_path"], + coordinates=tuple(metadata["coordinates"]), + score=result.score, + attention_score=metadata.get("attention_score", 0.0), + scale=metadata.get("scale", 1.0), + metadata=metadata + ) + image_groups[metadata["image_name"]].append(patch_result) + + # Aggregate each image group + aggregated_results = [] + for image_name, patches in image_groups.items(): + if len(patches) == 0: + continue + + agg_result = self._aggregate_image_patches(image_name, patches) + aggregated_results.append(agg_result) + + # Sort by aggregated score and return top-k + aggregated_results.sort(key=lambda x: x.doc_score, reverse=True) + return aggregated_results[:top_k] + + def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult: + """Aggregate patches for a single image.""" + + if self.aggregation_method == "maxsim": + doc_score = max(patch.score for patch in patches) + best_patch = max(patches, key=lambda p: p.score) + + elif self.aggregation_method == "voting": + # Count patches above threshold + threshold = np.percentile([p.score for p in patches], 75) + doc_score = sum(1 for patch in patches if patch.score >= threshold) + best_patch = max(patches, key=lambda p: p.score) + + elif self.aggregation_method == "weighted": + # Weight by attention scores + total_weighted_score = sum(p.score * p.attention_score for p in patches) + total_weights = sum(p.attention_score for p in patches) + doc_score = total_weighted_score / max(total_weights, 1e-8) + best_patch = max(patches, key=lambda p: p.score * p.attention_score) + + elif self.aggregation_method == "mean": + doc_score = np.mean([patch.score for patch in patches]) + best_patch = max(patches, key=lambda p: p.score) + + else: + raise ValueError(f"Unknown aggregation method: {self.aggregation_method}") + + # Spatial clustering if enabled + spatial_clusters = None + if self.spatial_clustering: + spatial_clusters = self._cluster_patches_spatially(patches) + + return AggregatedResult( + image_name=image_name, + image_path=patches[0].image_path, + doc_score=float(doc_score), + patch_count=len(patches), + best_patch=best_patch, + all_patches=sorted(patches, key=lambda p: p.score, reverse=True), + aggregation_method=self.aggregation_method, + spatial_clusters=spatial_clusters + ) + + def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]: + """Cluster patches that are spatially close to each other.""" + if len(patches) <= 1: + return [patches] + + clusters = [] + remaining_patches = patches.copy() + + while remaining_patches: + # Start new cluster with highest scoring remaining patch + seed_patch = max(remaining_patches, key=lambda p: p.score) + current_cluster = [seed_patch] + remaining_patches.remove(seed_patch) + + # Add nearby patches to cluster + added_to_cluster = True + while added_to_cluster: + added_to_cluster = False + for patch in remaining_patches.copy(): + if self._is_patch_nearby(patch, current_cluster): + current_cluster.append(patch) + remaining_patches.remove(patch) + added_to_cluster = True + + clusters.append(current_cluster) + + return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True) + + def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool: + """Check if a patch is spatially close to any patch in the cluster.""" + patch_center = self._get_patch_center(patch.coordinates) + + for cluster_patch in cluster: + cluster_center = self._get_patch_center(cluster_patch.coordinates) + distance = np.sqrt((patch_center[0] - cluster_center[0])**2 + + (patch_center[1] - cluster_center[1])**2) + + if distance <= self.cluster_distance_threshold: + return True + + return False + + def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]: + """Get center point of a patch.""" + x1, y1, x2, y2 = coordinates + return ((x1 + x2) / 2, (y1 + y2) / 2) + + def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3): + """Pretty print aggregated results.""" + print(f"\nπŸ” Aggregated Results (method: {self.aggregation_method})") + print("=" * 80) + + for i, result in enumerate(results): + print(f"\n{i+1}. {result.image_name}") + print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}") + print(f" Path: {result.image_path}") + + # Show best patch + best = result.best_patch + print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})") + + # Show top patches + print(f" πŸ“ Top Patches:") + for j, patch in enumerate(result.all_patches[:max_patches_per_doc]): + print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}") + + # Show spatial clusters if available + if result.spatial_clusters and len(result.spatial_clusters) > 1: + print(f" πŸ—‚οΈ Spatial Clusters: {len(result.spatial_clusters)}") + for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters + cluster_score = max(p.score for p in cluster) + print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})") + +def demo_aggregation(): + """Demonstrate the multi-vector aggregation functionality.""" + print("=== Multi-Vector Aggregation Demo ===") + + # Simulate some patch-level search results + # In real usage, these would come from LeannSearcher.search() + + class MockResult: + def __init__(self, score, metadata): + self.score = score + self.metadata = metadata + + # Simulate results for 2 images with multiple patches each + mock_results = [ + # Image 1: cats_and_kitchen.jpg - 4 patches + MockResult(0.85, { + "image_name": "cats_and_kitchen.jpg", + "image_path": "/path/to/cats_and_kitchen.jpg", + "patch_id": 3, + "coordinates": [100, 50, 224, 174], # Kitchen area + "attention_score": 0.92, + "scale": 1.0 + }), + MockResult(0.78, { + "image_name": "cats_and_kitchen.jpg", + "image_path": "/path/to/cats_and_kitchen.jpg", + "patch_id": 7, + "coordinates": [200, 300, 324, 424], # Cat area + "attention_score": 0.88, + "scale": 1.0 + }), + MockResult(0.72, { + "image_name": "cats_and_kitchen.jpg", + "image_path": "/path/to/cats_and_kitchen.jpg", + "patch_id": 12, + "coordinates": [150, 100, 274, 224], # Appliances + "attention_score": 0.75, + "scale": 1.0 + }), + MockResult(0.65, { + "image_name": "cats_and_kitchen.jpg", + "image_path": "/path/to/cats_and_kitchen.jpg", + "patch_id": 15, + "coordinates": [50, 250, 174, 374], # Furniture + "attention_score": 0.70, + "scale": 1.0 + }), + + # Image 2: city_street.jpg - 3 patches + MockResult(0.68, { + "image_name": "city_street.jpg", + "image_path": "/path/to/city_street.jpg", + "patch_id": 2, + "coordinates": [300, 100, 424, 224], # Buildings + "attention_score": 0.80, + "scale": 1.0 + }), + MockResult(0.62, { + "image_name": "city_street.jpg", + "image_path": "/path/to/city_street.jpg", + "patch_id": 8, + "coordinates": [100, 350, 224, 474], # Street level + "attention_score": 0.75, + "scale": 1.0 + }), + MockResult(0.55, { + "image_name": "city_street.jpg", + "image_path": "/path/to/city_street.jpg", + "patch_id": 11, + "coordinates": [400, 200, 524, 324], # Sky area + "attention_score": 0.60, + "scale": 1.0 + }), + ] + + # Test different aggregation methods + methods = ["maxsim", "voting", "weighted", "mean"] + + for method in methods: + print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}") + + aggregator = MultiVectorAggregator( + aggregation_method=method, + spatial_clustering=True, + cluster_distance_threshold=100.0 + ) + + aggregated = aggregator.aggregate_results(mock_results, top_k=5) + aggregator.print_aggregated_results(aggregated) + +if __name__ == "__main__": + demo_aggregation() \ No newline at end of file diff --git a/examples/resue_index.py b/examples/resue_index.py new file mode 100644 index 0000000..24dc3a1 --- /dev/null +++ b/examples/resue_index.py @@ -0,0 +1,18 @@ +import asyncio +from leann.api import LeannChat +from pathlib import Path + +INDEX_DIR = Path("./test_pdf_index_huawei") +INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann") + +async def main(): + print(f"\n[PHASE 2] Starting Leann chat session...") + chat = LeannChat(index_path=INDEX_PATH) + query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?" + query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" + # query = "δ»€δΉˆζ˜―η›˜ε€ε€§ζ¨‘εž‹δ»₯εŠη›˜ε€εΌ€ε‘θΏ‡η¨‹δΈ­ι‡εˆ°δΊ†δ»€δΉˆι˜΄ζš—ι’οΌŒδ»»εŠ‘δ»€δΈ€θˆ¬εœ¨δ»€δΉˆεŸŽεΈ‚ι’ε‘" + response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1) + print(f"\n[PHASE 2] Response: {response}") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/knowledge.leann.meta.json b/knowledge.leann.meta.json deleted file mode 100644 index 6a0d839..0000000 --- a/knowledge.leann.meta.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "version": "0.1.0", - "backend_name": "diskann", - "embedding_model": "sentence-transformers/all-mpnet-base-v2", - "num_chunks": 6, - "chunks": [ - { - "text": "Python is a powerful programming language", - "metadata": {} - }, - { - "text": "Machine learning transforms industries", - "metadata": {} - }, - { - "text": "Neural networks process complex data", - "metadata": {} - }, - { - "text": "Java is a powerful programming language", - "metadata": {} - }, - { - "text": "C++ is a powerful programming language", - "metadata": {} - }, - { - "text": "C# is a powerful programming language", - "metadata": {} - } - ] -} \ No newline at end of file diff --git a/packages/__init__.py b/packages/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/packages/__init__.py @@ -0,0 +1 @@ + diff --git a/packages/leann-backend-diskann/__init__.py b/packages/leann-backend-diskann/__init__.py new file mode 100644 index 0000000..3ff6d44 --- /dev/null +++ b/packages/leann-backend-diskann/__init__.py @@ -0,0 +1 @@ +# This file makes the directory a Python package \ No newline at end of file diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py index 6de653a..8c09e37 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py @@ -15,6 +15,8 @@ import os from contextlib import contextmanager import zmq import numpy as np +from pathlib import Path +import pickle RED = "\033[91m" RESET = "\033[0m" @@ -109,8 +111,6 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader: Load passages from a JSONL file with label map support Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line) """ - from pathlib import Path - import pickle if not os.path.exists(passages_file): raise FileNotFoundError(f"Passages file {passages_file} not found.") @@ -210,7 +210,6 @@ def create_embedding_server_thread( passages = load_passages_from_metadata(passages_file) else: # Try to find metadata file in same directory - from pathlib import Path passages_dir = Path(passages_file).parent meta_files = list(passages_dir.glob("*.meta.json")) if meta_files: diff --git a/packages/leann-backend-hnsw/CMakeLists.txt b/packages/leann-backend-hnsw/CMakeLists.txt index 6865da3..bcadd12 100644 --- a/packages/leann-backend-hnsw/CMakeLists.txt +++ b/packages/leann-backend-hnsw/CMakeLists.txt @@ -2,6 +2,33 @@ cmake_minimum_required(VERSION 3.24) project(leann_backend_hnsw_wrapper) +# Set OpenMP path for macOS +if(APPLE) + set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include") + set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include") + set(OpenMP_C_LIB_NAMES "omp") + set(OpenMP_CXX_LIB_NAMES "omp") + set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib") +endif() + +# Build ZeroMQ from source +set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE) +set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE) +set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE) +set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE) +set(WITH_DOCS OFF CACHE BOOL "" FORCE) +set(BUILD_SHARED OFF CACHE BOOL "" FORCE) +set(BUILD_STATIC ON CACHE BOOL "" FORCE) +add_subdirectory(third_party/libzmq) + +# Add cppzmq headers +include_directories(third_party/cppzmq) + +# Configure msgpack-c - disable boost dependency +set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE) +add_compile_definitions(MSGPACK_NO_BOOST) +include_directories(third_party/msgpack-c/include) + set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE) set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE) set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE) diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index 229819a..98b96ef 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -4,6 +4,7 @@ import json from pathlib import Path from typing import Dict, Any, List import pickle +import shutil from leann.searcher_base import BaseSearcher from .convert_to_csr import convert_hnsw_graph_to_csr @@ -77,17 +78,29 @@ class HNSWBuilder(LeannBackendBuilderInterface): self._convert_to_csr(index_file) def _convert_to_csr(self, index_file: Path): + """Convert built index to CSR format""" + mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard" + print(f"INFO: Converting HNSW index to {mode_str} format...") + csr_temp_file = index_file.with_suffix(".csr.tmp") + success = convert_hnsw_graph_to_csr( - str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute + str(index_file), + str(csr_temp_file), + prune_embeddings=self.is_recompute ) + if success: - import shutil + print("βœ… CSR conversion successful.") + index_file_old = index_file.with_suffix(".old") + shutil.move(str(index_file), str(index_file_old)) shutil.move(str(csr_temp_file), str(index_file)) + print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'") else: + # Clean up and fail fast if csr_temp_file.exists(): os.remove(csr_temp_file) - raise RuntimeError("CSR conversion failed") + raise RuntimeError("CSR conversion failed - cannot proceed with compact format") class HNSWSearcher(BaseSearcher): def __init__(self, index_path: str, **kwargs): @@ -99,7 +112,10 @@ class HNSWSearcher(BaseSearcher): if metric_enum is None: raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") - self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta() + self.is_compact, self.is_pruned = ( + self.meta.get('is_compact', True), + self.meta.get('is_pruned', True) + ) index_file = self.index_dir / f"{self.index_path.stem}.index" if not index_file.exists(): @@ -114,11 +130,6 @@ class HNSWSearcher(BaseSearcher): self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) - def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]: - is_compact = self.meta.get('is_compact', True) - is_pruned = self.meta.get('is_pruned', True) - return is_compact, is_pruned - def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: from . import faiss diff --git a/packages/leann-core/src/leann/__init__.py b/packages/leann-core/src/leann/__init__.py index 2f19395..673e71e 100644 --- a/packages/leann-core/src/leann/__init__.py +++ b/packages/leann-core/src/leann/__init__.py @@ -1,4 +1,14 @@ # packages/leann-core/src/leann/__init__.py +import os +import platform + +# Fix OpenMP threading issues on macOS ARM64 +if platform.system() == "Darwin": + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + os.environ["KMP_BLOCKTIME"] = "0" + from .api import LeannBuilder, LeannChat, LeannSearcher from .registry import BACKEND_REGISTRY, autodiscover_backends diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 00ac2f2..2e90d8d 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 """ This file contains the core API for the LEANN project, now definitively updated with the correct, original embedding logic from the user's reference code. @@ -11,6 +10,7 @@ from pathlib import Path from typing import List, Dict, Any, Optional from dataclasses import dataclass, field import uuid +import torch from .registry import BACKEND_REGISTRY from .interface import LeannBackendFactoryInterface @@ -25,13 +25,22 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray: raise RuntimeError( f"sentence-transformers not available. Install with: pip install sentence-transformers" ) from e - + # Load model using sentence-transformers model = SentenceTransformer(model_name) - + + model = model.half() + print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...") + # use acclerater GPU or MAC GPU + + if torch.cuda.is_available(): + model = model.to("cuda") + elif torch.backends.mps.is_available(): + model = model.to("mps") + # Generate embeddings embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64) - + return embeddings # --- Core API Classes (Restored and Unchanged) --- @@ -181,5 +190,25 @@ class LeannChat: def ask(self, question: str, top_k=5, **kwargs): results = self.searcher.search(question, top_k=top_k, **kwargs) context = "\n\n".join([r.text for r in results]) - prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:" - return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {})) \ No newline at end of file + prompt = ( + "Here is some retrieved context that might help answer your question:\n\n" + f"{context}\n\n" + f"Question: {question}\n\n" + "Please provide the best answer you can based on this context and your knowledge." + ) + return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {})) + + def start_interactive(self): + print("\nLeann Chat started (type 'quit' to exit)") + while True: + try: + user_input = input("You: ").strip() + if user_input.lower() in ['quit', 'exit']: + break + if not user_input: + continue + response = self.ask(user_input) + print(f"Leann: {response}") + except (KeyboardInterrupt, EOFError): + print("\nGoodbye!") + break \ No newline at end of file diff --git a/packages/leann-core/src/leann/chat.py b/packages/leann-core/src/leann/chat.py index 5f50dd5..1a9bc9b 100644 --- a/packages/leann-core/src/leann/chat.py +++ b/packages/leann-core/src/leann/chat.py @@ -7,6 +7,7 @@ supporting different backends like Ollama, Hugging Face Transformers, and a simu from abc import ABC, abstractmethod from typing import Dict, Any, Optional import logging +import os # Configure logging logging.basicConfig(level=logging.INFO) @@ -95,7 +96,57 @@ class HFChat(LLMInterface): } logger.info(f"Generating text with Hugging Face model with params: {params}") results = self.pipeline(prompt, **params) - return results[0]['generated_text'] + + # Handle different response formats from transformers + if isinstance(results, list) and len(results) > 0: + generated_text = results[0].get('generated_text', '') if isinstance(results[0], dict) else str(results[0]) + else: + generated_text = str(results) + + # Extract only the newly generated portion by removing the original prompt + if isinstance(generated_text, str) and generated_text.startswith(prompt): + response = generated_text[len(prompt):].strip() + else: + # Fallback: return the full response if prompt removal fails + response = str(generated_text) + + return response + +class OpenAIChat(LLMInterface): + """LLM interface for OpenAI models.""" + def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None): + self.model = model + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + + if not self.api_key: + raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") + + logger.info(f"Initializing OpenAI Chat with model='{model}'") + + try: + import openai + self.client = openai.OpenAI(api_key=self.api_key) + except ImportError: + raise ImportError("The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'.") + + def ask(self, prompt: str, **kwargs) -> str: + # Default parameters for OpenAI + params = { + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": kwargs.get("max_tokens", 1000), + "temperature": kwargs.get("temperature", 0.7), + **{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]} + } + + logger.info(f"Sending request to OpenAI with model {self.model}") + + try: + response = self.client.chat.completions.create(**params) + return response.choices[0].message.content.strip() + except Exception as e: + logger.error(f"Error communicating with OpenAI: {e}") + return f"Error: Could not get a response from OpenAI. Details: {e}" class SimulatedChat(LLMInterface): """A simple simulated chat for testing and development.""" @@ -127,9 +178,11 @@ def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface: logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'") if llm_type == "ollama": - return OllamaChat(model=model, host=llm_config.get("host")) + return OllamaChat(model=model or "llama3:8b", host=llm_config.get("host", "http://localhost:11434")) elif llm_type == "hf": - return HFChat(model_name=model) + return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat") + elif llm_type == "openai": + return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key")) elif llm_type == "simulated": return SimulatedChat() else: diff --git a/pyproject.toml b/pyproject.toml index 2ce90a7..349689c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "colorama", "boto3", "protobuf==4.25.3", - "sglang[all]", + "sglang", "ollama", "requests>=2.25.0", "sentence-transformers>=2.2.0", diff --git a/test/mail_reader_llamaindex.py b/test/mail_reader_llamaindex.py new file mode 100644 index 0000000..d0a8bdc --- /dev/null +++ b/test/mail_reader_llamaindex.py @@ -0,0 +1,147 @@ +import os +import email +from pathlib import Path +from typing import List, Any +from llama_index.core import VectorStoreIndex, Document +from llama_index.core.readers.base import BaseReader + +class EmlxReader(BaseReader): + """ + Apple Mail .emlx file reader. + + Reads individual .emlx files from Apple Mail's storage format. + """ + + def __init__(self) -> None: + """Initialize.""" + pass + + def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: + """ + Load data from the input directory containing .emlx files. + + Args: + input_dir: Directory containing .emlx files + **load_kwargs: + max_count (int): Maximum amount of messages to read. + """ + docs: List[Document] = [] + max_count = load_kwargs.get('max_count', 1000) + count = 0 + + # Walk through the directory recursively + for dirpath, dirnames, filenames in os.walk(input_dir): + # Skip hidden directories + dirnames[:] = [d for d in dirnames if not d.startswith(".")] + + for filename in filenames: + if count >= max_count: + break + + if filename.endswith(".emlx"): + filepath = os.path.join(dirpath, filename) + try: + # Read the .emlx file + with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + # .emlx files have a length prefix followed by the email content + # The first line contains the length, followed by the email + lines = content.split('\n', 1) + if len(lines) >= 2: + email_content = lines[1] + + # Parse the email using Python's email module + try: + msg = email.message_from_string(email_content) + + # Extract email metadata + subject = msg.get('Subject', 'No Subject') + from_addr = msg.get('From', 'Unknown') + to_addr = msg.get('To', 'Unknown') + date = msg.get('Date', 'Unknown') + + # Extract email body + body = "" + if msg.is_multipart(): + for part in msg.walk(): + if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html": + body += part.get_payload(decode=True).decode('utf-8', errors='ignore') + # break + else: + body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') + + # Create document content + doc_content = f""" +From: {from_addr} +To: {to_addr} +Subject: {subject} +Date: {date} + +{body} +""" + + # Create metadata + metadata = { + 'file_path': filepath, + 'subject': subject, + 'from': from_addr, + 'to': to_addr, + 'date': date, + 'filename': filename + } + if count == 0: + print("--------------------------------") + print('dir path', dirpath) + print(metadata) + print(doc_content) + print("--------------------------------") + body=[] + if msg.is_multipart(): + for part in msg.walk(): + print("-------------------------------- get content type -------------------------------") + print(part.get_content_type()) + print(part) + # body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore')) + print("-------------------------------- get content type -------------------------------") + else: + body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') + print(body) + + print(body) + print("--------------------------------") + doc = Document(text=doc_content, metadata=metadata) + docs.append(doc) + count += 1 + + except Exception as e: + print(f"!!!!!!! Error parsing email from {filepath}: {e} !!!!!!!!") + continue + + except Exception as e: + print(f"!!!!!!! Error reading file !!!!!!!! {filepath}: {e}") + continue + + print(f"Loaded {len(docs)} email documents") + return docs + +# Use the custom EmlxReader instead of MboxReader +documents = EmlxReader().load_data( + "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages", + max_count=1000 +) # Returns list of documents + +# Configure the index with larger chunk size to handle long metadata +from llama_index.core.node_parser import SentenceSplitter + +# Create a custom text splitter with larger chunk size +text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200) + +index = VectorStoreIndex.from_documents( + documents, + transformations=[text_splitter] +) # Initialize index with documents + +query_engine = index.as_query_engine() +res = query_engine.query("Hows Berkeley Graduate Student Instructor") +print(res) \ No newline at end of file diff --git a/test/mail_reader_save_load.py b/test/mail_reader_save_load.py new file mode 100644 index 0000000..60329b5 --- /dev/null +++ b/test/mail_reader_save_load.py @@ -0,0 +1,213 @@ +import os +import email +from pathlib import Path +from typing import List, Any +from llama_index.core import VectorStoreIndex, Document, StorageContext +from llama_index.core.readers.base import BaseReader +from llama_index.core.node_parser import SentenceSplitter + +class EmlxReader(BaseReader): + """ + Apple Mail .emlx file reader. + + Reads individual .emlx files from Apple Mail's storage format. + """ + + def __init__(self) -> None: + """Initialize.""" + pass + + def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: + """ + Load data from the input directory containing .emlx files. + + Args: + input_dir: Directory containing .emlx files + **load_kwargs: + max_count (int): Maximum amount of messages to read. + """ + docs: List[Document] = [] + max_count = load_kwargs.get('max_count', 1000) + count = 0 + + # Walk through the directory recursively + for dirpath, dirnames, filenames in os.walk(input_dir): + # Skip hidden directories + dirnames[:] = [d for d in dirnames if not d.startswith(".")] + + for filename in filenames: + if count >= max_count: + break + + if filename.endswith(".emlx"): + filepath = os.path.join(dirpath, filename) + try: + # Read the .emlx file + with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + # .emlx files have a length prefix followed by the email content + # The first line contains the length, followed by the email + lines = content.split('\n', 1) + if len(lines) >= 2: + email_content = lines[1] + + # Parse the email using Python's email module + try: + msg = email.message_from_string(email_content) + + # Extract email metadata + subject = msg.get('Subject', 'No Subject') + from_addr = msg.get('From', 'Unknown') + to_addr = msg.get('To', 'Unknown') + date = msg.get('Date', 'Unknown') + + # Extract email body + body = "" + if msg.is_multipart(): + for part in msg.walk(): + if part.get_content_type() == "text/plain": + body = part.get_payload(decode=True).decode('utf-8', errors='ignore') + break + else: + body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') + + # Create document content + doc_content = f""" +From: {from_addr} +To: {to_addr} +Subject: {subject} +Date: {date} + +{body} +""" + + # Create metadata + metadata = { + 'file_path': filepath, + 'subject': subject, + 'from': from_addr, + 'to': to_addr, + 'date': date, + 'filename': filename + } + + doc = Document(text=doc_content, metadata=metadata) + docs.append(doc) + count += 1 + + except Exception as e: + print(f"Error parsing email from {filepath}: {e}") + continue + + except Exception as e: + print(f"Error reading file {filepath}: {e}") + continue + + print(f"Loaded {len(docs)} email documents") + return docs + +def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000): + """ + Create the index from mail data and save it to disk. + + Args: + mail_path: Path to the mail directory + save_dir: Directory to save the index + max_count: Maximum number of emails to process + """ + print("Creating index from mail data...") + + # Load documents + documents = EmlxReader().load_data(mail_path, max_count=max_count) + + if not documents: + print("No documents loaded. Exiting.") + return None + + # Create text splitter + text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0) + + # Create index + index = VectorStoreIndex.from_documents( + documents, + transformations=[text_splitter] + ) + + # Save the index + os.makedirs(save_dir, exist_ok=True) + index.storage_context.persist(persist_dir=save_dir) + print(f"Index saved to {save_dir}") + + return index + +def load_index(save_dir: str = "mail_index"): + """ + Load the saved index from disk. + + Args: + save_dir: Directory where the index is saved + + Returns: + Loaded index or None if loading fails + """ + try: + # Load storage context + storage_context = StorageContext.from_defaults(persist_dir=save_dir) + + # Load index + index = VectorStoreIndex.from_vector_store( + storage_context.vector_store, + storage_context=storage_context + ) + + print(f"Index loaded from {save_dir}") + return index + + except Exception as e: + print(f"Error loading index: {e}") + return None + +def query_index(index, query: str): + """ + Query the loaded index. + + Args: + index: The loaded index + query: The query string + """ + if index is None: + print("No index available for querying.") + return + + query_engine = index.as_query_engine() + response = query_engine.query(query) + print(f"Query: {query}") + print(f"Response: {response}") + +def main(): + mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages" + save_dir = "mail_index" + + # Check if index already exists + if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")): + print("Loading existing index...") + index = load_index(save_dir) + else: + print("Creating new index...") + index = create_and_save_index(mail_path, save_dir, max_count=1000) + + if index: + # Example queries + queries = [ + "Hows Berkeley Graduate Student Instructor", + "What emails mention GSR appointments?", + "Find emails about deadlines" + ] + + for query in queries: + print("\n" + "="*50) + query_index(index, query) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/mail_reader_small_chunks.py b/test/mail_reader_small_chunks.py new file mode 100644 index 0000000..024a1d1 --- /dev/null +++ b/test/mail_reader_small_chunks.py @@ -0,0 +1,211 @@ +import os +import email +from pathlib import Path +from typing import List, Any +from llama_index.core import VectorStoreIndex, Document, StorageContext +from llama_index.core.readers.base import BaseReader +from llama_index.core.node_parser import SentenceSplitter + +class EmlxReader(BaseReader): + """ + Apple Mail .emlx file reader with reduced metadata. + + Reads individual .emlx files from Apple Mail's storage format. + """ + + def __init__(self) -> None: + """Initialize.""" + pass + + def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: + """ + Load data from the input directory containing .emlx files. + + Args: + input_dir: Directory containing .emlx files + **load_kwargs: + max_count (int): Maximum amount of messages to read. + """ + docs: List[Document] = [] + max_count = load_kwargs.get('max_count', 1000) + count = 0 + + # Walk through the directory recursively + for dirpath, dirnames, filenames in os.walk(input_dir): + # Skip hidden directories + dirnames[:] = [d for d in dirnames if not d.startswith(".")] + + for filename in filenames: + if count >= max_count: + break + + if filename.endswith(".emlx"): + filepath = os.path.join(dirpath, filename) + try: + # Read the .emlx file + with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + # .emlx files have a length prefix followed by the email content + # The first line contains the length, followed by the email + lines = content.split('\n', 1) + if len(lines) >= 2: + email_content = lines[1] + + # Parse the email using Python's email module + try: + msg = email.message_from_string(email_content) + + # Extract email metadata + subject = msg.get('Subject', 'No Subject') + from_addr = msg.get('From', 'Unknown') + to_addr = msg.get('To', 'Unknown') + date = msg.get('Date', 'Unknown') + + # Extract email body + body = "" + if msg.is_multipart(): + for part in msg.walk(): + if part.get_content_type() == "text/plain": + body = part.get_payload(decode=True).decode('utf-8', errors='ignore') + break + else: + body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') + + # Create document content with metadata embedded in text + doc_content = f""" +From: {from_addr} +To: {to_addr} +Subject: {subject} +Date: {date} + +{body} +""" + + # Create minimal metadata (only essential info) + metadata = { + 'subject': subject[:50], # Truncate subject + 'from': from_addr[:30], # Truncate from + 'date': date[:20], # Truncate date + 'filename': filename # Keep filename + } + + doc = Document(text=doc_content, metadata=metadata) + docs.append(doc) + count += 1 + + except Exception as e: + print(f"Error parsing email from {filepath}: {e}") + continue + + except Exception as e: + print(f"Error reading file {filepath}: {e}") + continue + + print(f"Loaded {len(docs)} email documents") + return docs + +def create_and_save_index(mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000): + """ + Create the index from mail data and save it to disk. + + Args: + mail_path: Path to the mail directory + save_dir: Directory to save the index + max_count: Maximum number of emails to process + """ + print("Creating index from mail data with small chunks...") + + # Load documents + documents = EmlxReader().load_data(mail_path, max_count=max_count) + + if not documents: + print("No documents loaded. Exiting.") + return None + + # Create text splitter with small chunk size + text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50) + + # Create index + index = VectorStoreIndex.from_documents( + documents, + transformations=[text_splitter] + ) + + # Save the index + os.makedirs(save_dir, exist_ok=True) + index.storage_context.persist(persist_dir=save_dir) + print(f"Index saved to {save_dir}") + + return index + +def load_index(save_dir: str = "mail_index_small"): + """ + Load the saved index from disk. + + Args: + save_dir: Directory where the index is saved + + Returns: + Loaded index or None if loading fails + """ + try: + # Load storage context + storage_context = StorageContext.from_defaults(persist_dir=save_dir) + + # Load index + index = VectorStoreIndex.from_vector_store( + storage_context.vector_store, + storage_context=storage_context + ) + + print(f"Index loaded from {save_dir}") + return index + + except Exception as e: + print(f"Error loading index: {e}") + return None + +def query_index(index, query: str): + """ + Query the loaded index. + + Args: + index: The loaded index + query: The query string + """ + if index is None: + print("No index available for querying.") + return + + query_engine = index.as_query_engine() + response = query_engine.query(query) + print(f"Query: {query}") + print(f"Response: {response}") + +def main(): + mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages" + save_dir = "mail_index_small" + + # Check if index already exists + if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")): + print("Loading existing index...") + index = load_index(save_dir) + else: + print("Creating new index...") + index = create_and_save_index(mail_path, save_dir, max_count=1000) + + if index: + # Example queries + queries = [ + "Hows Berkeley Graduate Student Instructor", + "What emails mention GSR appointments?", + "Find emails about deadlines" + ] + + for query in queries: + print("\n" + "="*50) + query_index(index, query) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/mail_reader_test.py b/test/mail_reader_test.py new file mode 100644 index 0000000..9dfd6b6 --- /dev/null +++ b/test/mail_reader_test.py @@ -0,0 +1,147 @@ +import os +import email +from pathlib import Path +from typing import List, Any +from llama_index.core import VectorStoreIndex, Document +from llama_index.core.readers.base import BaseReader + +class EmlxReader(BaseReader): + """ + Apple Mail .emlx file reader. + + Reads individual .emlx files from Apple Mail's storage format. + """ + + def __init__(self) -> None: + """Initialize.""" + pass + + def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: + """ + Load data from the input directory containing .emlx files. + + Args: + input_dir: Directory containing .emlx files + **load_kwargs: + max_count (int): Maximum amount of messages to read. + """ + docs: List[Document] = [] + max_count = load_kwargs.get('max_count', 1000) + count = 0 + + # Check if directory exists and is accessible + if not os.path.exists(input_dir): + print(f"Error: Directory '{input_dir}' does not exist") + return docs + + if not os.access(input_dir, os.R_OK): + print(f"Error: Directory '{input_dir}' is not accessible (permission denied)") + print("This is likely due to macOS security restrictions on Mail app data") + return docs + + print(f"Scanning directory: {input_dir}") + + # Walk through the directory recursively + for dirpath, dirnames, filenames in os.walk(input_dir): + # Skip hidden directories + dirnames[:] = [d for d in dirnames if not d.startswith(".")] + + for filename in filenames: + if count >= max_count: + break + + if filename.endswith(".emlx"): + filepath = os.path.join(dirpath, filename) + print(f"Found .emlx file: {filepath}") + try: + # Read the .emlx file + with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + # .emlx files have a length prefix followed by the email content + # The first line contains the length, followed by the email + lines = content.split('\n', 1) + if len(lines) >= 2: + email_content = lines[1] + + # Parse the email using Python's email module + try: + msg = email.message_from_string(email_content) + + # Extract email metadata + subject = msg.get('Subject', 'No Subject') + from_addr = msg.get('From', 'Unknown') + to_addr = msg.get('To', 'Unknown') + date = msg.get('Date', 'Unknown') + + # Extract email body + body = "" + if msg.is_multipart(): + for part in msg.walk(): + if part.get_content_type() == "text/plain": + body = part.get_payload(decode=True).decode('utf-8', errors='ignore') + break + else: + body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') + + # Create document content + doc_content = f""" +From: {from_addr} +To: {to_addr} +Subject: {subject} +Date: {date} + +{body} +""" + + # Create metadata + metadata = { + 'file_path': filepath, + 'subject': subject, + 'from': from_addr, + 'to': to_addr, + 'date': date, + 'filename': filename + } + + doc = Document(text=doc_content, metadata=metadata) + docs.append(doc) + count += 1 + + except Exception as e: + print(f"Error parsing email from {filepath}: {e}") + continue + + except Exception as e: + print(f"Error reading file {filepath}: {e}") + continue + + print(f"Loaded {len(docs)} email documents") + return docs + +def main(): + # Use the current directory where the sample.emlx file is located + current_dir = os.path.dirname(os.path.abspath(__file__)) + + print("Testing EmlxReader with sample .emlx file...") + print(f"Scanning directory: {current_dir}") + + # Use the custom EmlxReader + documents = EmlxReader().load_data(current_dir, max_count=1000) + + if not documents: + print("No documents loaded. Make sure sample.emlx exists in the examples directory.") + return + + print(f"\nSuccessfully loaded {len(documents)} document(s)") + + # Initialize index with documents + index = VectorStoreIndex.from_documents(documents) + query_engine = index.as_query_engine() + + print("\nTesting query: 'Hows Berkeley Graduate Student Instructor'") + res = query_engine.query("Hows Berkeley Graduate Student Instructor") + print(f"Response: {res}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/query_saved_index.py b/test/query_saved_index.py new file mode 100644 index 0000000..ac3989d --- /dev/null +++ b/test/query_saved_index.py @@ -0,0 +1,99 @@ +import os +from llama_index.core import VectorStoreIndex, StorageContext + +def load_index(save_dir: str = "mail_index"): + """ + Load the saved index from disk. + + Args: + save_dir: Directory where the index is saved + + Returns: + Loaded index or None if loading fails + """ + try: + # Load storage context + storage_context = StorageContext.from_defaults(persist_dir=save_dir) + + # Load index + index = VectorStoreIndex.from_vector_store( + storage_context.vector_store, + storage_context=storage_context + ) + + print(f"Index loaded from {save_dir}") + return index + + except Exception as e: + print(f"Error loading index: {e}") + return None + +def query_index(index, query: str): + """ + Query the loaded index. + + Args: + index: The loaded index + query: The query string + """ + if index is None: + print("No index available for querying.") + return + + query_engine = index.as_query_engine() + response = query_engine.query(query) + print(f"\nQuery: {query}") + print(f"Response: {response}") + +def main(): + save_dir = "mail_index" + + # Check if index exists + if not os.path.exists(save_dir) or not os.path.exists(os.path.join(save_dir, "vector_store.json")): + print(f"Index not found in {save_dir}") + print("Please run mail_reader_save_load.py first to create the index.") + return + + # Load the index + index = load_index(save_dir) + + if not index: + print("Failed to load index.") + return + + print("\n" + "="*60) + print("Email Query Interface") + print("="*60) + print("Type 'quit' to exit") + print("Type 'help' for example queries") + print("="*60) + + # Interactive query loop + while True: + try: + query = input("\nEnter your query: ").strip() + + if query.lower() == 'quit': + print("Goodbye!") + break + elif query.lower() == 'help': + print("\nExample queries:") + print("- Hows Berkeley Graduate Student Instructor") + print("- What emails mention GSR appointments?") + print("- Find emails about deadlines") + print("- Search for emails from specific sender") + print("- Find emails about meetings") + continue + elif not query: + continue + + query_index(index, query) + + except KeyboardInterrupt: + print("\nGoodbye!") + break + except Exception as e: + print(f"Error processing query: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file