142 lines
4.3 KiB
Python
142 lines
4.3 KiB
Python
import time
|
|
|
|
import matplotlib.pyplot as plt
|
|
import mlx.core as mx
|
|
import numpy as np
|
|
import torch
|
|
from mlx_lm import load
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
# --- Configuration ---
|
|
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
|
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
|
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
|
NUM_RUNS = 10 # Number of runs to average for each batch size
|
|
WARMUP_RUNS = 2 # Number of warm-up runs
|
|
|
|
# --- Generate Dummy Data ---
|
|
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
|
|
|
# --- Benchmark Functions ---b
|
|
|
|
|
|
def benchmark_torch(model, sentences):
|
|
start_time = time.time()
|
|
model.encode(sentences, convert_to_numpy=True)
|
|
end_time = time.time()
|
|
return (end_time - start_time) * 1000 # Return time in ms
|
|
|
|
|
|
def benchmark_mlx(model, tokenizer, sentences):
|
|
start_time = time.time()
|
|
|
|
# Tokenize sentences using MLX tokenizer
|
|
tokens = []
|
|
for sentence in sentences:
|
|
token_ids = tokenizer.encode(sentence)
|
|
tokens.append(token_ids)
|
|
|
|
# Pad sequences to the same length
|
|
max_len = max(len(t) for t in tokens)
|
|
input_ids = []
|
|
attention_mask = []
|
|
|
|
for token_seq in tokens:
|
|
# Pad sequence
|
|
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
|
|
input_ids.append(padded)
|
|
# Create attention mask (1 for real tokens, 0 for padding)
|
|
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
|
|
attention_mask.append(mask)
|
|
|
|
# Convert to MLX arrays
|
|
input_ids = mx.array(input_ids)
|
|
attention_mask = mx.array(attention_mask)
|
|
|
|
# Get embeddings
|
|
embeddings = model(input_ids)
|
|
|
|
# Mean pooling
|
|
mask = mx.expand_dims(attention_mask, -1)
|
|
sum_embeddings = (embeddings * mask).sum(axis=1)
|
|
sum_mask = mask.sum(axis=1)
|
|
_ = sum_embeddings / sum_mask
|
|
|
|
mx.eval() # Ensure computation is finished
|
|
end_time = time.time()
|
|
return (end_time - start_time) * 1000 # Return time in ms
|
|
|
|
|
|
# --- Main Execution ---
|
|
def main():
|
|
print("--- Initializing Models ---")
|
|
# Load PyTorch model
|
|
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
|
|
print(f"PyTorch model loaded on: {device}")
|
|
|
|
# Load MLX model
|
|
print(f"Loading MLX model: {MODEL_NAME_MLX}")
|
|
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
|
|
print("MLX model loaded.")
|
|
|
|
# --- Warm-up ---
|
|
print("\n--- Performing Warm-up Runs ---")
|
|
for _ in range(WARMUP_RUNS):
|
|
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
|
|
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
|
|
print("Warm-up complete.")
|
|
|
|
# --- Benchmarking ---
|
|
print("\n--- Starting Benchmark ---")
|
|
results_torch = []
|
|
results_mlx = []
|
|
|
|
for batch_size in BATCH_SIZES:
|
|
print(f"Benchmarking batch size: {batch_size}")
|
|
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
|
|
|
# Benchmark PyTorch
|
|
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
|
results_torch.append(np.mean(torch_times))
|
|
|
|
# Benchmark MLX
|
|
mlx_times = [
|
|
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
|
]
|
|
results_mlx.append(np.mean(mlx_times))
|
|
|
|
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
|
print(f"Batch Sizes: {BATCH_SIZES}")
|
|
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
|
|
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
|
|
|
|
# --- Plotting ---
|
|
print("\n--- Generating Plot ---")
|
|
plt.figure(figsize=(10, 6))
|
|
plt.plot(
|
|
BATCH_SIZES,
|
|
results_torch,
|
|
marker="o",
|
|
linestyle="-",
|
|
label=f"PyTorch ({device})",
|
|
)
|
|
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
|
|
|
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
|
plt.xlabel("Batch Size")
|
|
plt.ylabel("Average Time per Batch (ms)")
|
|
plt.xticks(BATCH_SIZES)
|
|
plt.grid(True)
|
|
plt.legend()
|
|
|
|
# Save the plot
|
|
output_filename = "embedding_benchmark.png"
|
|
plt.savefig(output_filename)
|
|
print(f"Plot saved to {output_filename}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|