Compare commits

...

25 Commits

Author SHA1 Message Date
Andy Lee
ec5e9ac33b feat: chat on mps 2025-07-12 06:07:43 +00:00
Andy Lee
d288946173 Merge remote-tracking branch 'origin/main' into datastore-reproduce 2025-07-12 05:42:16 +00:00
Andy Lee
0da08fbe38 refactor: chat and base searcher 2025-07-11 16:34:12 +00:00
Andy Lee
8bffb1e5b8 feat: reproducible research datas, rpj_wiki & dpr 2025-07-11 02:58:04 +00:00
yichuan520030910320
16ee9d0422 add traverse all dict interface 2025-07-10 15:59:16 -07:00
yichuan520030910320
8a961f8ab3 align the llamaindex result w leann& test attachment 2025-07-09 21:42:15 -07:00
yichuan520030910320
558126c46e add leann and llamaindex email infra, and need to align the results 2025-07-09 16:27:11 -07:00
yichuan520030910320
04c9684488 add email test code 2025-07-09 15:06:31 -07:00
Andy Lee
b744faa7e6 chore: all deps 2025-07-08 23:37:40 +00:00
Andy Lee
27b3a26e75 fix(deps): Update DiskANN with cleaned up CMake configuration 2025-07-08 23:27:05 +00:00
Andy Lee
41d872504e feat(deps): Update DiskANN to use system-installed Boost and Protobuf 2025-07-08 23:13:36 +00:00
Andy Lee
963cd05273 chore: diskann modules 2025-07-08 21:57:38 +00:00
Andy Lee
09b6e67baf chore: diskann upg boost 2025-07-08 21:44:44 +00:00
yichuan520030910320
dafb2aacab update macos env 2025-07-08 14:37:41 -07:00
Andy Lee
a6c400cd4f chroe: linux boost and protobuf 2025-07-08 21:25:43 +00:00
Andy Lee
c013e5ccce chore: linux deps 2025-07-08 13:55:39 -07:00
Andy Lee
f25a1a3840 chore: macos compatible 2025-07-08 13:32:00 -07:00
yichuan520030910320
6497e17671 add gpu chunk embedd and add complexity in hnsw 2025-07-08 18:40:52 +00:00
yichuan520030910320
44369a8138 update diskann module 2025-07-07 18:27:07 -07:00
yichuan520030910320
dfca00c21b add mac support in this repo 2025-07-07 18:22:24 -07:00
yichuan520030910320
637dab379e add workaround code 2025-07-07 23:13:47 +00:00
yichuan520030910320
6fc57eb48e add reuse code 2025-07-07 21:07:00 +00:00
yichuan520030910320
95a653993a rm useless 2025-07-06 06:47:20 +00:00
yichuan520030910320
af0959818d rm useless 2025-07-06 05:21:05 +00:00
Andy Lee
cf17c85607 Make DiskANN and HNSW work on main example (#2)
* fix: diskann zmq port and passages

* feat: auto discovery of packages and fix passage gen for diskann
2025-07-05 22:18:12 -07:00
33 changed files with 2846 additions and 1049 deletions

5
.gitignore vendored
View File

@@ -8,11 +8,15 @@ demo/indices/
*pycache*
outputs/
*.pkl
*.pdf
.history/
scripts/
lm_eval.egg-info/
demo/experiment_results/**/*.json
*.jsonl
*.eml
*.emlx
*.json
*.sh
*.txt
!CMakeLists.txt
@@ -42,6 +46,7 @@ embedding_comparison_results/
*.ivecs
*.index
*.bin
*.old
read_graph
analyze_diskann_graph

10
.gitmodules vendored
View File

@@ -4,3 +4,13 @@
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
path = packages/leann-backend-hnsw/third_party/faiss
url = https://github.com/yichuan520030910320/faiss.git
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
path = packages/leann-backend-hnsw/third_party/msgpack-c
url = https://github.com/msgpack/msgpack-c.git
branch = cpp_master
[submodule "packages/leann-backend-hnsw/third_party/cppzmq"]
path = packages/leann-backend-hnsw/third_party/cppzmq
url = https://github.com/zeromq/cppzmq.git
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
path = packages/leann-backend-hnsw/third_party/libzmq
url = https://github.com/zeromq/libzmq.git

105
README.md
View File

@@ -28,13 +28,15 @@
### 🎯 Why Leann?
Traditional RAG systems face a fundamental trade-off:
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change
- **💰 Cost**: Vector databases are expensive to scale
**Leann solves this by:**
-**Zero embedding storage** - Only graph structure is persisted
-**Real-time computation** - Embeddings computed on-demand with ms latency
-**Real-time computation** - Embeddings computed on-demand with ms latency
-**Memory efficient** - Runs on consumer hardware (8GB RAM)
-**Always fresh** - No stale embeddings, ever
@@ -46,6 +48,19 @@ Traditional RAG systems face a fundamental trade-off:
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
cd leann
git submodule update --init --recursive
```
**macOS:**
```bash
brew install llvm libomp boost protobuf
export CC=$(brew --prefix llvm)/bin/clang
export CXX=$(brew --prefix llvm)/bin/clang++
uv sync
```
**Linux (Ubuntu/Debian):**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
uv sync
```
@@ -75,19 +90,39 @@ for result in results:
uv run examples/document_search.py
```
or you want to use python
```bash
source .venv/bin/activate
python ./examples/main_cli_example.py
```
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
This demo showcases how to build a RAG system for PDF documents using Leann.
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
```bash
uv run examples/main_cli_example.py
```
### Regenerating Protobuf Files
If you modify any `.proto` files (such as `embedding.proto`), or if you see errors about protobuf version mismatch, **regenerate the C++ protobuf files** to match your installed version:
```bash
cd packages/leann-backend-diskann
protoc --cpp_out=third_party/DiskANN/include --proto_path=third_party embedding.proto
protoc --cpp_out=third_party/DiskANN/src --proto_path=third_party embedding.proto
```
This ensures the generated files are compatible with your system's protobuf library.
## ✨ Features
### 🔥 Core Features
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
@@ -95,6 +130,7 @@ uv run examples/main_cli_example.py
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
### 🛠️ Technical Highlights
- **Zero-copy operations** for maximum performance
- **SIMD-optimized** distance computations (AVX2/AVX512)
- **Async embedding pipeline** with batched processing
@@ -102,6 +138,7 @@ uv run examples/main_cli_example.py
- **Recompute mode** for highest accuracy scenarios
### 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
@@ -111,19 +148,19 @@ uv run examples/main_cli_example.py
### Memory Usage Comparison
| System | 1M Documents | 10M Documents | 100M Documents |
|--------|-------------|---------------|----------------|
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
| System | 1M Documents | 10M Documents | 100M Documents |
| --------------------- | ---------------- | ---------------- | ---------------- |
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
### Query Performance
| Backend | Index Size | Query Time | Recall@10 |
|---------|------------|------------|-----------|
| DiskANN | 1M docs | 12ms | 0.95 |
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
| HNSW | 1M docs | 8ms | 0.93 |
| Backend | Index Size | Query Time | Recall@10 |
| ------------------- | ---------- | ---------- | --------- |
| DiskANN | 1M docs | 12ms | 0.95 |
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
| HNSW | 1M docs | 8ms | 0.93 |
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
@@ -145,26 +182,29 @@ uv run examples/main_cli_example.py
### Key Components
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
2. **📊 Graph Index**: Memory-efficient navigation structures
2. **📊 Graph Index**: Memory-efficient navigation structures
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
## 🎓 Supported Models & Backends
### 🤖 Embedding Models
- **sentence-transformers/all-mpnet-base-v2** (default)
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
- Any HuggingFace sentence-transformer model
- Custom model support via API
### 🔧 Search Backends
### 🔧 Search Backends
- **DiskANN**: Microsoft's billion-scale ANN algorithm
- **HNSW**: Hierarchical Navigable Small World graphs
- **Coming soon**: ScaNN, Faiss-IVF, NGT
### 📏 Distance Functions
- **L2**: Euclidean distance for precise similarity
- **Cosine**: Angular similarity for normalized vectors
- **Cosine**: Angular similarity for normalized vectors
- **MIPS**: Maximum Inner Product Search for recommendation systems
## 🔬 Paper
@@ -188,6 +228,7 @@ If you find Leann useful, please cite:
## 🌍 Use Cases
### 💼 Enterprise RAG
```python
# Handle millions of documents with limited resources
builder = LeannBuilder(
@@ -198,7 +239,8 @@ builder = LeannBuilder(
)
```
### 🔬 Research & Experimentation
### 🔬 Research & Experimentation
```python
# Quick prototyping with different algorithms
for backend in ["diskann", "hnsw"]:
@@ -207,6 +249,7 @@ for backend in ["diskann", "hnsw"]:
```
### 🚀 Real-time Applications
```python
# Sub-second response times
chat = LeannChat("knowledge.leann")
@@ -219,6 +262,7 @@ response = chat.ask("What is quantum computing?")
We welcome contributions! Leann is built by the community, for the community.
### Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
@@ -226,14 +270,17 @@ We welcome contributions! Leann is built by the community, for the community.
- 🧪 **Benchmarks**: Share your performance results
### Development Setup
```bash
git clone https://github.com/yourname/leann
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
cd leann
git submodule update --init --recursive
uv sync --dev
uv run pytest tests/
```
### Quick Tests
```bash
# Sanity check all distance functions
uv run python tests/sanity_checks/test_distance_functions.py
@@ -241,17 +288,21 @@ uv run python tests/sanity_checks/test_distance_functions.py
# Verify L2 implementation
uv run python tests/sanity_checks/test_l2_verification.py
```
## ❓ FAQ
### Common Issues
#### NCCL Topology Error
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
```
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
```
**Solution**: Set these environment variables before running your script:
```bash
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
export NCCL_DEBUG=INFO
@@ -259,23 +310,26 @@ export NCCL_DEBUG_SUBSYS=INIT,GRAPH
export NCCL_IB_DISABLE=1
export NCCL_NET_PLUGIN=none
export NCCL_SOCKET_IFNAME=ens5
```
## 📈 Roadmap
### 🎯 Q1 2024
- [x] DiskANN backend with MIPS/L2/Cosine support
- [x] HNSW backend integration
- [x] Real-time embedding pipeline
- [x] Memory-efficient graph pruning
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
### 🚀 Q2 2024
- [ ] Distributed search across multiple nodes
- [ ] ScaNN backend support
- [ ] Advanced caching strategies
- [ ] Kubernetes deployment guides
### 🌟 Q3 2024
- [ ] GPU-accelerated embedding computation
- [ ] Approximate distance functions
- [ ] Integration with LangChain/LlamaIndex
@@ -297,7 +351,7 @@ MIT License - see [LICENSE](LICENSE) for details.
## 🙏 Acknowledgments
- **Microsoft Research** for the DiskANN algorithm
- **Meta AI** for FAISS and optimization insights
- **Meta AI** for FAISS and optimization insights
- **HuggingFace** for the transformer ecosystem
- **Our amazing contributors** who make this possible
@@ -309,4 +363,5 @@ MIT License - see [LICENSE](LICENSE) for details.
<p align="center">
Made with ❤️ by the Leann team
</p>
</p>

View File

@@ -0,0 +1,130 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with embedded metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
# if part.get_content_type() == "text/html":
# continue
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
# break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content with metadata embedded in text
doc_content = f"""
[EMAIL METADATA]
File: {filename}
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
[END METADATA]
{body}
"""
# No separate metadata - everything is in the text
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
@staticmethod
def find_all_messages_directories(base_path: str) -> List[Path]:
"""
Find all Messages directories under the given base path.
Args:
base_path: Base path to search for Messages directories
Returns:
List of Path objects pointing to Messages directories
"""
base_path_obj = Path(base_path)
messages_dirs = []
if not base_path_obj.exists():
print(f"Base path {base_path} does not exist")
return messages_dirs
# Find all Messages directories recursively
for messages_dir in base_path_obj.rglob("Messages"):
if messages_dir.is_dir():
messages_dirs.append(messages_dir)
print(f"Found Messages directory: {messages_dir}")
print(f"Found {len(messages_dirs)} Messages directories")
return messages_dirs

View File

@@ -0,0 +1,192 @@
"""
Mbox parser.
Contains simple parser for mbox files.
"""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
from fsspec import AbstractFileSystem
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
logger = logging.getLogger(__name__)
class MboxReader(BaseReader):
"""
Mbox parser.
Extract messages from mailbox files.
Returns string including date, subject, sender, receiver and
content for each message.
"""
DEFAULT_MESSAGE_FORMAT: str = (
"Date: {_date}\n"
"From: {_from}\n"
"To: {_to}\n"
"Subject: {_subject}\n"
"Content: {_content}"
)
def __init__(
self,
*args: Any,
max_count: int = 0,
message_format: str = DEFAULT_MESSAGE_FORMAT,
**kwargs: Any,
) -> None:
"""Init params."""
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError(
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
)
super().__init__(*args, **kwargs)
self.max_count = max_count
self.message_format = message_format
def load_data(
self,
file: Path,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse file into string."""
# Import required libraries
import mailbox
from email.parser import BytesParser
from email.policy import default
from bs4 import BeautifulSoup
if fs:
logger.warning(
"fs was specified but MboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
i = 0
results: List[str] = []
# Load file using mailbox
bytes_parser = BytesParser(policy=default).parse
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
# Iterate through all messages
for _, _msg in enumerate(mbox):
try:
msg: mailbox.mboxMessage = _msg
# Parse multipart messages
if msg.is_multipart():
for part in msg.walk():
ctype = part.get_content_type()
cdispo = str(part.get("Content-Disposition"))
if "attachment" in cdispo:
print(f"Attachment found: {part.get_filename()}")
if ctype == "text/plain" and "attachment" not in cdispo:
content = part.get_payload(decode=True) # decode
break
# Get plain message payload for non-multipart messages
else:
content = msg.get_payload(decode=True)
# Parse message HTML content and remove unneeded whitespace
soup = BeautifulSoup(content)
stripped_content = " ".join(soup.get_text().split())
# Format message to include date, sender, receiver and subject
msg_string = self.message_format.format(
_date=msg["date"],
_from=msg["from"],
_to=msg["to"],
_subject=msg["subject"],
_content=stripped_content,
)
# Add message string to results
results.append(msg_string)
except Exception as e:
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
# Increment counter and return if max count is met
i += 1
if self.max_count > 0 and i >= self.max_count:
break
return [Document(text=result, metadata=extra_info or {}) for result in results]
class EmlxMboxReader(MboxReader):
"""
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
Extends MboxReader to work with Apple Mail's .emlx format by:
1. Reading .emlx files from a directory
2. Converting them to mbox format in memory
3. Using the parent MboxReader's parsing logic
"""
def load_data(
self,
directory: Path,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse .emlx files from directory into strings using MboxReader logic."""
import tempfile
import os
if fs:
logger.warning(
"fs was specified but EmlxMboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
# Find all .emlx files in the directory
emlx_files = list(directory.glob("*.emlx"))
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
if not emlx_files:
logger.warning(f"No .emlx files found in {directory}")
return []
# Create a temporary mbox file
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
temp_mbox_path = temp_mbox.name
# Convert .emlx files to mbox format
for emlx_file in emlx_files:
try:
# Read the .emlx file
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx format: first line is length, rest is email content
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1] # Skip the length line
# Write to mbox format (each message starts with "From " and ends with blank line)
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
except Exception as e:
logger.warning(f"Failed to process {emlx_file}: {e}")
continue
# Close the temporary file so MboxReader can read it
temp_mbox.close()
try:
# Use the parent MboxReader's logic to parse the mbox file
return super().load_data(Path(temp_mbox_path), extra_info, fs)
finally:
# Clean up temporary file
try:
os.unlink(temp_mbox_path)
except:
pass

View File

@@ -0,0 +1,229 @@
import os
import asyncio
import dotenv
from pathlib import Path
from typing import List, Any
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv()
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1):
"""
Create LEANN index from multiple mail data sources.
Args:
messages_dirs: List of Path objects pointing to Messages directories
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process per directory
"""
print("Creating LEANN index from multiple mail data sources...")
# Load documents using EmlxReader from LEANN_email_reader
from LEANN_email_reader import EmlxReader
reader = EmlxReader()
# from email_data.email import EmlxMboxReader
# from pathlib import Path
# reader = EmlxMboxReader()
all_documents = []
total_processed = 0
# Process each Messages directory
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
try:
documents = reader.load_data(messages_dir)
if documents:
print(f"Loaded {len(documents)} email documents from {messages_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {messages_dir}")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000):
"""
Create LEANN index from mail data.
Args:
mail_path: Path to the mail directory
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process
"""
print("Creating LEANN index from mail data...")
# Load documents using EmlxReader from LEANN_email_reader
from LEANN_email_reader import EmlxReader
reader = EmlxReader()
# from email_data.email import EmlxMboxReader
# from pathlib import Path
# reader = EmlxMboxReader()
documents = reader.load_data(Path(mail_path))
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} email documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=5,
recompute_beighbor_embeddings=True,
complexity=32,
beam_width=1
)
print(f"Leann: {chat_response}")
async def main():
# Base path to the mail data directory
base_mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
INDEX_DIR = Path("./mail_index_leann_raw_text_all")
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
# Find all Messages directories
from LEANN_email_reader import EmlxReader
messages_dirs = EmlxReader.find_all_messages_directories(base_mail_path)
if not messages_dirs:
print("No Messages directories found. Exiting.")
return
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH)
if index_path:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,86 @@
import os
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.node_parser import SentenceSplitter
# --- EMBEDDING MODEL ---
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import torch
# --- END EMBEDDING MODEL ---
# Import EmlxReader from the new module
from LEANN_email_reader import EmlxReader
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000):
print("Creating index from mail data with embedded metadata...")
documents = EmlxReader().load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Use facebook/contriever as the embedder
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
# set on device
import torch
if torch.cuda.is_available():
embed_model._model.to("cuda")
# set mps
elif torch.backends.mps.is_available():
embed_model._model.to("mps")
else:
embed_model._model.to("cpu")
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter],
embed_model=embed_model
)
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index_embedded"):
try:
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
save_dir = "mail_index_embedded"
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
index = load_index(save_dir)
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=10000)
if index:
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying"
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*50)
query_index(index, query)
if __name__ == "__main__":
main()

View File

@@ -1,13 +1,13 @@
import faulthandler
faulthandler.enable()
import argparse
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.readers.base import BaseReader
from llama_index.node_parser.docling import DoclingNodeParser
from llama_index.readers.docling import DoclingReader
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
import asyncio
import os
import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat
import shutil
@@ -21,9 +21,11 @@ file_extractor: dict[str, BaseReader] = {
".pptx": reader,
".pdf": reader,
".xlsx": reader,
".txt": reader,
".md": reader,
}
node_parser = DoclingNodeParser(
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=64)
chunker=HybridChunker(tokenizer="facebook/contriever", max_tokens=128)
)
print("Loading documents...")
documents = SimpleDirectoryReader(
@@ -31,7 +33,7 @@ documents = SimpleDirectoryReader(
recursive=True,
file_extractor=file_extractor,
encoding="utf-8",
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
required_exts=[".pdf", ".docx", ".pptx", ".xlsx", ".txt", ".md"]
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
@@ -40,7 +42,7 @@ for doc in documents:
for node in nodes:
all_texts.append(node.get_content())
INDEX_DIR = Path("./test_pdf_index")
INDEX_DIR = Path("./test_pdf_index_pangu_test")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if not INDEX_DIR.exists():
@@ -48,14 +50,15 @@ if not INDEX_DIR.exists():
print(f"\n[PHASE 1] Building Leann index...")
# CSR compact mode with recompute
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="diskann",
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Loaded {len(all_texts)} text chunks from documents.")
@@ -67,14 +70,30 @@ if not INDEX_DIR.exists():
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
async def main():
async def main(args):
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
llm_config = {
"type": args.llm,
"model": args.model,
"host": args.host
}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32)
print(f"Leann: {chat_response}")
if __name__ == "__main__":
asyncio.run(main())
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf", "openai"], help="The LLM backend to use.")
parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).")
parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.")
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Multi-Vector Aggregator for Fat Embeddings
==========================================
This module implements aggregation strategies for multi-vector embeddings,
similar to ColPali's approach where multiple patch vectors represent a single document.
Key features:
- MaxSim aggregation (take maximum similarity across patches)
- Voting-based aggregation (count patch matches)
- Weighted aggregation (attention-score weighted)
- Spatial clustering of matching patches
- Document-level result consolidation
"""
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import json
@dataclass
class PatchResult:
"""Represents a single patch search result."""
patch_id: int
image_name: str
image_path: str
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
score: float
attention_score: float
scale: float
metadata: Dict[str, Any]
@dataclass
class AggregatedResult:
"""Represents an aggregated document-level result."""
image_name: str
image_path: str
doc_score: float
patch_count: int
best_patch: PatchResult
all_patches: List[PatchResult]
aggregation_method: str
spatial_clusters: Optional[List[List[PatchResult]]] = None
class MultiVectorAggregator:
"""
Aggregates multiple patch-level results into document-level results.
"""
def __init__(self,
aggregation_method: str = "maxsim",
spatial_clustering: bool = True,
cluster_distance_threshold: float = 100.0):
"""
Initialize the aggregator.
Args:
aggregation_method: "maxsim", "voting", "weighted", or "mean"
spatial_clustering: Whether to cluster spatially close patches
cluster_distance_threshold: Distance threshold for spatial clustering
"""
self.aggregation_method = aggregation_method
self.spatial_clustering = spatial_clustering
self.cluster_distance_threshold = cluster_distance_threshold
def aggregate_results(self,
search_results: List[Dict[str, Any]],
top_k: int = 10) -> List[AggregatedResult]:
"""
Aggregate patch-level search results into document-level results.
Args:
search_results: List of search results from LeannSearcher
top_k: Number of top documents to return
Returns:
List of aggregated document results
"""
# Group results by image
image_groups = defaultdict(list)
for result in search_results:
metadata = result.metadata
if "image_name" in metadata and "patch_id" in metadata:
patch_result = PatchResult(
patch_id=metadata["patch_id"],
image_name=metadata["image_name"],
image_path=metadata["image_path"],
coordinates=tuple(metadata["coordinates"]),
score=result.score,
attention_score=metadata.get("attention_score", 0.0),
scale=metadata.get("scale", 1.0),
metadata=metadata
)
image_groups[metadata["image_name"]].append(patch_result)
# Aggregate each image group
aggregated_results = []
for image_name, patches in image_groups.items():
if len(patches) == 0:
continue
agg_result = self._aggregate_image_patches(image_name, patches)
aggregated_results.append(agg_result)
# Sort by aggregated score and return top-k
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
return aggregated_results[:top_k]
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
"""Aggregate patches for a single image."""
if self.aggregation_method == "maxsim":
doc_score = max(patch.score for patch in patches)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "voting":
# Count patches above threshold
threshold = np.percentile([p.score for p in patches], 75)
doc_score = sum(1 for patch in patches if patch.score >= threshold)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "weighted":
# Weight by attention scores
total_weighted_score = sum(p.score * p.attention_score for p in patches)
total_weights = sum(p.attention_score for p in patches)
doc_score = total_weighted_score / max(total_weights, 1e-8)
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
elif self.aggregation_method == "mean":
doc_score = np.mean([patch.score for patch in patches])
best_patch = max(patches, key=lambda p: p.score)
else:
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
# Spatial clustering if enabled
spatial_clusters = None
if self.spatial_clustering:
spatial_clusters = self._cluster_patches_spatially(patches)
return AggregatedResult(
image_name=image_name,
image_path=patches[0].image_path,
doc_score=float(doc_score),
patch_count=len(patches),
best_patch=best_patch,
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
aggregation_method=self.aggregation_method,
spatial_clusters=spatial_clusters
)
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
"""Cluster patches that are spatially close to each other."""
if len(patches) <= 1:
return [patches]
clusters = []
remaining_patches = patches.copy()
while remaining_patches:
# Start new cluster with highest scoring remaining patch
seed_patch = max(remaining_patches, key=lambda p: p.score)
current_cluster = [seed_patch]
remaining_patches.remove(seed_patch)
# Add nearby patches to cluster
added_to_cluster = True
while added_to_cluster:
added_to_cluster = False
for patch in remaining_patches.copy():
if self._is_patch_nearby(patch, current_cluster):
current_cluster.append(patch)
remaining_patches.remove(patch)
added_to_cluster = True
clusters.append(current_cluster)
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
"""Check if a patch is spatially close to any patch in the cluster."""
patch_center = self._get_patch_center(patch.coordinates)
for cluster_patch in cluster:
cluster_center = self._get_patch_center(cluster_patch.coordinates)
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
(patch_center[1] - cluster_center[1])**2)
if distance <= self.cluster_distance_threshold:
return True
return False
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
"""Get center point of a patch."""
x1, y1, x2, y2 = coordinates
return ((x1 + x2) / 2, (y1 + y2) / 2)
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
"""Pretty print aggregated results."""
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
print("=" * 80)
for i, result in enumerate(results):
print(f"\n{i+1}. {result.image_name}")
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
print(f" Path: {result.image_path}")
# Show best patch
best = result.best_patch
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
# Show top patches
print(f" 📍 Top Patches:")
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
# Show spatial clusters if available
if result.spatial_clusters and len(result.spatial_clusters) > 1:
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
cluster_score = max(p.score for p in cluster)
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
def demo_aggregation():
"""Demonstrate the multi-vector aggregation functionality."""
print("=== Multi-Vector Aggregation Demo ===")
# Simulate some patch-level search results
# In real usage, these would come from LeannSearcher.search()
class MockResult:
def __init__(self, score, metadata):
self.score = score
self.metadata = metadata
# Simulate results for 2 images with multiple patches each
mock_results = [
# Image 1: cats_and_kitchen.jpg - 4 patches
MockResult(0.85, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 3,
"coordinates": [100, 50, 224, 174], # Kitchen area
"attention_score": 0.92,
"scale": 1.0
}),
MockResult(0.78, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 7,
"coordinates": [200, 300, 324, 424], # Cat area
"attention_score": 0.88,
"scale": 1.0
}),
MockResult(0.72, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 12,
"coordinates": [150, 100, 274, 224], # Appliances
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.65, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 15,
"coordinates": [50, 250, 174, 374], # Furniture
"attention_score": 0.70,
"scale": 1.0
}),
# Image 2: city_street.jpg - 3 patches
MockResult(0.68, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 2,
"coordinates": [300, 100, 424, 224], # Buildings
"attention_score": 0.80,
"scale": 1.0
}),
MockResult(0.62, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 8,
"coordinates": [100, 350, 224, 474], # Street level
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.55, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 11,
"coordinates": [400, 200, 524, 324], # Sky area
"attention_score": 0.60,
"scale": 1.0
}),
]
# Test different aggregation methods
methods = ["maxsim", "voting", "weighted", "mean"]
for method in methods:
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
aggregator = MultiVectorAggregator(
aggregation_method=method,
spatial_clustering=True,
cluster_distance_threshold=100.0
)
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
aggregator.print_aggregated_results(aggregated)
if __name__ == "__main__":
demo_aggregation()

18
examples/resue_index.py Normal file
View File

@@ -0,0 +1,18 @@
import asyncio
from leann.api import LeannChat
from pathlib import Path
INDEX_DIR = Path("./test_pdf_index_huawei")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
async def main():
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
print(f"\n[PHASE 2] Response: {response}")
if __name__ == "__main__":
asyncio.run(main())

157
examples/run_evaluation.py Normal file
View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import json
import argparse
import time
from pathlib import Path
import sys
import numpy as np
from typing import List, Dict, Any
import glob
import pickle
# Add project root to path to allow importing from leann
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from leann.api import LeannSearcher
# --- Configuration ---
NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl")
# Ground truth files for different datasets
GROUND_TRUTH_FILES = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json",
"dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json"
}
# Old passages for different datasets
OLD_PASSAGES_GLOBS = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl",
"dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl"
}
# --- Helper Class to Load Original Passages ---
class OldPassageLoader:
"""A simplified version of the old LazyPassages class to fetch golden results by ID."""
def __init__(self, passages_glob: str):
self.jsonl_paths = sorted(glob.glob(passages_glob))
self.offsets = {}
self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths]
print("Building offset map for original passages...")
for i, shard_path_str in enumerate(self.jsonl_paths):
old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx"))
if not old_idx_path.exists(): continue
with open(old_idx_path, 'rb') as f:
shard_offsets = pickle.load(f)
for pid, offset in shard_offsets.items():
self.offsets[str(pid)] = (i, offset)
print("Offset map for original passages is ready.")
def get_passage_by_id(self, pid: str) -> Dict[str, Any]:
pid = str(pid)
if pid not in self.offsets:
raise ValueError(f"Passage ID {pid} not found in offsets")
file_idx, offset = self.offsets[pid]
fp = self.fps[file_idx]
fp.seek(offset)
return json.loads(fp.readline())
def __del__(self):
for fp in self.fps:
fp.close()
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
queries.append(data['query'])
return queries
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
args = parser.parse_args()
print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---")
# Detect dataset type from index path
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
print("WARNING: Unknown dataset type, defaulting to rpj_wiki")
dataset_type = "rpj_wiki"
print(f"INFO: Detected dataset type: {dataset_type}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(NQ_QUERIES_FILE)
golden_results_file = GROUND_TRUTH_FILES[dataset_type]
old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type]
print(f"INFO: Using ground truth file: {golden_results_file}")
print(f"INFO: Using old passages glob: {old_passages_glob}")
with open(golden_results_file, 'r') as f:
golden_results_data = json.load(f)
old_passage_loader = OldPassageLoader(old_passages_glob)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
golden_ids = golden_results_data["indices"][i][:args.top_k]
golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids}
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print(f"--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print(f"\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -1,32 +0,0 @@
{
"version": "0.1.0",
"backend_name": "diskann",
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
"num_chunks": 6,
"chunks": [
{
"text": "Python is a powerful programming language",
"metadata": {}
},
{
"text": "Machine learning transforms industries",
"metadata": {}
},
{
"text": "Neural networks process complex data",
"metadata": {}
},
{
"text": "Java is a powerful programming language",
"metadata": {}
},
{
"text": "C++ is a powerful programming language",
"metadata": {}
},
{
"text": "C# is a powerful programming language",
"metadata": {}
}
]
}

1
packages/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@
# This file makes the directory a Python package

View File

@@ -5,21 +5,16 @@ import struct
from pathlib import Path
from typing import Dict, Any, List
import contextlib
import threading
import time
import atexit
import socket
import subprocess
import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface
)
def _get_diskann_metrics():
from . import _diskannpy as diskannpy
return {
@@ -52,211 +47,87 @@ class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f:
meta = json.load(f)
# Pass essential metadata to the searcher
kwargs['meta'] = meta
return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(metric_str)
metric_enum = _get_diskann_metrics().get(build_kwargs.get("distance_metric", "mips").lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
raise ValueError(f"Unsupported distance_metric.")
complexity = build_kwargs.get("complexity", 64)
graph_degree = build_kwargs.get("graph_degree", 32)
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = ""
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
from . import _diskannpy as diskannpy
with chdir(index_dir):
diskannpy.build_disk_float_index(
metric_enum,
data_filename,
index_prefix,
complexity,
graph_degree,
final_index_ram_limit,
indexing_ram_budget,
num_threads,
pq_disk_bytes,
codebook_prefix
metric_enum, data_filename, index_prefix,
build_kwargs.get("complexity", 64), build_kwargs.get("graph_degree", 32),
build_kwargs.get("search_memory_maximum", 4.0), build_kwargs.get("build_memory_maximum", 8.0),
build_kwargs.get("num_threads", 8), build_kwargs.get("pq_disk_bytes", 0), ""
)
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise
finally:
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
os.remove(temp_data_file)
class DiskannSearcher(LeannBackendSearcherInterface):
class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
self.meta = kwargs.get("meta", {})
if not self.meta:
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
super().__init__(index_path, backend_module_name="leann_backend_diskann.embedding_server", **kwargs)
from . import _diskannpy as diskannpy
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
path = Path(index_path)
self.index_dir = path.parent
self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
# Extract parameters for DiskANN
distance_metric = kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(distance_metric)
metric_enum = _get_diskann_metrics().get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
self.num_threads = kwargs.get("num_threads", 8)
self.zmq_port = kwargs.get("zmq_port", 6666)
try:
from . import _diskannpy as diskannpy
full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
)
self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_diskann.embedding_server"
)
print("✅ DiskANN index loaded successfully.")
except Exception as e:
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, self.num_threads,
kwargs.get("num_nodes_to_cache", 0), 1, self.zmq_port, "", ""
)
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
complexity = kwargs.get("complexity", 256)
beam_width = kwargs.get("beam_width", 4)
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
skip_search_reorder = kwargs.get("skip_search_reorder", False)
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False)
dedup_node_dis = kwargs.get("dedup_node_dis", False)
prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False)
port = kwargs.get("zmq_port", self.zmq_port)
if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
if not self.embedding_model:
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
recompute = kwargs.get("recompute_beighbor_embeddings", False)
if recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}")
zmq_port = kwargs.get("zmq_port", self.zmq_port)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
passages_file = kwargs.get("passages_file")
if not passages_file:
# Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
passage_source = self.meta['passage_sources'][0]
passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
else:
raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.")
server_started = self.embedding_server_manager.start_server(
port=self.zmq_port,
model_name=self.embedding_model,
distance_metric=kwargs.get("distance_metric", "mips"),
passages_file=passages_file
)
if not server_started:
raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
if query.dtype != np.float32:
query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
try:
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
USE_DEFERRED_FETCH,
skip_search_reorder,
recompute_beighbor_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
global_pruning
)
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0]
return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)],
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()
labels, distances = self._index.batch_search(
query, query.shape[0], top_k,
kwargs.get("complexity", 256), kwargs.get("beam_width", 4), self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False), kwargs.get("skip_search_reorder", False),
recompute, kwargs.get("dedup_node_dis", False), kwargs.get("prune_ratio", 0.0),
kwargs.get("batch_recompute", False), kwargs.get("global_pruning", False)
)
string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels]
return {"labels": string_labels, "distances": distances}

View File

@@ -15,6 +15,8 @@ import os
from contextlib import contextmanager
import zmq
import numpy as np
from pathlib import Path
import pickle
RED = "\033[91m"
RESET = "\033[0m"
@@ -39,13 +41,76 @@ class SimplePassageLoader:
def __len__(self) -> int:
return len(self.passages_data)
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""
Load passages using metadata file with PassageManager for lazy loading
"""
# Load metadata to get passage sources
with open(meta_file, 'r') as f:
meta = json.load(f)
# Import PassageManager dynamically to avoid circular imports
import sys
from pathlib import Path
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
passage_manager = PassageManager(meta['passage_sources'])
finally:
sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
else:
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
except Exception as e:
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
def __len__(self) -> int:
return len(self.label_map)
return LazyPassageLoader(passage_manager, label_map)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
from pathlib import Path
import pickle
if not os.path.exists(passages_file):
raise FileNotFoundError(f"Passages file {passages_file} not found.")
@@ -140,7 +205,20 @@ def create_embedding_server_thread(
# Load passages from file if provided
if passages_file and os.path.exists(passages_file):
passages = load_passages_from_file(passages_file)
# Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = load_passages_from_file(passages_file)
else:
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader()

View File

@@ -2,6 +2,33 @@
cmake_minimum_required(VERSION 3.24)
project(leann_backend_hnsw_wrapper)
# Set OpenMP path for macOS
if(APPLE)
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
set(OpenMP_C_LIB_NAMES "omp")
set(OpenMP_CXX_LIB_NAMES "omp")
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
endif()
# Build ZeroMQ from source
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
add_subdirectory(third_party/libzmq)
# Add cppzmq headers
include_directories(third_party/cppzmq)
# Configure msgpack-c - disable boost dependency
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
add_compile_definitions(MSGPACK_NO_BOOST)
include_directories(third_party/msgpack-c/include)
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)

View File

@@ -1,19 +1,12 @@
import numpy as np
import os
import json
import struct
from pathlib import Path
from typing import Dict, Any, List
import contextlib
import threading
import time
import atexit
import socket
import subprocess
import sys
import pickle
import shutil
from leann.embedding_server_manager import EmbeddingServerManager
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend
@@ -39,345 +32,130 @@ class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f:
meta = json.load(f)
kwargs['meta'] = meta
return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs.copy()
# --- Configuration defaults with standardized names ---
self.is_compact = self.build_params.setdefault("is_compact", True)
self.is_recompute = self.build_params.setdefault("is_recompute", True)
# --- Additional Options ---
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
self.disk_cache_ratio = self.build_params.setdefault("disk_cache_ratio", 0.0)
self.external_storage_path = self.build_params.get("external_storage_path", None)
# --- Standard HNSW parameters ---
self.M = self.build_params.setdefault("M", 32)
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions")
if self.is_skip_neighbors and not self.is_compact:
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
"""Build HNSW index using FAISS"""
from . import faiss
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
metric_str = self.distance_metric.lower()
metric_enum = get_metric_map().get(metric_str)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
M = self.M
efConstruction = self.efConstruction
dim = self.dimensions
if not dim:
dim = data.shape[1]
dim = self.dimensions or data.shape[1]
index = faiss.IndexHNSWFlat(dim, self.M, metric_enum)
index.hnsw.efConstruction = self.efConstruction
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
index.hnsw.efConstruction = efConstruction
if metric_str == "cosine":
faiss.normalize_L2(data)
index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
print(f"✅ HNSW index built successfully at '{index_file}'")
if self.distance_metric.lower() == "cosine":
faiss.normalize_L2(data)
if self.is_compact:
self._convert_to_csr(index_file)
except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise
index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
if self.is_compact:
self._convert_to_csr(index_file)
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
try:
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr(
str(index_file),
str(csr_temp_file),
prune_embeddings=self.is_recompute
)
if success:
print("✅ CSR conversion successful.")
import shutil
shutil.move(str(csr_temp_file), str(index_file))
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
else:
# Clean up and fail fast
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
except Exception as e:
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
"""
Robustly determines the index's storage status by parsing the file.
Returns:
A tuple (is_compact, is_pruned).
"""
if not index_file.exists():
return False, False
with open(index_file, 'rb') as f:
try:
def read_struct(fmt):
size = struct.calcsize(fmt)
data = f.read(size)
if len(data) != size:
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'.")
return struct.unpack(fmt, data)[0]
success = convert_hnsw_graph_to_csr(
str(index_file),
str(csr_temp_file),
prune_embeddings=self.is_recompute
)
def skip_vector(element_size):
count = read_struct('<Q')
f.seek(count * element_size, 1)
# 1. Read up to the compact flag
read_struct('<I'); read_struct('<i'); read_struct('<q');
read_struct('<q'); read_struct('<q'); read_struct('<?')
metric_type = read_struct('<i')
if metric_type > 1: read_struct('<f')
skip_vector(8); skip_vector(4); skip_vector(4)
# 2. Check if there's a compact flag byte
# Try to read the compact flag, but handle both old and new formats
pos_before_compact = f.tell()
try:
is_compact = read_struct('<?')
print(f"INFO: Detected is_compact flag as: {is_compact}")
except (EOFError, struct.error):
# Old format without compact flag - assume non-compact
f.seek(pos_before_compact)
is_compact = False
print(f"INFO: No compact flag found, assuming is_compact=False")
# 3. Read storage FourCC to determine if pruned
is_pruned = False
try:
if is_compact:
# For compact, we need to skip pointers and scalars to get to the storage FourCC
skip_vector(8) # level_ptr
skip_vector(8) # node_offsets
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
storage_fourcc = read_struct('<I')
else:
# For non-compact, we need to read the flag probe, then skip offsets and neighbors
pos_before_probe = f.tell()
flag_byte = f.read(1)
if not (flag_byte and flag_byte == b'\x00'):
f.seek(pos_before_probe)
skip_vector(8); skip_vector(4) # offsets, neighbors
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
# Now we are at the storage. The entire rest is storage blob.
storage_fourcc = struct.unpack('<I', f.read(4))[0]
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
if storage_fourcc == NULL_INDEX_FOURCC:
is_pruned = True
except (EOFError, struct.error):
# Cannot determine pruning status, assume not pruned
pass
print(f"INFO: Detected is_pruned as: {is_pruned}")
return is_compact, is_pruned
except (EOFError, struct.error) as e:
print(f"WARNING: Could not parse index file to detect format: {e}. Assuming standard, not pruned.")
return False, False
if success:
print("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
else:
# Clean up and fail fast
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
class HNSWSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
super().__init__(index_path, backend_module_name="leann_backend_hnsw.hnsw_embedding_server", **kwargs)
from . import faiss
self.meta = kwargs.get("meta", {})
if not self.meta:
raise ValueError("HNSWSearcher requires metadata from .meta.json.")
self.dimensions = self.meta.get("dimensions")
if not self.dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(self.distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
self.is_compact, self.is_pruned = (
self.meta.get('is_compact', True),
self.meta.get('is_pruned', True)
)
path = Path(index_path)
self.index_dir = path.parent
self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
index_file = self.index_dir / f"{self.index_prefix}.index"
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
# Validate configuration constraints
if not self.is_compact and kwargs.get("is_skip_neighbors", False):
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if kwargs.get("is_recompute", False) and kwargs.get("external_storage_path"):
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact
# Apply additional configuration options with strict validation
hnsw_config.is_skip_neighbors = kwargs.get("is_skip_neighbors", False)
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
self.zmq_port = kwargs.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
print(f"INFO: Loading index with is_compact={self.is_compact}, is_pruned={self.is_pruned}")
print(f"INFO: Config - skip_neighbors={hnsw_config.is_skip_neighbors}, recompute={hnsw_config.is_recompute}")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
if self.is_compact:
print("✅ Compact CSR format HNSW index loaded successfully.")
else:
print("✅ Standard HNSW index loaded successfully.")
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned but recompute is disabled.")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""Search using HNSW index with optional recompute functionality"""
from . import faiss
ef = kwargs.get("ef", 200)
if self.is_pruned:
print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.")
if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
passages_file = kwargs.get("passages_file")
if not passages_file:
# Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
passage_source = self.meta['passage_sources'][0]
passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
else:
raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.")
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(f"FATAL: Index is pruned but metadata file not found: {meta_file_path}")
zmq_port = kwargs.get("zmq_port", 5557)
server_started = self.embedding_server_manager.start_server(
port=zmq_port,
model_name=self.embedding_model,
passages_file=passages_file,
distance_metric=self.distance_metric
)
if not server_started:
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32:
query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
if self.distance_metric == "cosine":
faiss.normalize_L2(query)
try:
params = faiss.SearchParametersHNSW()
params.efSearch = ef
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
raise
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()
params = faiss.SearchParametersHNSW()
params.zmq_port = kwargs.get("zmq_port", 5557)
params.efSearch = kwargs.get("complexity", 32)
params.beam_size = kwargs.get("beam_width", 1)
batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels]
return {"labels": string_labels, "distances": distances}

View File

@@ -56,22 +56,33 @@ class SimplePassageLoader:
def __len__(self) -> int:
return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
Load passages using metadata file with PassageManager for lazy loading
"""
if not os.path.exists(passages_file):
raise FileNotFoundError(f"Passages file {passages_file} not found.")
# Load metadata to get passage sources
with open(meta_file, 'r') as f:
meta = json.load(f)
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Import PassageManager dynamically to avoid circular imports
import sys
import importlib.util
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
passage_manager = PassageManager(meta['passage_sources'])
finally:
sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
@@ -80,24 +91,38 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
print(f"Initialized lazy passage loading for {len(label_map)} passages")
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
print(f"DEBUG: Empty text for ID {int_id} -> {string_id}")
return {"text": ""}
else:
print(f"DEBUG: ID {int_id} not found in label_map")
return {"text": ""}
except Exception as e:
print(f"DEBUG: Exception getting passage {passage_id}: {e}")
return {"text": ""}
def __len__(self) -> int:
return len(self.label_map)
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
return LazyPassageLoader(passage_manager, label_map)
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,
@@ -183,7 +208,20 @@ def create_hnsw_embedding_server(
passages = SimplePassageLoader(passages_data)
print(f"Using provided passages data: {len(passages)} passages")
elif passages_file:
passages = load_passages_from_file(passages_file)
# Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = SimplePassageLoader() # Use empty loader to avoid massive warnings
else:
passages = SimplePassageLoader()
print("No passages provided, using empty loader")
@@ -252,6 +290,11 @@ def create_hnsw_embedding_server(
_is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch)
# Validate no empty texts
for i, text in enumerate(texts_batch):
if not text or text.strip() == "":
raise RuntimeError(f"FATAL: Empty text at batch index {i}, ID: {ids_batch[i] if i < len(ids_batch) else 'unknown'}")
# E5 model preprocessing
if _is_e5_model:
processed_texts_batch = [f"passage: {text}" for text in texts_batch]
@@ -398,14 +441,12 @@ def create_hnsw_embedding_server(
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
try:
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
else:
txt = txtinfo["text"]
except (KeyError, IndexError):
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
print(f"DEBUG: Looking up passage ID {nid}")
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} returned empty text")
txt = txtinfo["text"]
print(f"DEBUG: Found text for ID {nid}, length: {len(txt)}")
texts.append(txt)
lookup_timer.print_elapsed()

View File

@@ -1,4 +1,4 @@
# 文件: packages/leann-backend-hnsw/pyproject.toml
# packages/leann-backend-hnsw/pyproject.toml
[build-system]
requires = ["scikit-build-core>=0.10", "numpy", "swig"]
@@ -10,7 +10,6 @@ version = "0.1.0"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = ["leann-core==0.1.0", "numpy"]
# 回归到最标准的 scikit-build-core 配置
[tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"]
editable.mode = "redirect"

View File

@@ -1,4 +1,14 @@
# packages/leann-core/src/leann/__init__.py
import os
import platform
# Fix OpenMP threading issues on macOS ARM64
if platform.system() == "Darwin":
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["KMP_BLOCKTIME"] = "0"
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends

View File

@@ -1,334 +1,203 @@
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from typing import List, Dict, Any, Optional
import numpy as np
import os
"""
This file contains the core API for the LEANN project, now definitively updated
with the correct, original embedding logic from the user's reference code.
"""
import json
import pickle
import numpy as np
from pathlib import Path
import openai
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
import uuid
import pickle
import torch
# --- Helper Functions for Embeddings ---
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
def _get_openai_client():
"""Initializes and returns an OpenAI client, ensuring the API key is set."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set, which is required for OpenAI models.")
return openai.OpenAI(api_key=api_key)
# --- The Correct, Verified Embedding Logic from old_code.py ---
def _is_openai_model(model_name: str) -> bool:
"""Checks if the model is likely an OpenAI embedding model."""
# This is a simple check, can be improved with a more robust list.
return "ada" in model_name or "babbage" in model_name or model_name.startswith("text-embedding-")
def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings for a list of text chunks using either SentenceTransformers or OpenAI."""
if _is_openai_model(model_name):
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
client = _get_openai_client()
response = client.embeddings.create(model=model_name, input=chunks)
embeddings = [item.embedding for item in response.data]
else:
def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using sentence-transformers for consistent results."""
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
embeddings = model.encode(chunks, show_progress_bar=True)
return np.asarray(embeddings, dtype=np.float32)
except ImportError as e:
raise RuntimeError(
f"sentence-transformers not available. Install with: pip install sentence-transformers"
) from e
def _get_embedding_dimensions(model_name: str) -> int:
"""Gets the embedding dimensions for a given model."""
print(f"INFO: Calculating dimensions for model '{model_name}'...")
if _is_openai_model(model_name):
client = _get_openai_client()
response = client.embeddings.create(model=model_name, input=["dummy text"])
return len(response.data[0].embedding)
else:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
dimension = model.get_sentence_embedding_dimension()
if dimension is None:
raise ValueError(f"Model '{model_name}' does not have a valid embedding dimension.")
return dimension
# Load model using sentence-transformers
model = SentenceTransformer(model_name)
model = model.half()
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
# use acclerater GPU or MAC GPU
if torch.cuda.is_available():
model = model.to("cuda")
elif torch.backends.mps.is_available():
model = model.to("mps")
# Generate embeddings
embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64)
return embeddings
# --- Core API Classes (Restored and Unchanged) ---
@dataclass
class SearchResult:
"""Represents a single search result."""
id: str
score: float
text: str
metadata: Dict[str, Any] = field(default_factory=dict)
class PassageManager:
"""Manages passage data and lazy loading from JSONL files."""
def __init__(self, passage_sources: List[Dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not os.path.exists(index_file):
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, 'rb') as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
"""Lazy load a passage by ID."""
for passage_file, offset_map in self.offset_maps.items():
if passage_id in offset_map:
offset = offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f:
f.seek(offset)
line = f.readline()
return json.loads(line)
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f:
f.seek(offset)
return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}")
# --- Core Classes ---
class LeannBuilder:
"""
The builder is responsible for building the index, it will compute the embeddings and then build the index.
It will also save the metadata of the index.
"""
def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", dimensions: Optional[int] = None, **backend_kwargs):
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, **backend_kwargs):
self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory
self.embedding_model = embedding_model
self.dimensions = dimensions
self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = []
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None:
metadata = {}
# Check if ID is provided in metadata
passage_id = metadata.get('id')
if passage_id is None:
passage_id = str(uuid.uuid4())
else:
# Validate uniqueness
existing_ids = {chunk['id'] for chunk in self.chunks}
if passage_id in existing_ids:
raise ValueError(f"Duplicate passage ID: {passage_id}")
# Store the definitive ID with the chunk
chunk_data = {
"id": passage_id,
"text": text,
"metadata": metadata
}
if metadata is None: metadata = {}
passage_id = metadata.get('id', str(uuid.uuid4()))
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
self.chunks.append(chunk_data)
def build_index(self, index_path: str):
if not self.chunks:
raise ValueError("No chunks added. Use add_text() first.")
if self.dimensions is None:
self.dimensions = _get_embedding_dimensions(self.embedding_model)
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
if not self.chunks: raise ValueError("No chunks added.")
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model)[0])
path = Path(index_path)
index_dir = path.parent
index_name = path.name
# Ensure the directory exists
index_dir.mkdir(parents=True, exist_ok=True)
# Create the passages.jsonl file and offset index
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
offset_map = {}
with open(passages_file, 'w', encoding='utf-8') as f:
for chunk in self.chunks:
offset = f.tell()
passage_data = {
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"]
}
json.dump(passage_data, f, ensure_ascii=False)
json.dump({"id": chunk["id"], "text": chunk["text"], "metadata": chunk["metadata"]}, f, ensure_ascii=False)
f.write('\n')
offset_map[chunk["id"]] = offset
# Save the offset map
with open(offset_file, 'wb') as f:
pickle.dump(offset_map, f)
# Compute embeddings
with open(offset_file, 'wb') as f: pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
# Extract string IDs for the backend
embeddings = compute_embeddings(texts_to_embed, self.embedding_model)
string_ids = [chunk["id"] for chunk in self.chunks]
# Build the vector index
current_backend_kwargs = self.backend_kwargs.copy()
current_backend_kwargs['dimensions'] = self.dimensions
current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
# Create the lightweight meta.json file
leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = {
"version": "1.0",
"backend_name": self.backend_name,
"embedding_model": self.embedding_model,
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs,
"passage_sources": [
{
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file)
}
]
"version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model,
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs,
"passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}]
}
with open(leann_meta_path, 'w', encoding='utf-8') as f:
json.dump(meta_data, f, indent=2)
print(f"INFO: Leann metadata saved to {leann_meta_path}")
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute # Pruned only if compact and recompute
with open(leann_meta_path, 'w', encoding='utf-8') as f: json.dump(meta_data, f, indent=2)
class LeannSearcher:
"""
The searcher is responsible for loading the index and performing the search.
It will also load the metadata of the index.
"""
def __init__(self, index_path: str, **backend_kwargs):
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
if not leann_meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}. Was the index built with LeannBuilder?")
with open(leann_meta_path, 'r', encoding='utf-8') as f:
self.meta_data = json.load(f)
meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists(): raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f)
backend_name = self.meta_data['backend_name']
self.embedding_model = self.meta_data['embedding_model']
# Initialize the passage manager
passage_sources = self.meta_data.get('passage_sources', [])
self.passage_manager = PassageManager(passage_sources)
self.passage_manager = PassageManager(self.meta_data.get('passage_sources', []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
final_kwargs = backend_kwargs.copy()
final_kwargs['meta'] = self.meta_data
if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get('backend_kwargs', {}), **backend_kwargs}
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
def search(self, query: str, top_k: int = 5, **search_kwargs):
query_embedding = _compute_embeddings([query], self.embedding_model)
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
print(f"🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'")
print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}")
query_embedding = compute_embeddings([query], self.embedding_model)
print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
search_kwargs['embedding_model'] = self.embedding_model
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
print(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
enriched_results = []
for string_id, dist in zip(results['labels'][0], results['distances'][0]):
try:
passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult(
id=string_id,
score=dist,
text=passage_data['text'],
metadata=passage_data.get('metadata', {})
))
except KeyError:
print(f"WARNING: Passage ID '{string_id}' not found in passage files")
if 'labels' in results and 'distances' in results:
print(f" Processing {len(results['labels'][0])} passage IDs:")
for i, (string_id, dist) in enumerate(zip(results['labels'][0], results['distances'][0])):
try:
passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult(
id=string_id, score=dist, text=passage_data['text'], metadata=passage_data.get('metadata', {})
))
print(f" {i+1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text'][:60]}...")
except KeyError:
print(f" {i+1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!")
print(f" Final enriched results: {len(enriched_results)} passages")
return enriched_results
from .chat import get_llm
class LeannChat:
"""
The chat is responsible for the conversation with the LLM.
It will use the searcher to get the results and then use the LLM to generate the response.
"""
def __init__(self, index_path: str, backend_name: Optional[str] = None, llm_model: str = "gpt-4o", **kwargs):
if backend_name is None:
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
if not leann_meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}.")
with open(leann_meta_path, 'r', encoding='utf-8') as f:
meta_data = json.load(f)
backend_name = meta_data['backend_name']
def __init__(self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs):
self.searcher = LeannSearcher(index_path, **kwargs)
self.llm_model = llm_model
def ask(self, question: str, top_k=5, **kwargs):
"""
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
self.llm = get_llm(llm_config)
def ask(self, question: str, top_k=5, **kwargs):
results = self.searcher.search(question, top_k=top_k, **kwargs)
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
f"{context}\n\n"
f"Question: {question}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
print(f"DEBUG: Calling LLM with prompt: {prompt}...")
try:
client = _get_openai_client()
response = client.chat.completions.create(
model=self.llm_model,
messages=[
{"role": "system", "content": "You are a helpful assistant that answers questions based on the provided context."},
{"role": "user", "content": prompt}
]
)
return response.choices[0].message.content
except Exception as e:
print(f"ERROR: Failed to call OpenAI API: {e}")
return f"Error: Could not get a response from the LLM. {e}"
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")
while True:
@@ -342,4 +211,4 @@ class LeannChat:
print(f"Leann: {response}")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
break

View File

@@ -0,0 +1,229 @@
#!/usr/bin/env python3
"""
This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import logging
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LLMInterface(ABC):
"""Abstract base class for a generic Language Model (LLM) interface."""
@abstractmethod
def ask(self, prompt: str, **kwargs) -> str:
"""
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
# """
# Sends a prompt to the LLM and returns the generated text.
# Args:
# prompt: The input prompt for the LLM.
# **kwargs: Additional keyword arguments for the LLM backend.
# Returns:
# The response string from the LLM.
# """
pass
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
except ImportError:
raise ImportError("The 'requests' library is required for Ollama. Please install it with 'pip install requests'.")
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
raise ConnectionError(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
def ask(self, prompt: str, **kwargs) -> str:
import requests
import json
full_url = f"{self.host}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"stream": False, # Keep it simple for now
"options": kwargs
}
logger.info(f"Sending request to Ollama: {payload}")
try:
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
# The response from Ollama can be a stream of JSON objects, handle this
response_parts = response.text.strip().split('\n')
full_response = ""
for part in response_parts:
if part:
json_part = json.loads(part)
full_response += json_part.get("response", "")
if json_part.get("done"):
break
return full_response
except requests.exceptions.RequestException as e:
logger.error(f"Error communicating with Ollama: {e}")
return f"Error: Could not get a response from Ollama. Details: {e}"
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
try:
from transformers import pipeline
import torch
except ImportError:
raise ImportError("The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'.")
# Auto-detect device
if torch.cuda.is_available():
device = "cuda"
logger.info("CUDA is available. Using GPU.")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
logger.info("MPS is available. Using Apple Silicon GPU.")
else:
device = "cpu"
logger.info("No GPU detected. Using CPU.")
self.pipeline = pipeline("text-generation", model=model_name, device=device)
def ask(self, prompt: str, **kwargs) -> str:
# Sensible defaults for text generation
params = {
"max_length": 500,
"num_return_sequences": 1,
**kwargs
}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
# Handle different response formats from transformers
if isinstance(results, list) and len(results) > 0:
generated_text = results[0].get('generated_text', '') if isinstance(results[0], dict) else str(results[0])
else:
generated_text = str(results)
# Extract only the newly generated portion by removing the original prompt
if isinstance(generated_text, str) and generated_text.startswith(prompt):
response = generated_text[len(prompt):].strip()
else:
# Fallback: return the full response if prompt removal fails
response = str(generated_text)
return response
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.")
logger.info(f"Initializing OpenAI Chat with model='{model}'")
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError("The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'.")
def ask(self, prompt: str, **kwargs) -> str:
# Default parameters for OpenAI
params = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7),
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
}
logger.info(f"Sending request to OpenAI with model {self.model}")
try:
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error communicating with OpenAI: {e}")
return f"Error: Could not get a response from OpenAI. Details: {e}"
class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development."""
def ask(self, prompt: str, **kwargs) -> str:
logger.info("Simulating LLM call...")
print("Prompt sent to LLM (simulation):", prompt[:500] + "...")
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.
Args:
llm_config: A dictionary specifying the LLM type and its parameters.
Example: {"type": "ollama", "model": "llama3"}
{"type": "hf", "model": "distilgpt2"}
None (for simulation mode)
Returns:
An instance of an LLMInterface subclass.
"""
if llm_config is None:
logger.info("No LLM config provided, defaulting to simulated chat.")
return SimulatedChat()
llm_type = llm_config.get("type", "simulated")
model = llm_config.get("model")
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
if llm_type == "ollama":
return OllamaChat(model=model or "llama3:8b", host=llm_config.get("host", "http://localhost:11434"))
elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":
return SimulatedChat()
else:
raise ValueError(f"Unknown LLM type: '{llm_type}'")

View File

@@ -73,15 +73,17 @@ class EmbeddingServerManager:
self.server_process = subprocess.Popen(
command,
cwd=project_root,
# stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
text=True,
encoding='utf-8'
encoding='utf-8',
bufsize=1, # Line buffered
universal_newlines=True
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.")
@@ -90,7 +92,7 @@ class EmbeddingServerManager:
return True
if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.")
self._log_monitor()
self._print_recent_output()
return False
time.sleep(wait_interval)
@@ -102,19 +104,32 @@ class EmbeddingServerManager:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
def _print_recent_output(self):
"""Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout:
return
try:
# Read any available output
import select
import sys
if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read()
if output:
print(f"[{self.backend_module_name} OUTPUT]: {output}")
except Exception as e:
print(f"Error reading server output: {e}")
def _log_monitor(self):
"""Monitors and prints the server's stdout and stderr."""
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[{self.backend_module_name} LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[{self.backend_module_name} ERROR]: {line.strip()}")
self.server_process.stderr.close()
while True:
line = self.server_process.stdout.readline()
if not line:
break
print(f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True)
except Exception as e:
print(f"Log monitor error: {e}")

View File

@@ -0,0 +1,97 @@
import json
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, List
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendSearcherInterface
class BaseSearcher(LeannBackendSearcherInterface, ABC):
"""
Abstract base class for Leann searchers, containing common logic for
loading metadata, managing embedding servers, and handling file paths.
"""
def __init__(self, index_path: str, backend_module_name: str, **kwargs):
"""
Initializes the BaseSearcher.
Args:
index_path: Path to the Leann index file (e.g., '.../my_index.leann').
backend_module_name: The specific embedding server module to use
(e.g., 'leann_backend_hnsw.hnsw_embedding_server').
**kwargs: Additional keyword arguments.
"""
self.index_path = Path(index_path)
self.index_dir = self.index_path.parent
self.meta = kwargs.get("meta", self._load_meta())
if not self.meta:
raise ValueError("Searcher requires metadata from .meta.json.")
self.dimensions = self.meta.get("dimensions")
if not self.dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.label_map = self._load_label_map()
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name
)
def _load_meta(self) -> Dict[str, Any]:
"""Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
with open(meta_path, 'r', encoding='utf-8') as f:
return json.load(f)
def _load_label_map(self) -> Dict[int, str]:
"""Loads the mapping from integer IDs to string IDs."""
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
return pickle.load(f)
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> None:
"""
Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses.
"""
if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
server_started = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"),
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {kwargs.get('zmq_port')}")
@abstractmethod
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""
Search for the top_k nearest neighbors of the query vector.
Must be implemented by subclasses.
"""
pass
def __del__(self):
"""Ensures the embedding server is stopped when the searcher is destroyed."""
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()

View File

@@ -21,7 +21,7 @@ dependencies = [
"colorama",
"boto3",
"protobuf==4.25.3",
"sglang[all]",
"sglang",
"ollama",
"requests>=2.25.0",
"sentence-transformers>=2.2.0",

View File

@@ -0,0 +1,147 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
# break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content
doc_content = f"""
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
{body}
"""
# Create metadata
metadata = {
'file_path': filepath,
'subject': subject,
'from': from_addr,
'to': to_addr,
'date': date,
'filename': filename
}
if count == 0:
print("--------------------------------")
print('dir path', dirpath)
print(metadata)
print(doc_content)
print("--------------------------------")
body=[]
if msg.is_multipart():
for part in msg.walk():
print("-------------------------------- get content type -------------------------------")
print(part.get_content_type())
print(part)
# body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore'))
print("-------------------------------- get content type -------------------------------")
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
print(body)
print(body)
print("--------------------------------")
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"!!!!!!! Error parsing email from {filepath}: {e} !!!!!!!!")
continue
except Exception as e:
print(f"!!!!!!! Error reading file !!!!!!!! {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
# Use the custom EmlxReader instead of MboxReader
documents = EmlxReader().load_data(
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
max_count=1000
) # Returns list of documents
# Configure the index with larger chunk size to handle long metadata
from llama_index.core.node_parser import SentenceSplitter
# Create a custom text splitter with larger chunk size
text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200)
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter]
) # Initialize index with documents
query_engine = index.as_query_engine()
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
print(res)

View File

@@ -0,0 +1,213 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document, StorageContext
from llama_index.core.readers.base import BaseReader
from llama_index.core.node_parser import SentenceSplitter
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content
doc_content = f"""
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
{body}
"""
# Create metadata
metadata = {
'file_path': filepath,
'subject': subject,
'from': from_addr,
'to': to_addr,
'date': date,
'filename': filename
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000):
"""
Create the index from mail data and save it to disk.
Args:
mail_path: Path to the mail directory
save_dir: Directory to save the index
max_count: Maximum number of emails to process
"""
print("Creating index from mail data...")
# Load documents
documents = EmlxReader().load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
# Create text splitter
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0)
# Create index
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter]
)
# Save the index
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index"):
"""
Load the saved index from disk.
Args:
save_dir: Directory where the index is saved
Returns:
Loaded index or None if loading fails
"""
try:
# Load storage context
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
# Load index
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
"""
Query the loaded index.
Args:
index: The loaded index
query: The query string
"""
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
save_dir = "mail_index"
# Check if index already exists
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
index = load_index(save_dir)
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=1000)
if index:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"What emails mention GSR appointments?",
"Find emails about deadlines"
]
for query in queries:
print("\n" + "="*50)
query_index(index, query)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,211 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document, StorageContext
from llama_index.core.readers.base import BaseReader
from llama_index.core.node_parser import SentenceSplitter
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with reduced metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content with metadata embedded in text
doc_content = f"""
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
{body}
"""
# Create minimal metadata (only essential info)
metadata = {
'subject': subject[:50], # Truncate subject
'from': from_addr[:30], # Truncate from
'date': date[:20], # Truncate date
'filename': filename # Keep filename
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000):
"""
Create the index from mail data and save it to disk.
Args:
mail_path: Path to the mail directory
save_dir: Directory to save the index
max_count: Maximum number of emails to process
"""
print("Creating index from mail data with small chunks...")
# Load documents
documents = EmlxReader().load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
# Create text splitter with small chunk size
text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
# Create index
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter]
)
# Save the index
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index_small"):
"""
Load the saved index from disk.
Args:
save_dir: Directory where the index is saved
Returns:
Loaded index or None if loading fails
"""
try:
# Load storage context
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
# Load index
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
"""
Query the loaded index.
Args:
index: The loaded index
query: The query string
"""
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
save_dir = "mail_index_small"
# Check if index already exists
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
index = load_index(save_dir)
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=1000)
if index:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"What emails mention GSR appointments?",
"Find emails about deadlines"
]
for query in queries:
print("\n" + "="*50)
query_index(index, query)
if __name__ == "__main__":
main()

147
test/mail_reader_test.py Normal file
View File

@@ -0,0 +1,147 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Check if directory exists and is accessible
if not os.path.exists(input_dir):
print(f"Error: Directory '{input_dir}' does not exist")
return docs
if not os.access(input_dir, os.R_OK):
print(f"Error: Directory '{input_dir}' is not accessible (permission denied)")
print("This is likely due to macOS security restrictions on Mail app data")
return docs
print(f"Scanning directory: {input_dir}")
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
print(f"Found .emlx file: {filepath}")
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content
doc_content = f"""
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
{body}
"""
# Create metadata
metadata = {
'file_path': filepath,
'subject': subject,
'from': from_addr,
'to': to_addr,
'date': date,
'filename': filename
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
def main():
# Use the current directory where the sample.emlx file is located
current_dir = os.path.dirname(os.path.abspath(__file__))
print("Testing EmlxReader with sample .emlx file...")
print(f"Scanning directory: {current_dir}")
# Use the custom EmlxReader
documents = EmlxReader().load_data(current_dir, max_count=1000)
if not documents:
print("No documents loaded. Make sure sample.emlx exists in the examples directory.")
return
print(f"\nSuccessfully loaded {len(documents)} document(s)")
# Initialize index with documents
index = VectorStoreIndex.from_documents(documents)
query_engine = index.as_query_engine()
print("\nTesting query: 'Hows Berkeley Graduate Student Instructor'")
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
print(f"Response: {res}")
if __name__ == "__main__":
main()

99
test/query_saved_index.py Normal file
View File

@@ -0,0 +1,99 @@
import os
from llama_index.core import VectorStoreIndex, StorageContext
def load_index(save_dir: str = "mail_index"):
"""
Load the saved index from disk.
Args:
save_dir: Directory where the index is saved
Returns:
Loaded index or None if loading fails
"""
try:
# Load storage context
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
# Load index
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
"""
Query the loaded index.
Args:
index: The loaded index
query: The query string
"""
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"\nQuery: {query}")
print(f"Response: {response}")
def main():
save_dir = "mail_index"
# Check if index exists
if not os.path.exists(save_dir) or not os.path.exists(os.path.join(save_dir, "vector_store.json")):
print(f"Index not found in {save_dir}")
print("Please run mail_reader_save_load.py first to create the index.")
return
# Load the index
index = load_index(save_dir)
if not index:
print("Failed to load index.")
return
print("\n" + "="*60)
print("Email Query Interface")
print("="*60)
print("Type 'quit' to exit")
print("Type 'help' for example queries")
print("="*60)
# Interactive query loop
while True:
try:
query = input("\nEnter your query: ").strip()
if query.lower() == 'quit':
print("Goodbye!")
break
elif query.lower() == 'help':
print("\nExample queries:")
print("- Hows Berkeley Graduate Student Instructor")
print("- What emails mention GSR appointments?")
print("- Find emails about deadlines")
print("- Search for emails from specific sender")
print("- Find emails about meetings")
continue
elif not query:
continue
query_index(index, query)
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error processing query: {e}")
if __name__ == "__main__":
main()

View File

@@ -1,68 +0,0 @@
# HNSW Index Storage Optimization
This document explains the storage optimization features available in the HNSW backend.
## Storage Modes
The HNSW backend supports two orthogonal optimization techniques:
### 1. CSR Compression (`is_compact=True`)
- Converts the graph structure from standard format to Compressed Sparse Row (CSR) format
- Reduces memory overhead from graph adjacency storage
- Maintains all embedding data for direct access
### 2. Embedding Pruning (`is_recompute=True`)
- Removes embedding vectors from the index file
- Replaces them with a NULL storage marker
- Requires recomputation via embedding server during search
- Must be used with `is_compact=True` for efficiency
## Performance Impact
**Storage Reduction (100 vectors, 384 dimensions):**
```
Standard format: 168 KB (embeddings + graph)
CSR only: 160 KB (embeddings + compressed graph)
CSR + Pruned: 6 KB (compressed graph only)
```
**Key Benefits:**
- **CSR compression**: ~5% size reduction from graph optimization
- **Embedding pruning**: ~95% size reduction by removing embeddings
- **Combined**: Up to 96% total storage reduction
## Usage
```python
# Standard format (largest)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=False,
is_recompute=False
)
# CSR compressed (medium)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=True,
is_recompute=False
)
# CSR + Pruned (smallest, requires embedding server)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=True, # Required for pruning
is_recompute=True # Default: enabled
)
```
## Trade-offs
| Mode | Storage | Search Speed | Memory Usage | Setup Complexity |
|------|---------|--------------|--------------|------------------|
| Standard | Largest | Fastest | Highest | Simple |
| CSR | Medium | Fast | Medium | Simple |
| CSR + Pruned | Smallest | Slower* | Lowest | Complex** |
*Requires network round-trip to embedding server for recomputation
**Needs embedding server and passages file for search

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env python3
"""
Sanity check script to verify HNSW index pruning effectiveness.
Tests the difference in file sizes between pruned and non-pruned indices.
"""
import os
import sys
import tempfile
from pathlib import Path
import numpy as np
import json
# Add the project root to the Python path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
# Import backend packages to trigger plugin registration
import leann_backend_hnsw
from leann.api import LeannBuilder
def create_sample_documents(num_docs=1000):
"""Create sample documents for testing"""
documents = []
for i in range(num_docs):
documents.append(f"Sample document {i} with some random text content for testing purposes.")
return documents
def build_index(documents, output_dir, is_recompute=True):
"""Build HNSW index with specified recompute setting"""
index_path = os.path.join(output_dir, "test_index.hnsw")
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
M=16,
efConstruction=100,
distance_metric="mips",
is_compact=True,
is_recompute=is_recompute
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
return index_path
def get_file_size(filepath):
"""Get file size in bytes"""
return os.path.getsize(filepath)
def main():
print("🔍 HNSW Pruning Sanity Check")
print("=" * 50)
# Create sample data
print("📊 Creating sample documents...")
documents = create_sample_documents(num_docs=1000)
print(f" Number of documents: {len(documents)}")
with tempfile.TemporaryDirectory() as temp_dir:
print(f"📁 Working in temporary directory: {temp_dir}")
# Build index with pruning (is_recompute=True)
print("\n🔨 Building index with pruning enabled (is_recompute=True)...")
pruned_dir = os.path.join(temp_dir, "pruned")
os.makedirs(pruned_dir, exist_ok=True)
pruned_index_path = build_index(documents, pruned_dir, is_recompute=True)
# Check what files were actually created
print(f" Looking for index files at: {pruned_index_path}")
import glob
files = glob.glob(f"{pruned_index_path}*")
print(f" Found files: {files}")
# Try to find the actual index file
if os.path.exists(f"{pruned_index_path}.index"):
pruned_index_file = f"{pruned_index_path}.index"
else:
# Look for any .index file in the directory
index_files = glob.glob(f"{pruned_dir}/*.index")
if index_files:
pruned_index_file = index_files[0]
else:
raise FileNotFoundError(f"No .index file found in {pruned_dir}")
pruned_size = get_file_size(pruned_index_file)
print(f" ✅ Pruned index built successfully")
print(f" 📏 Pruned index size: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
# Build index without pruning (is_recompute=False)
print("\n🔨 Building index without pruning (is_recompute=False)...")
non_pruned_dir = os.path.join(temp_dir, "non_pruned")
os.makedirs(non_pruned_dir, exist_ok=True)
non_pruned_index_path = build_index(documents, non_pruned_dir, is_recompute=False)
# Check what files were actually created
print(f" Looking for index files at: {non_pruned_index_path}")
files = glob.glob(f"{non_pruned_index_path}*")
print(f" Found files: {files}")
# Try to find the actual index file
if os.path.exists(f"{non_pruned_index_path}.index"):
non_pruned_index_file = f"{non_pruned_index_path}.index"
else:
# Look for any .index file in the directory
index_files = glob.glob(f"{non_pruned_dir}/*.index")
if index_files:
non_pruned_index_file = index_files[0]
else:
raise FileNotFoundError(f"No .index file found in {non_pruned_dir}")
non_pruned_size = get_file_size(non_pruned_index_file)
print(f" ✅ Non-pruned index built successfully")
print(f" 📏 Non-pruned index size: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
# Compare sizes
print("\n📊 Comparison Results:")
print("=" * 30)
size_diff = non_pruned_size - pruned_size
size_ratio = pruned_size / non_pruned_size if non_pruned_size > 0 else 0
reduction_percent = (1 - size_ratio) * 100
print(f"Non-pruned index: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
print(f"Pruned index: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
print(f"Size difference: {size_diff:,} bytes ({size_diff/1024:.1f} KB)")
print(f"Size ratio: {size_ratio:.3f}")
print(f"Size reduction: {reduction_percent:.1f}%")
# Verify pruning effectiveness
print("\n🔍 Verification:")
if size_diff > 0:
print(" ✅ Pruning is effective - pruned index is smaller")
if reduction_percent > 10:
print(f" ✅ Significant size reduction: {reduction_percent:.1f}%")
else:
print(f" ⚠️ Small size reduction: {reduction_percent:.1f}%")
else:
print(" ❌ Pruning appears ineffective - no size reduction")
# Check if passages files were created
pruned_passages = f"{pruned_index_path}.passages.json"
non_pruned_passages = f"{non_pruned_index_path}.passages.json"
print(f"\n📄 Passages files:")
print(f" Pruned passages file exists: {os.path.exists(pruned_passages)}")
print(f" Non-pruned passages file exists: {os.path.exists(non_pruned_passages)}")
return True
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)