add embedding api
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -35,7 +35,7 @@ build/
|
|||||||
nprobe_logs/
|
nprobe_logs/
|
||||||
micro/results
|
micro/results
|
||||||
micro/contriever-INT8
|
micro/contriever-INT8
|
||||||
examples/data/
|
examples/data/*
|
||||||
!examples/data/2501.14312v1 (1).pdf
|
!examples/data/2501.14312v1 (1).pdf
|
||||||
!examples/data/2506.08276v1.pdf
|
!examples/data/2506.08276v1.pdf
|
||||||
!examples/data/PrideandPrejudice.txt
|
!examples/data/PrideandPrejudice.txt
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def get_mail_path():
|
|||||||
# Default mail path for macOS
|
# Default mail path for macOS
|
||||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||||
|
|
||||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False):
|
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||||
"""
|
"""
|
||||||
Create LEANN index from multiple mail data sources.
|
Create LEANN index from multiple mail data sources.
|
||||||
|
|
||||||
@@ -101,7 +101,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
|||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model="facebook/contriever",
|
embedding_model=embedding_model,
|
||||||
graph_degree=32,
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
@@ -120,7 +120,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
|||||||
|
|
||||||
return index_path
|
return index_path
|
||||||
|
|
||||||
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False):
|
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||||
"""
|
"""
|
||||||
Create LEANN index from mail data.
|
Create LEANN index from mail data.
|
||||||
|
|
||||||
@@ -180,7 +180,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
|
|||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model="facebook/contriever",
|
embedding_model=embedding_model,
|
||||||
graph_degree=32,
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
@@ -239,6 +239,8 @@ async def main():
|
|||||||
help='Single query to run (default: runs example queries)')
|
help='Single query to run (default: runs example queries)')
|
||||||
parser.add_argument('--include-html', action='store_true', default=False,
|
parser.add_argument('--include-html', action='store_true', default=False,
|
||||||
help='Include HTML content in email processing (default: False)')
|
help='Include HTML content in email processing (default: False)')
|
||||||
|
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
||||||
|
help='Embedding model to use (default: facebook/contriever)')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -263,7 +265,7 @@ async def main():
|
|||||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
# Create or load the LEANN index from all sources
|
||||||
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html)
|
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
||||||
|
|
||||||
if index_path:
|
if index_path:
|
||||||
if args.query:
|
if args.query:
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str)
|
|||||||
# Generate embeddings
|
# Generate embeddings
|
||||||
# give use an warning if OOM here means we need to turn down the batch size
|
# give use an warning if OOM here means we need to turn down the batch size
|
||||||
embeddings = model.encode(
|
embeddings = model.encode(
|
||||||
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=256
|
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=8
|
||||||
)
|
)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|||||||
314
test/simple_mac_tpt_test.py
Normal file
314
test/simple_mac_tpt_test.py
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import AutoModel, BitsAndBytesConfig
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Add MLX imports
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.utils import load
|
||||||
|
MLX_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||||
|
MLX_AVAILABLE = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkConfig:
|
||||||
|
model_path: str = "facebook/contriever"
|
||||||
|
batch_sizes: List[int] = None
|
||||||
|
seq_length: int = 256
|
||||||
|
num_runs: int = 5
|
||||||
|
use_fp16: bool = True
|
||||||
|
use_int4: bool = False
|
||||||
|
use_int8: bool = False
|
||||||
|
use_cuda_graphs: bool = False
|
||||||
|
use_flash_attention: bool = False
|
||||||
|
use_linear8bitlt: bool = False
|
||||||
|
use_mlx: bool = False # New flag for MLX testing
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.batch_sizes is None:
|
||||||
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||||
|
|
||||||
|
class MLXBenchmark:
|
||||||
|
"""MLX-specific benchmark for embedding models"""
|
||||||
|
|
||||||
|
def __init__(self, config: BenchmarkConfig):
|
||||||
|
self.config = config
|
||||||
|
self.model, self.tokenizer = self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""Load MLX model and tokenizer following the API pattern"""
|
||||||
|
print(f"Loading MLX model from {self.config.model_path}...")
|
||||||
|
try:
|
||||||
|
model, tokenizer = load(self.config.model_path)
|
||||||
|
print("MLX model loaded successfully")
|
||||||
|
return model, tokenizer
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading MLX model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_random_batch(self, batch_size: int):
|
||||||
|
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||||
|
return torch.randint(
|
||||||
|
0, 1000,
|
||||||
|
(batch_size, self.config.seq_length),
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
|
"""Run MLX inference with same input as PyTorch"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
# Convert PyTorch tensor to MLX array
|
||||||
|
input_ids_mlx = mx.array(input_ids.numpy())
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
embeddings = self.model(input_ids_mlx)
|
||||||
|
|
||||||
|
# Mean pooling (following the API pattern)
|
||||||
|
pooled = embeddings.mean(axis=1)
|
||||||
|
|
||||||
|
# Convert to numpy (following the API pattern)
|
||||||
|
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
|
||||||
|
|
||||||
|
# Force computation
|
||||||
|
_ = pooled_numpy.shape
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"MLX inference error: {e}")
|
||||||
|
return float('inf')
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
return end_time - start_time
|
||||||
|
|
||||||
|
def run(self) -> Dict[int, Dict[str, float]]:
|
||||||
|
"""Run the MLX benchmark across all batch sizes"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
print(f"Starting MLX benchmark with model: {self.config.model_path}")
|
||||||
|
print(f"Testing batch sizes: {self.config.batch_sizes}")
|
||||||
|
|
||||||
|
for batch_size in self.config.batch_sizes:
|
||||||
|
print(f"\n=== Testing MLX batch size: {batch_size} ===")
|
||||||
|
times = []
|
||||||
|
|
||||||
|
# Create input batch (same as PyTorch)
|
||||||
|
input_ids = self._create_random_batch(batch_size)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
print("Warming up...")
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
self._run_inference(input_ids[:2]) # Warm up with smaller batch
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warmup error: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||||
|
try:
|
||||||
|
elapsed_time = self._run_inference(input_ids)
|
||||||
|
if elapsed_time != float('inf'):
|
||||||
|
times.append(elapsed_time)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during MLX inference: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not times:
|
||||||
|
print(f"Skipping batch size {batch_size} due to errors")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
avg_time = np.mean(times)
|
||||||
|
std_time = np.std(times)
|
||||||
|
throughput = batch_size / avg_time
|
||||||
|
|
||||||
|
results[batch_size] = {
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"std_time": std_time,
|
||||||
|
"throughput": throughput,
|
||||||
|
"min_time": np.min(times),
|
||||||
|
"max_time": np.max(times),
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"MLX Results for batch size {batch_size}:")
|
||||||
|
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||||
|
print(f" Min Time: {np.min(times):.4f}s")
|
||||||
|
print(f" Max Time: {np.max(times):.4f}s")
|
||||||
|
print(f" Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
class Benchmark:
|
||||||
|
def __init__(self, config: BenchmarkConfig):
|
||||||
|
self.config = config
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
self.model = self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self) -> nn.Module:
|
||||||
|
print(f"Loading model from {self.config.model_path}...")
|
||||||
|
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(self.config.model_path)
|
||||||
|
if self.config.use_fp16:
|
||||||
|
model = model.half()
|
||||||
|
model = torch.compile(model)
|
||||||
|
model = model.to(self.device)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
|
return torch.randint(
|
||||||
|
0, 1000,
|
||||||
|
(batch_size, self.config.seq_length),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
return end_time - start_time
|
||||||
|
|
||||||
|
def run(self) -> Dict[int, Dict[str, float]]:
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
for batch_size in self.config.batch_sizes:
|
||||||
|
print(f"\nTesting batch size: {batch_size}")
|
||||||
|
times = []
|
||||||
|
|
||||||
|
input_ids = self._create_random_batch(batch_size)
|
||||||
|
|
||||||
|
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||||
|
try:
|
||||||
|
elapsed_time = self._run_inference(input_ids)
|
||||||
|
times.append(elapsed_time)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during inference: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not times:
|
||||||
|
continue
|
||||||
|
|
||||||
|
avg_time = np.mean(times)
|
||||||
|
std_time = np.std(times)
|
||||||
|
throughput = batch_size / avg_time
|
||||||
|
|
||||||
|
results[batch_size] = {
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"std_time": std_time,
|
||||||
|
"throughput": throughput,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||||
|
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||||
|
else:
|
||||||
|
peak_memory_gb = 0.0
|
||||||
|
|
||||||
|
for batch_size in results:
|
||||||
|
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def run_benchmark():
|
||||||
|
"""Main function to run the benchmark with optimized parameters."""
|
||||||
|
config = BenchmarkConfig()
|
||||||
|
|
||||||
|
try:
|
||||||
|
benchmark = Benchmark(config)
|
||||||
|
results = benchmark.run()
|
||||||
|
|
||||||
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
|
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"max_throughput": max_throughput,
|
||||||
|
"avg_throughput": avg_throughput,
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Benchmark failed: {e}")
|
||||||
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def run_mlx_benchmark():
|
||||||
|
"""Run MLX-specific benchmark"""
|
||||||
|
if not MLX_AVAILABLE:
|
||||||
|
print("MLX not available, skipping MLX benchmark")
|
||||||
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": "MLX not available"
|
||||||
|
}
|
||||||
|
|
||||||
|
config = BenchmarkConfig(
|
||||||
|
model_path="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
|
||||||
|
use_mlx=True
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
benchmark = MLXBenchmark(config)
|
||||||
|
results = benchmark.run()
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": "No valid results"
|
||||||
|
}
|
||||||
|
|
||||||
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
|
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"max_throughput": max_throughput,
|
||||||
|
"avg_throughput": avg_throughput,
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"MLX benchmark failed: {e}")
|
||||||
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=== PyTorch Benchmark ===")
|
||||||
|
pytorch_result = run_benchmark()
|
||||||
|
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
|
||||||
|
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
|
||||||
|
|
||||||
|
print("\n=== MLX Benchmark ===")
|
||||||
|
mlx_result = run_mlx_benchmark()
|
||||||
|
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
|
||||||
|
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0:
|
||||||
|
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput']
|
||||||
|
print(f"\n=== Comparison ===")
|
||||||
|
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||||
Reference in New Issue
Block a user