fix: resolve all ruff linting errors and add lint CI check

- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments
- Replace Chinese comments with English equivalents
- Fix unused imports with proper noqa annotations for intentional imports
- Fix bare except clauses with specific exception types
- Fix redefined variables and undefined names
- Add ruff noqa annotations for generated protobuf files
- Add lint and format check to GitHub Actions CI pipeline
This commit is contained in:
Andy Lee
2025-07-26 22:35:12 -07:00
parent 8537a6b17e
commit b3e9ee96fa
53 changed files with 5655 additions and 5220 deletions

View File

@@ -1,5 +1,6 @@
import os
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from leann.api import LeannBuilder, LeannChat
# Define the path for our new MLX-based index
INDEX_PATH = "./mlx_diskann_index/leann"
@@ -38,7 +39,5 @@ chat = LeannChat(index_path=INDEX_PATH)
# add query
query = "MLX is an array framework for machine learning on Apple silicon."
print(f"Query: {query}")
response = chat.ask(
query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1
)
response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1)
print(f"Response: {response}")

View File

@@ -1,76 +1,84 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document
import os
from typing import Any
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
lines = content.split("\n", 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
subject = msg.get("Subject", "No Subject")
from_addr = msg.get("From", "Unknown")
to_addr = msg.get("To", "Unknown")
date = msg.get("Date", "Unknown")
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
if (
part.get_content_type() == "text/plain"
or part.get_content_type() == "text/html"
):
body += part.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# Create document content
doc_content = f"""
From: {from_addr}
@@ -80,32 +88,38 @@ Date: {date}
{body}
"""
# Create metadata
metadata = {
'file_path': filepath,
'subject': subject,
'from': from_addr,
'to': to_addr,
'date': date,
'filename': filename
"file_path": filepath,
"subject": subject,
"from": from_addr,
"to": to_addr,
"date": date,
"filename": filename,
}
if count == 0:
print("--------------------------------")
print('dir path', dirpath)
print("dir path", dirpath)
print(metadata)
print(doc_content)
print("--------------------------------")
body=[]
body = []
if msg.is_multipart():
for part in msg.walk():
print("-------------------------------- get content type -------------------------------")
print(
"-------------------------------- get content type -------------------------------"
)
print(part.get_content_type())
print(part)
# body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore'))
print("-------------------------------- get content type -------------------------------")
print(
"-------------------------------- get content type -------------------------------"
)
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
print(body)
print(body)
@@ -113,22 +127,23 @@ Date: {date}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"!!!!!!! Error parsing email from {filepath}: {e} !!!!!!!!")
continue
except Exception as e:
print(f"!!!!!!! Error reading file !!!!!!!! {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
# Use the custom EmlxReader instead of MboxReader
documents = EmlxReader().load_data(
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
max_count=1000
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
max_count=1000,
) # Returns list of documents
# Configure the index with larger chunk size to handle long metadata
@@ -138,10 +153,9 @@ from llama_index.core.node_parser import SentenceSplitter
text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200)
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter]
documents, transformations=[text_splitter]
) # Initialize index with documents
query_engine = index.as_query_engine()
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
print(res)
print(res)

View File

@@ -1,77 +1,82 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document, StorageContext
from llama_index.core.readers.base import BaseReader
import os
from typing import Any
from llama_index.core import Document, StorageContext, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
lines = content.split("\n", 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
subject = msg.get("Subject", "No Subject")
from_addr = msg.get("From", "Unknown")
to_addr = msg.get("To", "Unknown")
date = msg.get("Date", "Unknown")
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
body = part.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# Create document content
doc_content = f"""
From: {from_addr}
@@ -81,97 +86,96 @@ Date: {date}
{body}
"""
# Create metadata
metadata = {
'file_path': filepath,
'subject': subject,
'from': from_addr,
'to': to_addr,
'date': date,
'filename': filename
"file_path": filepath,
"subject": subject,
"from": from_addr,
"to": to_addr,
"date": date,
"filename": filename,
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000):
"""
Create the index from mail data and save it to disk.
Args:
mail_path: Path to the mail directory
save_dir: Directory to save the index
max_count: Maximum number of emails to process
"""
print("Creating index from mail data...")
# Load documents
documents = EmlxReader().load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
# Create text splitter
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0)
# Create index
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter]
)
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
# Save the index
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index"):
"""
Load the saved index from disk.
Args:
save_dir: Directory where the index is saved
Returns:
Loaded index or None if loading fails
"""
try:
# Load storage context
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
# Load index
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
storage_context.vector_store, storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
"""
Query the loaded index.
Args:
index: The loaded index
query: The query string
@@ -179,16 +183,17 @@ def query_index(index, query: str):
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
save_dir = "mail_index"
# Check if index already exists
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
@@ -196,18 +201,19 @@ def main():
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=1000)
if index:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"What emails mention GSR appointments?",
"Find emails about deadlines"
"Find emails about deadlines",
]
for query in queries:
print("\n" + "="*50)
print("\n" + "=" * 50)
query_index(index, query)
if __name__ == "__main__":
main()
main()

View File

@@ -1,77 +1,82 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document, StorageContext
from llama_index.core.readers.base import BaseReader
import os
from typing import Any
from llama_index.core import Document, StorageContext, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with reduced metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
lines = content.split("\n", 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
subject = msg.get("Subject", "No Subject")
from_addr = msg.get("From", "Unknown")
to_addr = msg.get("To", "Unknown")
date = msg.get("Date", "Unknown")
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
body = part.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# Create document content with metadata embedded in text
doc_content = f"""
From: {from_addr}
@@ -81,95 +86,96 @@ Date: {date}
{body}
"""
# Create minimal metadata (only essential info)
metadata = {
'subject': subject[:50], # Truncate subject
'from': from_addr[:30], # Truncate from
'date': date[:20], # Truncate date
'filename': filename # Keep filename
"subject": subject[:50], # Truncate subject
"from": from_addr[:30], # Truncate from
"date": date[:20], # Truncate date
"filename": filename, # Keep filename
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000):
def create_and_save_index(
mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000
):
"""
Create the index from mail data and save it to disk.
Args:
mail_path: Path to the mail directory
save_dir: Directory to save the index
max_count: Maximum number of emails to process
"""
print("Creating index from mail data with small chunks...")
# Load documents
documents = EmlxReader().load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
# Create text splitter with small chunk size
text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
# Create index
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter]
)
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
# Save the index
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index_small"):
"""
Load the saved index from disk.
Args:
save_dir: Directory where the index is saved
Returns:
Loaded index or None if loading fails
"""
try:
# Load storage context
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
# Load index
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
storage_context.vector_store, storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
"""
Query the loaded index.
Args:
index: The loaded index
query: The query string
@@ -177,16 +183,17 @@ def query_index(index, query: str):
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
save_dir = "mail_index_small"
# Check if index already exists
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
@@ -194,18 +201,19 @@ def main():
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=1000)
if index:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"What emails mention GSR appointments?",
"Find emails about deadlines"
"Find emails about deadlines",
]
for query in queries:
print("\n" + "="*50)
print("\n" + "=" * 50)
query_index(index, query)
if __name__ == "__main__":
main()
main()

View File

@@ -1,89 +1,94 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import VectorStoreIndex, Document
import os
from typing import Any
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
count = 0
# Check if directory exists and is accessible
if not os.path.exists(input_dir):
print(f"Error: Directory '{input_dir}' does not exist")
return docs
if not os.access(input_dir, os.R_OK):
print(f"Error: Directory '{input_dir}' is not accessible (permission denied)")
print("This is likely due to macOS security restrictions on Mail app data")
return docs
print(f"Scanning directory: {input_dir}")
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
print(f"Found .emlx file: {filepath}")
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
lines = content.split("\n", 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
subject = msg.get("Subject", "No Subject")
from_addr = msg.get("From", "Unknown")
to_addr = msg.get("To", "Unknown")
date = msg.get("Date", "Unknown")
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
body = part.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# Create document content
doc_content = f"""
From: {from_addr}
@@ -93,55 +98,57 @@ Date: {date}
{body}
"""
# Create metadata
metadata = {
'file_path': filepath,
'subject': subject,
'from': from_addr,
'to': to_addr,
'date': date,
'filename': filename
"file_path": filepath,
"subject": subject,
"from": from_addr,
"to": to_addr,
"date": date,
"filename": filename,
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs
def main():
# Use the current directory where the sample.emlx file is located
current_dir = os.path.dirname(os.path.abspath(__file__))
print("Testing EmlxReader with sample .emlx file...")
print(f"Scanning directory: {current_dir}")
# Use the custom EmlxReader
documents = EmlxReader().load_data(current_dir, max_count=1000)
if not documents:
print("No documents loaded. Make sure sample.emlx exists in the examples directory.")
return
print(f"\nSuccessfully loaded {len(documents)} document(s)")
# Initialize index with documents
index = VectorStoreIndex.from_documents(documents)
query_engine = index.as_query_engine()
print("\nTesting query: 'Hows Berkeley Graduate Student Instructor'")
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
print(f"Response: {res}")
if __name__ == "__main__":
main()
main()

View File

@@ -2,20 +2,20 @@
import argparse
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm
from contextlib import contextmanager
from transformers import AutoModel, BitsAndBytesConfig
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: List[int]
batch_sizes: list[int]
seq_length: int
num_runs: int
use_fp16: bool = True
@@ -28,47 +28,44 @@ class BenchmarkConfig:
class GraphContainer:
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, seq_length: int):
self.model = model
self.seq_length = seq_length
self.graphs: Dict[int, 'GraphWrapper'] = {}
def get_or_create(self, batch_size: int) -> 'GraphWrapper':
self.graphs: dict[int, GraphWrapper] = {}
def get_or_create(self, batch_size: int) -> "GraphWrapper":
if batch_size not in self.graphs:
self.graphs[batch_size] = GraphWrapper(
self.model, batch_size, self.seq_length
)
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
return self.graphs[batch_size]
class GraphWrapper:
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.device = self._get_device()
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Only use CUDA graphs on NVIDIA GPUs
if torch.cuda.is_available() and hasattr(torch.cuda, 'CUDAGraph'):
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
input_ids=self.static_input, attention_mask=self.static_attention_mask
)
self.use_cuda_graph = True
else:
# For MPS or CPU, just store the model
self.use_cuda_graph = False
self.static_output = None
def _get_device(self) -> str:
if torch.cuda.is_available():
return "cuda"
@@ -76,22 +73,17 @@ class GraphWrapper:
return "mps"
else:
return "cpu"
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length),
device=self.device,
dtype=torch.long
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
self.model(input_ids=self.static_input, attention_mask=self.static_attention_mask)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.use_cuda_graph:
self.static_input.copy_(input_ids)
@@ -105,14 +97,14 @@ class GraphWrapper:
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
if model is None:
raise ValueError("Cannot optimize None model")
# Move to GPU
if torch.cuda.is_available():
model = model.cuda()
@@ -124,53 +116,59 @@ class ModelOptimizer:
model = model.cpu()
device = "cpu"
print(f"- Model moved to {device}")
# FP16
if config.use_fp16 and not config.use_int4:
model = model.half()
# use torch compile
model = torch.compile(model)
print("- Using FP16 precision")
# Check if using SDPA (only on CUDA)
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
if (
torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Flash Attention (only on CUDA)
if config.use_flash_attention and torch.cuda.is_available():
try:
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attention import FlashAttention # noqa: F401
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Memory efficient attention (only on CUDA)
if torch.cuda.is_available():
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
from xformers.ops import memory_efficient_attention # noqa: F401
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using GPU events or CPU timing."""
def __init__(self):
if torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True)
@@ -182,7 +180,7 @@ class Timer:
else:
# CPU timing
self.use_gpu_timing = False
@contextmanager
def timing(self):
if self.use_gpu_timing:
@@ -195,7 +193,7 @@ class Timer:
start_time = time.time()
yield
self.cpu_elapsed = time.time() - start_time
def elapsed_time(self) -> float:
if self.use_gpu_timing:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
@@ -205,14 +203,14 @@ class Timer:
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
try:
self.model = self._load_model()
if self.model is None:
raise ValueError("Model initialization failed - model is None")
# Only use CUDA graphs on NVIDIA GPUs
if config.use_cuda_graphs and torch.cuda.is_available():
self.graphs = GraphContainer(self.model, config.seq_length)
@@ -220,25 +218,27 @@ class Benchmark:
self.graphs = None
self.timer = Timer()
except Exception as e:
print(f"ERROR in benchmark initialization: {str(e)}")
print(f"ERROR in benchmark initialization: {e!s}")
raise
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
try:
# Int4 quantization using HuggingFace integration
if self.config.use_int4:
import bitsandbytes as bnb
print(f"- bitsandbytes version: {bnb.__version__}")
# 检查是否使用自定义的8bit量化
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
# Check if using custom 8bit quantization
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Using custom Linear8bitLt replacement for all linear layers")
# 加载原始模型(不使用量化配置)
# Load original model (without quantization config)
import bitsandbytes as bnb
import torch
# set default to half
torch.set_default_dtype(torch.float16)
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
@@ -246,112 +246,120 @@ class Benchmark:
self.config.model_path,
torch_dtype=compute_dtype,
)
# 定义替换函数
# Define replacement function
def replace_linear_with_linear8bitlt(model):
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
for name, module in list(model.named_children()):
if isinstance(module, nn.Linear):
# 获取原始线性层的参数
# Get original linear layer parameters
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
# 创建8bit线性层
# Create 8bit linear layer
# print size
print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt(
in_features,
out_features,
bias=bias,
has_fp16_weights=False
in_features, out_features, bias=bias, has_fp16_weights=False
)
# 复制权重和偏置
# Copy weights and bias
new_module.weight.data = module.weight.data
if bias:
new_module.bias.data = module.bias.data
# 替换模块
# Replace module
setattr(model, name, new_module)
else:
# 递归处理子模块
# Process child modules recursively
replace_linear_with_linear8bitlt(module)
return model
# 替换所有线性层
# Replace all linear layers
model = replace_linear_with_linear8bitlt(model)
# add torch compile
model = torch.compile(model)
# 将模型移到GPU量化发生在这里
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Move model to GPU (quantization happens here)
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
model = model.to(device)
print("- All linear layers replaced with Linear8bitLt")
else:
# 使用原来的Int4量化方法
# Use original Int4 quantization method
print("- Using bitsandbytes for Int4 quantization")
# Create quantization config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
bnb_4bit_quant_type="nf4",
)
print("- Quantization config:", quantization_config)
# Load model directly with quantization config
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto" # Let HF decide on device mapping
device_map="auto", # Let HF decide on device mapping
)
# Check if model loaded successfully
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
# Apply optimizations directly here
print("\nApplying model optimizations:")
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Model moved to GPU with Linear8bitLt quantization")
else:
# Skip moving to GPU since device_map="auto" already did that
print("- Model already on GPU due to device_map='auto'")
# Skip FP16 conversion since we specified compute_dtype
print(f"- Using {compute_dtype} for compute dtype")
# Check CUDA and SDPA
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
if (
torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Try xformers if available (only on CUDA)
if torch.cuda.is_available():
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
from xformers.ops import memory_efficient_attention # noqa: F401
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
@@ -365,76 +373,79 @@ class Benchmark:
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto"
device_map="auto",
)
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
model.eval()
print("- Model set to eval mode")
else:
# Standard loading for FP16/FP32
model = AutoModel.from_pretrained(self.config.model_path)
print("- Model loaded in standard precision")
print(f"- Model type: {type(model)}")
# Apply standard optimizations
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = ModelOptimizer.optimize(model, self.config)
model = model.half()
# add torch compile
model = torch.compile(model)
# Final check to ensure model is not None
if model is None:
raise ValueError("Model is None after optimization")
print(f"- Final model type: {type(model)}")
return model
except Exception as e:
print(f"ERROR loading model: {str(e)}")
print(f"ERROR loading model: {e!s}")
import traceback
traceback.print_exc()
raise
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device=device,
dtype=torch.long
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
return torch.randint(
0, 1000, (batch_size, self.config.seq_length), device=device, dtype=torch.long
)
def _run_inference(
self,
input_ids: torch.Tensor,
graph_wrapper: Optional[GraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
) -> tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
with torch.no_grad(), self.timer.timing():
if graph_wrapper is not None:
output = graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]:
def run(self) -> dict[int, dict[str, float]]:
results = {}
# Reset peak memory stats
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
@@ -443,22 +454,20 @@ class Benchmark:
pass
else:
print("- No GPU memory stats available")
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create graph for this batch size
graph_wrapper = (
self.graphs.get_or_create(batch_size)
if self.graphs is not None
else None
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
)
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
print(f"Input shape: {input_ids.shape}")
# Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
@@ -469,44 +478,44 @@ class Benchmark:
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
print(f"No successful runs for batch size {batch_size}, skipping")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
# Log memory usage
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
elif torch.backends.mps.is_available():
# MPS doesn't have max_memory_allocated, use 0
peak_memory_gb = 0.0
else:
peak_memory_gb = 0.0
print("- No GPU memory usage available")
if peak_memory_gb > 0:
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
else:
print("\n- GPU memory usage not available")
# Add memory info to results
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
@@ -566,14 +575,14 @@ def main():
action="store_true",
help="Enable Linear8bitLt quantization for all linear layers",
)
args = parser.parse_args()
# Print arguments for debugging
print("\nCommand line arguments:")
for arg, value in vars(args).items():
print(f"- {arg}: {value}")
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
@@ -586,45 +595,56 @@ def main():
use_flash_attention=args.use_flash_attention,
use_linear8bitlt=args.use_linear8bitlt,
)
# Print configuration for debugging
print("\nBenchmark configuration:")
for field, value in vars(config).items():
print(f"- {field}: {value}")
try:
benchmark = Benchmark(config)
results = benchmark.run()
# Save results to file
import json
import os
# Create results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Generate filename based on configuration
precision_type = "int4" if config.use_int4 else "int8" if config.use_int8 else "fp16" if config.use_fp16 else "fp32"
precision_type = (
"int4"
if config.use_int4
else "int8"
if config.use_int8
else "fp16"
if config.use_fp16
else "fp32"
)
model_name = os.path.basename(config.model_path)
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
# Save results
with open(output_file, "w") as f:
json.dump(
{
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
"results": {str(k): v for k, v in results.items()}
},
f,
indent=2
"config": {
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
},
"results": {str(k): v for k, v in results.items()},
},
f,
indent=2,
)
print(f"Results saved to {output_file}")
except Exception as e:
print(f"Benchmark failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()
main()

View File

@@ -1,37 +1,39 @@
import os
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core import StorageContext, VectorStoreIndex
def load_index(save_dir: str = "mail_index"):
"""
Load the saved index from disk.
Args:
save_dir: Directory where the index is saved
Returns:
Loaded index or None if loading fails
"""
try:
# Load storage context
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
# Load index
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
storage_context.vector_store, storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
"""
Query the loaded index.
Args:
index: The loaded index
query: The query string
@@ -39,44 +41,47 @@ def query_index(index, query: str):
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"\nQuery: {query}")
print(f"Response: {response}")
def main():
save_dir = "mail_index"
# Check if index exists
if not os.path.exists(save_dir) or not os.path.exists(os.path.join(save_dir, "vector_store.json")):
if not os.path.exists(save_dir) or not os.path.exists(
os.path.join(save_dir, "vector_store.json")
):
print(f"Index not found in {save_dir}")
print("Please run mail_reader_save_load.py first to create the index.")
return
# Load the index
index = load_index(save_dir)
if not index:
print("Failed to load index.")
return
print("\n" + "="*60)
print("\n" + "=" * 60)
print("Email Query Interface")
print("="*60)
print("=" * 60)
print("Type 'quit' to exit")
print("Type 'help' for example queries")
print("="*60)
print("=" * 60)
# Interactive query loop
while True:
try:
query = input("\nEnter your query: ").strip()
if query.lower() == 'quit':
if query.lower() == "quit":
print("Goodbye!")
break
elif query.lower() == 'help':
elif query.lower() == "help":
print("\nExample queries:")
print("- Hows Berkeley Graduate Student Instructor")
print("- What emails mention GSR appointments?")
@@ -86,14 +91,15 @@ def main():
continue
elif not query:
continue
query_index(index, query)
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error processing query: {e}")
if __name__ == "__main__":
main()
main()

View File

@@ -1,43 +1,46 @@
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
from sentence_transformers import SentenceTransformer
import mlx.core as mx
import numpy as np
import torch
from mlx_lm import load
from sentence_transformers import SentenceTransformer
# --- Configuration ---
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
NUM_RUNS = 10 # Number of runs to average for each batch size
WARMUP_RUNS = 2 # Number of warm-up runs
WARMUP_RUNS = 2 # Number of warm-up runs
# --- Generate Dummy Data ---
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
# --- Benchmark Functions ---b
def benchmark_torch(model, sentences):
start_time = time.time()
model.encode(sentences, convert_to_numpy=True)
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
def benchmark_mlx(model, tokenizer, sentences):
start_time = time.time()
# Tokenize sentences using MLX tokenizer
tokens = []
for sentence in sentences:
token_ids = tokenizer.encode(sentence)
tokens.append(token_ids)
# Pad sequences to the same length
max_len = max(len(t) for t in tokens)
input_ids = []
attention_mask = []
for token_seq in tokens:
# Pad sequence
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
@@ -45,24 +48,25 @@ def benchmark_mlx(model, tokenizer, sentences):
# Create attention mask (1 for real tokens, 0 for padding)
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
attention_mask.append(mask)
# Convert to MLX arrays
input_ids = mx.array(input_ids)
attention_mask = mx.array(attention_mask)
# Get embeddings
embeddings = model(input_ids)
# Mean pooling
mask = mx.expand_dims(attention_mask, -1)
sum_embeddings = (embeddings * mask).sum(axis=1)
sum_mask = mask.sum(axis=1)
_ = sum_embeddings / sum_mask
mx.eval() # Ensure computation is finished
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
# --- Main Execution ---
def main():
print("--- Initializing Models ---")
@@ -92,13 +96,15 @@ def main():
for batch_size in BATCH_SIZES:
print(f"Benchmarking batch size: {batch_size}")
sentences_batch = DUMMY_SENTENCES[:batch_size]
# Benchmark PyTorch
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
results_torch.append(np.mean(torch_times))
# Benchmark MLX
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
mlx_times = [
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
]
results_mlx.append(np.mean(mlx_times))
print("\n--- Benchmark Results (Average time per batch in ms) ---")
@@ -109,20 +115,21 @@ def main():
# --- Plotting ---
print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6))
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
plt.plot(BATCH_SIZES, results_torch, marker="o", linestyle="-", label=f"PyTorch ({device})")
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
plt.xlabel("Batch Size")
plt.ylabel("Average Time per Batch (ms)")
plt.xticks(BATCH_SIZES)
plt.grid(True)
plt.legend()
# Save the plot
output_filename = "embedding_benchmark.png"
plt.savefig(output_filename)
print(f"Plot saved to {output_filename}")
if __name__ == "__main__":
main()

View File

@@ -3,49 +3,52 @@
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py
"""
import zmq
import time
import threading
import sys
sys.path.append('packages/leann-backend-diskann')
import time
import zmq
sys.path.append("packages/leann-backend-diskann")
from leann_backend_diskann import embedding_pb2
def test_zmq_with_same_model():
print("=== Testing ZMQ with same model as main_cli_example.py ===")
# Test the exact same model that main_cli_example.py uses
model_name = "sentence-transformers/all-mpnet-base-v2"
# Start server with the same model
import subprocess
server_cmd = [
sys.executable, "-m",
sys.executable,
"-m",
"packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
"--zmq-port", "5556", # Use different port to avoid conflicts
"--model-name", model_name
"--zmq-port",
"5556", # Use different port to avoid conflicts
"--model-name",
model_name,
]
print(f"Starting server with command: {' '.join(server_cmd)}")
server_process = subprocess.Popen(
server_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
# Wait for server to start
print("Waiting for server to start...")
time.sleep(10)
# Check if server is running
if server_process.poll() is not None:
stdout, stderr = server_process.communicate()
print(f"Server failed to start. stdout: {stdout}")
print(f"Server failed to start. stderr: {stderr}")
return False
print(f"Server started with PID: {server_process.pid}")
try:
# Test client
context = zmq.Context()
@@ -53,39 +56,39 @@ def test_zmq_with_same_model():
socket.connect("tcp://127.0.0.1:5556")
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout like C++
socket.setsockopt(zmq.SNDTIMEO, 30000)
# Create request with same format as C++
request = embedding_pb2.NodeEmbeddingRequest()
request.node_ids.extend([0, 1, 2, 3, 4]) # Test with some node IDs
print(f"Sending request with {len(request.node_ids)} node IDs...")
start_time = time.time()
# Send request
socket.send(request.SerializeToString())
# Receive response
response_data = socket.recv()
end_time = time.time()
print(f"Received response in {end_time - start_time:.3f} seconds")
print(f"Response size: {len(response_data)} bytes")
# Parse response
response = embedding_pb2.NodeEmbeddingResponse()
response.ParseFromString(response_data)
print(f"Response dimensions: {list(response.dimensions)}")
print(f"Embeddings data size: {len(response.embeddings_data)} bytes")
print(f"Missing IDs: {list(response.missing_ids)}")
# Calculate expected size
if len(response.dimensions) == 2:
batch_size = response.dimensions[0]
embedding_dim = response.dimensions[1]
expected_bytes = batch_size * embedding_dim * 4 # 4 bytes per float
print(f"Expected bytes: {expected_bytes}, Actual: {len(response.embeddings_data)}")
if len(response.embeddings_data) == expected_bytes:
print("✅ Response format is correct!")
return True
@@ -95,7 +98,7 @@ def test_zmq_with_same_model():
else:
print("❌ Invalid response dimensions!")
return False
except Exception as e:
print(f"❌ Error during ZMQ test: {e}")
return False
@@ -105,9 +108,10 @@ def test_zmq_with_same_model():
server_process.wait()
print("Server terminated")
if __name__ == "__main__":
success = test_zmq_with_same_model()
if success:
print("\n✅ ZMQ communication test passed!")
else:
print("\n❌ ZMQ communication test failed!")
print("\n❌ ZMQ communication test failed!")

View File

@@ -1,26 +1,27 @@
import time
from dataclasses import dataclass
from typing import Dict, List
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm
from transformers import AutoModel
# Add MLX imports
try:
import mlx.core as mx
from mlx_lm.utils import load
MLX_AVAILABLE = True
except ImportError as e:
except ImportError:
print("MLX not available. Install with: uv pip install mlx mlx-lm")
MLX_AVAILABLE = False
@dataclass
class BenchmarkConfig:
model_path: str = "facebook/contriever"
batch_sizes: List[int] = None
batch_sizes: list[int] = None
seq_length: int = 256
num_runs: int = 5
use_fp16: bool = True
@@ -30,18 +31,19 @@ class BenchmarkConfig:
use_flash_attention: bool = False
use_linear8bitlt: bool = False
use_mlx: bool = False # New flag for MLX testing
def __post_init__(self):
if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
class MLXBenchmark:
"""MLX-specific benchmark for embedding models"""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.model, self.tokenizer = self._load_model()
def _load_model(self):
"""Load MLX model and tokenizer following the API pattern"""
print(f"Loading MLX model from {self.config.model_path}...")
@@ -52,55 +54,51 @@ class MLXBenchmark:
except Exception as e:
print(f"Error loading MLX model: {e}")
raise
def _create_random_batch(self, batch_size: int):
"""Create random input batches for MLX testing - same as PyTorch"""
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
dtype=torch.long
)
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
def _run_inference(self, input_ids: torch.Tensor) -> float:
"""Run MLX inference with same input as PyTorch"""
start_time = time.time()
try:
# Convert PyTorch tensor to MLX array
input_ids_mlx = mx.array(input_ids.numpy())
# Get embeddings
embeddings = self.model(input_ids_mlx)
# Mean pooling (following the API pattern)
pooled = embeddings.mean(axis=1)
# Convert to numpy (following the API pattern)
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
# Force computation
_ = pooled_numpy.shape
except Exception as e:
print(f"MLX inference error: {e}")
return float('inf')
return float("inf")
end_time = time.time()
return end_time - start_time
def run(self) -> Dict[int, Dict[str, float]]:
def run(self) -> dict[int, dict[str, float]]:
"""Run the MLX benchmark across all batch sizes"""
results = {}
print(f"Starting MLX benchmark with model: {self.config.model_path}")
print(f"Testing batch sizes: {self.config.batch_sizes}")
for batch_size in self.config.batch_sizes:
print(f"\n=== Testing MLX batch size: {batch_size} ===")
times = []
# Create input batch (same as PyTorch)
input_ids = self._create_random_batch(batch_size)
# Warm up
print("Warming up...")
for _ in range(3):
@@ -109,26 +107,26 @@ class MLXBenchmark:
except Exception as e:
print(f"Warmup error: {e}")
break
# Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
if elapsed_time != float('inf'):
if elapsed_time != float("inf"):
times.append(elapsed_time)
except Exception as e:
print(f"Error during MLX inference: {e}")
break
if not times:
print(f"Skipping batch size {batch_size} due to errors")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
@@ -136,179 +134,166 @@ class MLXBenchmark:
"min_time": np.min(times),
"max_time": np.max(times),
}
print(f"MLX Results for batch size {batch_size}:")
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f" Min Time: {np.min(times):.4f}s")
print(f" Max Time: {np.max(times):.4f}s")
print(f" Throughput: {throughput:.2f} sequences/second")
return results
class Benchmark:
def __init__(self, config: BenchmarkConfig):
self.config = config
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
self.device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.model = self._load_model()
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
model = AutoModel.from_pretrained(self.config.model_path)
if self.config.use_fp16:
model = model.half()
model = torch.compile(model)
model = model.to(self.device)
model.eval()
return model
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device=self.device,
dtype=torch.long
0, 1000, (batch_size, self.config.seq_length), device=self.device, dtype=torch.long
)
def _run_inference(self, input_ids: torch.Tensor) -> float:
attention_mask = torch.ones_like(input_ids)
start_time = time.time()
with torch.no_grad():
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
self.model(input_ids=input_ids, attention_mask=attention_mask)
end_time = time.time()
return end_time - start_time
def run(self) -> Dict[int, Dict[str, float]]:
def run(self) -> dict[int, dict[str, float]]:
results = {}
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
input_ids = self._create_random_batch(batch_size)
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
continue
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
else:
peak_memory_gb = 0.0
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def run_benchmark():
"""Main function to run the benchmark with optimized parameters."""
config = BenchmarkConfig()
try:
benchmark = Benchmark(config)
results = benchmark.run()
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results
"results": results,
}
except Exception as e:
print(f"Benchmark failed: {e}")
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": str(e)
}
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
def run_mlx_benchmark():
"""Run MLX-specific benchmark"""
if not MLX_AVAILABLE:
print("MLX not available, skipping MLX benchmark")
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "MLX not available"
}
config = BenchmarkConfig(
model_path="mlx-community/all-MiniLM-L6-v2-4bit",
use_mlx=True
)
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "MLX not available"}
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
try:
benchmark = MLXBenchmark(config)
results = benchmark.run()
if not results:
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "No valid results"
}
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "No valid results"}
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results
"results": results,
}
except Exception as e:
print(f"MLX benchmark failed: {e}")
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": str(e)
}
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
if __name__ == "__main__":
print("=== PyTorch Benchmark ===")
pytorch_result = run_benchmark()
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
print("\n=== MLX Benchmark ===")
mlx_result = run_mlx_benchmark()
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
# Compare results
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0:
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput']
print(f"\n=== Comparison ===")
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
print("\n=== Comparison ===")
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")