diff --git a/.gitignore b/.gitignore index 71b34cb..ea88898 100755 --- a/.gitignore +++ b/.gitignore @@ -35,7 +35,7 @@ build/ nprobe_logs/ micro/results micro/contriever-INT8 -examples/data/ +examples/data/* !examples/data/2501.14312v1 (1).pdf !examples/data/2506.08276v1.pdf !examples/data/PrideandPrejudice.txt diff --git a/examples/mail_reader_leann.py b/examples/mail_reader_leann.py index 4c1a990..fbdde1b 100644 --- a/examples/mail_reader_leann.py +++ b/examples/mail_reader_leann.py @@ -24,7 +24,7 @@ def get_mail_path(): # Default mail path for macOS # DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data" -def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False): +def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"): """ Create LEANN index from multiple mail data sources. @@ -101,7 +101,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa # Use HNSW backend for better macOS compatibility builder = LeannBuilder( backend_name="hnsw", - embedding_model="facebook/contriever", + embedding_model=embedding_model, graph_degree=32, complexity=64, is_compact=True, @@ -120,7 +120,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa return index_path -def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False): +def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"): """ Create LEANN index from mail data. @@ -180,7 +180,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max # Use HNSW backend for better macOS compatibility builder = LeannBuilder( backend_name="hnsw", - embedding_model="facebook/contriever", + embedding_model=embedding_model, graph_degree=32, complexity=64, is_compact=True, @@ -239,6 +239,8 @@ async def main(): help='Single query to run (default: runs example queries)') parser.add_argument('--include-html', action='store_true', default=False, help='Include HTML content in email processing (default: False)') + parser.add_argument('--embedding-model', type=str, default="facebook/contriever", + help='Embedding model to use (default: facebook/contriever)') args = parser.parse_args() @@ -263,7 +265,7 @@ async def main(): print(f"Found {len(messages_dirs)} Messages directories.") # Create or load the LEANN index from all sources - index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html) + index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model) if index_path: if args.query: diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 2042ac8..26bc644 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -76,7 +76,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) # Generate embeddings # give use an warning if OOM here means we need to turn down the batch size embeddings = model.encode( - chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=256 + chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=8 ) return embeddings diff --git a/test/simple_mac_tpt_test.py b/test/simple_mac_tpt_test.py new file mode 100644 index 0000000..6aaac13 --- /dev/null +++ b/test/simple_mac_tpt_test.py @@ -0,0 +1,314 @@ +import time +from dataclasses import dataclass +from typing import Dict, List + +import numpy as np +import torch +from torch import nn +from transformers import AutoModel, BitsAndBytesConfig +from tqdm import tqdm + +# Add MLX imports +try: + import mlx.core as mx + from mlx_lm.utils import load + MLX_AVAILABLE = True +except ImportError as e: + print("MLX not available. Install with: uv pip install mlx mlx-lm") + MLX_AVAILABLE = False + +@dataclass +class BenchmarkConfig: + model_path: str = "facebook/contriever" + batch_sizes: List[int] = None + seq_length: int = 256 + num_runs: int = 5 + use_fp16: bool = True + use_int4: bool = False + use_int8: bool = False + use_cuda_graphs: bool = False + use_flash_attention: bool = False + use_linear8bitlt: bool = False + use_mlx: bool = False # New flag for MLX testing + + def __post_init__(self): + if self.batch_sizes is None: + self.batch_sizes = [1, 2, 4, 8, 16, 32, 64] + +class MLXBenchmark: + """MLX-specific benchmark for embedding models""" + + def __init__(self, config: BenchmarkConfig): + self.config = config + self.model, self.tokenizer = self._load_model() + + def _load_model(self): + """Load MLX model and tokenizer following the API pattern""" + print(f"Loading MLX model from {self.config.model_path}...") + try: + model, tokenizer = load(self.config.model_path) + print("MLX model loaded successfully") + return model, tokenizer + except Exception as e: + print(f"Error loading MLX model: {e}") + raise + + def _create_random_batch(self, batch_size: int): + """Create random input batches for MLX testing - same as PyTorch""" + return torch.randint( + 0, 1000, + (batch_size, self.config.seq_length), + dtype=torch.long + ) + + def _run_inference(self, input_ids: torch.Tensor) -> float: + """Run MLX inference with same input as PyTorch""" + start_time = time.time() + try: + # Convert PyTorch tensor to MLX array + input_ids_mlx = mx.array(input_ids.numpy()) + + # Get embeddings + embeddings = self.model(input_ids_mlx) + + # Mean pooling (following the API pattern) + pooled = embeddings.mean(axis=1) + + # Convert to numpy (following the API pattern) + pooled_numpy = np.array(pooled.tolist(), dtype=np.float32) + + # Force computation + _ = pooled_numpy.shape + + except Exception as e: + print(f"MLX inference error: {e}") + return float('inf') + end_time = time.time() + + return end_time - start_time + + def run(self) -> Dict[int, Dict[str, float]]: + """Run the MLX benchmark across all batch sizes""" + results = {} + + print(f"Starting MLX benchmark with model: {self.config.model_path}") + print(f"Testing batch sizes: {self.config.batch_sizes}") + + for batch_size in self.config.batch_sizes: + print(f"\n=== Testing MLX batch size: {batch_size} ===") + times = [] + + # Create input batch (same as PyTorch) + input_ids = self._create_random_batch(batch_size) + + # Warm up + print("Warming up...") + for _ in range(3): + try: + self._run_inference(input_ids[:2]) # Warm up with smaller batch + except Exception as e: + print(f"Warmup error: {e}") + break + + # Run benchmark + for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"): + try: + elapsed_time = self._run_inference(input_ids) + if elapsed_time != float('inf'): + times.append(elapsed_time) + except Exception as e: + print(f"Error during MLX inference: {e}") + break + + if not times: + print(f"Skipping batch size {batch_size} due to errors") + continue + + # Calculate statistics + avg_time = np.mean(times) + std_time = np.std(times) + throughput = batch_size / avg_time + + results[batch_size] = { + "avg_time": avg_time, + "std_time": std_time, + "throughput": throughput, + "min_time": np.min(times), + "max_time": np.max(times), + } + + print(f"MLX Results for batch size {batch_size}:") + print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s") + print(f" Min Time: {np.min(times):.4f}s") + print(f" Max Time: {np.max(times):.4f}s") + print(f" Throughput: {throughput:.2f} sequences/second") + + return results + +class Benchmark: + def __init__(self, config: BenchmarkConfig): + self.config = config + self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + self.model = self._load_model() + + def _load_model(self) -> nn.Module: + print(f"Loading model from {self.config.model_path}...") + + + model = AutoModel.from_pretrained(self.config.model_path) + if self.config.use_fp16: + model = model.half() + model = torch.compile(model) + model = model.to(self.device) + + model.eval() + return model + + def _create_random_batch(self, batch_size: int) -> torch.Tensor: + return torch.randint( + 0, 1000, + (batch_size, self.config.seq_length), + device=self.device, + dtype=torch.long + ) + + def _run_inference(self, input_ids: torch.Tensor) -> float: + attention_mask = torch.ones_like(input_ids) + + start_time = time.time() + with torch.no_grad(): + output = self.model(input_ids=input_ids, attention_mask=attention_mask) + end_time = time.time() + + return end_time - start_time + + def run(self) -> Dict[int, Dict[str, float]]: + results = {} + + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + + for batch_size in self.config.batch_sizes: + print(f"\nTesting batch size: {batch_size}") + times = [] + + input_ids = self._create_random_batch(batch_size) + + for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"): + try: + elapsed_time = self._run_inference(input_ids) + times.append(elapsed_time) + except Exception as e: + print(f"Error during inference: {e}") + break + + if not times: + continue + + avg_time = np.mean(times) + std_time = np.std(times) + throughput = batch_size / avg_time + + results[batch_size] = { + "avg_time": avg_time, + "std_time": std_time, + "throughput": throughput, + } + + print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s") + print(f"Throughput: {throughput:.2f} sequences/second") + + if torch.cuda.is_available(): + peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3) + else: + peak_memory_gb = 0.0 + + for batch_size in results: + results[batch_size]["peak_memory_gb"] = peak_memory_gb + + return results + +def run_benchmark(): + """Main function to run the benchmark with optimized parameters.""" + config = BenchmarkConfig() + + try: + benchmark = Benchmark(config) + results = benchmark.run() + + max_throughput = max(results[batch_size]["throughput"] for batch_size in results) + avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results]) + + return { + "max_throughput": max_throughput, + "avg_throughput": avg_throughput, + "results": results + } + + except Exception as e: + print(f"Benchmark failed: {e}") + return { + "max_throughput": 0.0, + "avg_throughput": 0.0, + "error": str(e) + } + +def run_mlx_benchmark(): + """Run MLX-specific benchmark""" + if not MLX_AVAILABLE: + print("MLX not available, skipping MLX benchmark") + return { + "max_throughput": 0.0, + "avg_throughput": 0.0, + "error": "MLX not available" + } + + config = BenchmarkConfig( + model_path="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", + use_mlx=True + ) + + try: + benchmark = MLXBenchmark(config) + results = benchmark.run() + + if not results: + return { + "max_throughput": 0.0, + "avg_throughput": 0.0, + "error": "No valid results" + } + + max_throughput = max(results[batch_size]["throughput"] for batch_size in results) + avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results]) + + return { + "max_throughput": max_throughput, + "avg_throughput": avg_throughput, + "results": results + } + + except Exception as e: + print(f"MLX benchmark failed: {e}") + return { + "max_throughput": 0.0, + "avg_throughput": 0.0, + "error": str(e) + } + +if __name__ == "__main__": + print("=== PyTorch Benchmark ===") + pytorch_result = run_benchmark() + print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second") + print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second") + + print("\n=== MLX Benchmark ===") + mlx_result = run_mlx_benchmark() + print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second") + print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second") + + # Compare results + if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0: + speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput'] + print(f"\n=== Comparison ===") + print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch") \ No newline at end of file