* refactor: Unify examples interface with BaseRAGExample - Create BaseRAGExample base class for all RAG examples - Refactor 4 examples to use unified interface: - document_rag.py (replaces main_cli_example.py) - email_rag.py (replaces mail_reader_leann.py) - browser_rag.py (replaces google_history_reader_leann.py) - wechat_rag.py (replaces wechat_history_reader_leann.py) - Maintain 100% parameter compatibility with original files - Add interactive mode support for all examples - Unify parameter names (--max-items replaces --max-emails/--max-entries) - Update README.md with new examples usage - Add PARAMETER_CONSISTENCY.md documenting all parameter mappings - Keep main_cli_example.py for backward compatibility with migration notice All default values, LeannBuilder parameters, and chunking settings remain identical to ensure full compatibility with existing indexes. * fix: Update CI tests for new unified examples interface - Rename test_main_cli.py to test_document_rag.py - Update all references from main_cli_example.py to document_rag.py - Update tests/README.md documentation The tests now properly test the new unified interface while maintaining the same test coverage and functionality. * fix: Fix pre-commit issues and update tests - Fix import sorting and unused imports - Update type annotations to use built-in types (list, dict) instead of typing.List/Dict - Fix trailing whitespace and end-of-file issues - Fix Chinese fullwidth comma to regular comma - Update test_main_cli.py to test_document_rag.py - Add backward compatibility test for main_cli_example.py - Pass all pre-commit hooks (ruff, ruff-format, etc.) * refactor: Remove old example scripts and migration references - Delete old example scripts (mail_reader_leann.py, google_history_reader_leann.py, etc.) - Remove migration hints and backward compatibility - Update tests to use new unified examples directly - Clean up all references to old script names - Users now only see the new unified interface * fix: Restore embedding-mode parameter to all examples - All examples now have --embedding-mode parameter (unified interface benefit) - Default is 'sentence-transformers' (consistent with original behavior) - Users can now use OpenAI or MLX embeddings with any data source - Maintains functional equivalence with original scripts * docs: Improve parameter categorization in README - Clearly separate core (shared) vs specific parameters - Move LLM and embedding examples to 'Example Commands' section - Add descriptive comments for all specific parameters - Keep only truly data-source-specific parameters in specific sections * docs: Make example commands more representative - Add default values to parameter descriptions - Replace generic examples with real-world use cases - Focus on data-source-specific features in examples - Remove redundant demonstrations of common parameters * docs: Reorganize parameter documentation structure - Move common parameters to a dedicated section before all examples - Rename sections to 'X-Specific Arguments' for clarity - Remove duplicate common parameters from individual examples - Better information architecture for users * docs: polish applications * docs: Add CLI installation instructions - Add two installation options: venv and global uv tool - Clearly explain when to use each option - Make CLI more accessible for daily use * docs: Clarify CLI global installation process - Explain the transition from venv to global installation - Add upgrade command for global installation - Make it clear that global install allows usage without venv activation * docs: Add collapsible section for CLI installation - Wrap CLI installation instructions in details/summary tags - Keep consistent with other collapsible sections in README - Improve document readability and navigation * style: format * docs: Fix collapsible sections - Make Common Parameters collapsible (as it's lengthy reference material) - Keep CLI Installation visible (important for users to see immediately) - Better information hierarchy * docs: Add introduction for Common Parameters section - Add 'Flexible Configuration' heading with descriptive sentence - Create parallel structure with 'Generation Model Setup' section - Improve document flow and readability * docs: nit * fix: Fix issues in unified examples - Add smart path detection for data directory - Fix add_texts -> add_text method call - Handle both running from project root and examples directory * fix: Fix async/await and add_text issues in unified examples - Remove incorrect await from chat.ask() calls (not async) - Fix add_texts -> add_text method calls - Verify search-complexity correctly maps to efSearch parameter - All examples now run successfully * feat: Address review comments - Add complexity parameter to LeannChat initialization (default: search_complexity) - Fix chunk-size default in README documentation (256, not 2048) - Add more index building parameters as CLI arguments: - --backend-name (hnsw/diskann) - --graph-degree (default: 32) - --build-complexity (default: 64) - --no-compact (disable compact storage) - --no-recompute (disable embedding recomputation) - Update README to document all new parameters * feat: Add chunk-size parameters and improve file type filtering - Add --chunk-size and --chunk-overlap parameters to all RAG examples - Preserve original default values for each data source: - Document: 256/128 (optimized for general documents) - Email: 256/25 (smaller overlap for email threads) - Browser: 256/128 (standard for web content) - WeChat: 192/64 (smaller chunks for chat messages) - Make --file-types optional filter instead of restriction in document_rag - Update README to clarify interactive mode and parameter usage - Fix LLM default model documentation (gpt-4o, not gpt-4o-mini) * feat: Update documentation based on review feedback - Add MLX embedding example to README - Clarify examples/data content description (two papers, Pride and Prejudice, Chinese README) - Move chunk parameters to common parameters section - Remove duplicate chunk parameters from document-specific section * docs: Emphasize diverse data sources in examples/data description * fix: update default embedding models for better performance - Change WeChat, Browser, and Email RAG examples to use all-MiniLM-L6-v2 - Previous Qwen/Qwen3-Embedding-0.6B was too slow for these use cases - all-MiniLM-L6-v2 is a fast 384-dim model, ideal for large-scale personal data * add response highlight * change rebuild logic * fix some example * feat: check if k is larger than #docs * fix: WeChat history reader bugs and refactor wechat_rag to use unified architecture * fix email wrong -1 to process all file * refactor: reorgnize all examples/ and test/ * refactor: reorganize examples and add link checker * fix: add init.py * fix: handle certificate errors in link checker * fix wechat * merge * docs: update README to use proper module imports for apps - Change from 'python apps/xxx.py' to 'python -m apps.xxx' - More professional and pythonic module calling - Ensures proper module resolution and imports - Better separation between apps/ (production tools) and examples/ (demos) --------- Co-authored-by: yichuan520030910320 <yichuan_wang@berkeley.edu>
312 lines
10 KiB
Python
312 lines
10 KiB
Python
import time
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from tqdm import tqdm
|
|
from transformers import AutoModel
|
|
|
|
# Add MLX imports
|
|
try:
|
|
import mlx.core as mx
|
|
from mlx_lm.utils import load
|
|
|
|
MLX_AVAILABLE = True
|
|
except ImportError:
|
|
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
|
MLX_AVAILABLE = False
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkConfig:
|
|
model_path: str = "facebook/contriever"
|
|
batch_sizes: list[int] = None
|
|
seq_length: int = 256
|
|
num_runs: int = 5
|
|
use_fp16: bool = True
|
|
use_int4: bool = False
|
|
use_int8: bool = False
|
|
use_cuda_graphs: bool = False
|
|
use_flash_attention: bool = False
|
|
use_linear8bitlt: bool = False
|
|
use_mlx: bool = False # New flag for MLX testing
|
|
|
|
def __post_init__(self):
|
|
if self.batch_sizes is None:
|
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
|
|
|
|
|
class MLXBenchmark:
|
|
"""MLX-specific benchmark for embedding models"""
|
|
|
|
def __init__(self, config: BenchmarkConfig):
|
|
self.config = config
|
|
self.model, self.tokenizer = self._load_model()
|
|
|
|
def _load_model(self):
|
|
"""Load MLX model and tokenizer following the API pattern"""
|
|
print(f"Loading MLX model from {self.config.model_path}...")
|
|
try:
|
|
model, tokenizer = load(self.config.model_path)
|
|
print("MLX model loaded successfully")
|
|
return model, tokenizer
|
|
except Exception as e:
|
|
print(f"Error loading MLX model: {e}")
|
|
raise
|
|
|
|
def _create_random_batch(self, batch_size: int):
|
|
"""Create random input batches for MLX testing - same as PyTorch"""
|
|
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
|
|
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
|
"""Run MLX inference with same input as PyTorch"""
|
|
start_time = time.time()
|
|
try:
|
|
# Convert PyTorch tensor to MLX array
|
|
input_ids_mlx = mx.array(input_ids.numpy())
|
|
|
|
# Get embeddings
|
|
embeddings = self.model(input_ids_mlx)
|
|
|
|
# Mean pooling (following the API pattern)
|
|
pooled = embeddings.mean(axis=1)
|
|
|
|
# Convert to numpy (following the API pattern)
|
|
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
|
|
|
|
# Force computation
|
|
_ = pooled_numpy.shape
|
|
|
|
except Exception as e:
|
|
print(f"MLX inference error: {e}")
|
|
return float("inf")
|
|
end_time = time.time()
|
|
|
|
return end_time - start_time
|
|
|
|
def run(self) -> dict[int, dict[str, float]]:
|
|
"""Run the MLX benchmark across all batch sizes"""
|
|
results = {}
|
|
|
|
print(f"Starting MLX benchmark with model: {self.config.model_path}")
|
|
print(f"Testing batch sizes: {self.config.batch_sizes}")
|
|
|
|
for batch_size in self.config.batch_sizes:
|
|
print(f"\n=== Testing MLX batch size: {batch_size} ===")
|
|
times = []
|
|
|
|
# Create input batch (same as PyTorch)
|
|
input_ids = self._create_random_batch(batch_size)
|
|
|
|
# Warm up
|
|
print("Warming up...")
|
|
for _ in range(3):
|
|
try:
|
|
self._run_inference(input_ids[:2]) # Warm up with smaller batch
|
|
except Exception as e:
|
|
print(f"Warmup error: {e}")
|
|
break
|
|
|
|
# Run benchmark
|
|
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
|
try:
|
|
elapsed_time = self._run_inference(input_ids)
|
|
if elapsed_time != float("inf"):
|
|
times.append(elapsed_time)
|
|
except Exception as e:
|
|
print(f"Error during MLX inference: {e}")
|
|
break
|
|
|
|
if not times:
|
|
print(f"Skipping batch size {batch_size} due to errors")
|
|
continue
|
|
|
|
# Calculate statistics
|
|
avg_time = np.mean(times)
|
|
std_time = np.std(times)
|
|
throughput = batch_size / avg_time
|
|
|
|
results[batch_size] = {
|
|
"avg_time": avg_time,
|
|
"std_time": std_time,
|
|
"throughput": throughput,
|
|
"min_time": np.min(times),
|
|
"max_time": np.max(times),
|
|
}
|
|
|
|
print(f"MLX Results for batch size {batch_size}:")
|
|
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
|
print(f" Min Time: {np.min(times):.4f}s")
|
|
print(f" Max Time: {np.max(times):.4f}s")
|
|
print(f" Throughput: {throughput:.2f} sequences/second")
|
|
|
|
return results
|
|
|
|
|
|
class Benchmark:
|
|
def __init__(self, config: BenchmarkConfig):
|
|
self.config = config
|
|
self.device = (
|
|
"cuda"
|
|
if torch.cuda.is_available()
|
|
else "mps"
|
|
if torch.backends.mps.is_available()
|
|
else "cpu"
|
|
)
|
|
self.model = self._load_model()
|
|
|
|
def _load_model(self) -> nn.Module:
|
|
print(f"Loading model from {self.config.model_path}...")
|
|
|
|
model = AutoModel.from_pretrained(self.config.model_path)
|
|
if self.config.use_fp16:
|
|
model = model.half()
|
|
model = torch.compile(model)
|
|
model = model.to(self.device)
|
|
|
|
model.eval()
|
|
return model
|
|
|
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
|
return torch.randint(
|
|
0,
|
|
1000,
|
|
(batch_size, self.config.seq_length),
|
|
device=self.device,
|
|
dtype=torch.long,
|
|
)
|
|
|
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
start_time = time.time()
|
|
with torch.no_grad():
|
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
end_time = time.time()
|
|
|
|
return end_time - start_time
|
|
|
|
def run(self) -> dict[int, dict[str, float]]:
|
|
results = {}
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
for batch_size in self.config.batch_sizes:
|
|
print(f"\nTesting batch size: {batch_size}")
|
|
times = []
|
|
|
|
input_ids = self._create_random_batch(batch_size)
|
|
|
|
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
|
try:
|
|
elapsed_time = self._run_inference(input_ids)
|
|
times.append(elapsed_time)
|
|
except Exception as e:
|
|
print(f"Error during inference: {e}")
|
|
break
|
|
|
|
if not times:
|
|
continue
|
|
|
|
avg_time = np.mean(times)
|
|
std_time = np.std(times)
|
|
throughput = batch_size / avg_time
|
|
|
|
results[batch_size] = {
|
|
"avg_time": avg_time,
|
|
"std_time": std_time,
|
|
"throughput": throughput,
|
|
}
|
|
|
|
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
|
print(f"Throughput: {throughput:.2f} sequences/second")
|
|
|
|
if torch.cuda.is_available():
|
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
|
else:
|
|
peak_memory_gb = 0.0
|
|
|
|
for batch_size in results:
|
|
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
|
|
|
return results
|
|
|
|
|
|
def run_benchmark():
|
|
"""Main function to run the benchmark with optimized parameters."""
|
|
config = BenchmarkConfig()
|
|
|
|
try:
|
|
benchmark = Benchmark(config)
|
|
results = benchmark.run()
|
|
|
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
|
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
|
|
|
return {
|
|
"max_throughput": max_throughput,
|
|
"avg_throughput": avg_throughput,
|
|
"results": results,
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Benchmark failed: {e}")
|
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
|
|
|
|
|
def run_mlx_benchmark():
|
|
"""Run MLX-specific benchmark"""
|
|
if not MLX_AVAILABLE:
|
|
print("MLX not available, skipping MLX benchmark")
|
|
return {
|
|
"max_throughput": 0.0,
|
|
"avg_throughput": 0.0,
|
|
"error": "MLX not available",
|
|
}
|
|
|
|
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
|
|
|
try:
|
|
benchmark = MLXBenchmark(config)
|
|
results = benchmark.run()
|
|
|
|
if not results:
|
|
return {
|
|
"max_throughput": 0.0,
|
|
"avg_throughput": 0.0,
|
|
"error": "No valid results",
|
|
}
|
|
|
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
|
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
|
|
|
return {
|
|
"max_throughput": max_throughput,
|
|
"avg_throughput": avg_throughput,
|
|
"results": results,
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"MLX benchmark failed: {e}")
|
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("=== PyTorch Benchmark ===")
|
|
pytorch_result = run_benchmark()
|
|
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
|
|
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
|
|
|
|
print("\n=== MLX Benchmark ===")
|
|
mlx_result = run_mlx_benchmark()
|
|
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
|
|
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
|
|
|
# Compare results
|
|
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
|
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
|
print("\n=== Comparison ===")
|
|
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|