reader: non-destructive portability (relative hints + fallback); fix comments; sky: refine yaml
This commit is contained in:
121
benchmark_embeddings_simulated.py
Normal file
121
benchmark_embeddings_simulated.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||
BATCH_SIZES = [1, 8, 16, 32, 64, 128, 256]
|
||||
NUM_RUNS = 10
|
||||
WARMUP_RUNS = 2
|
||||
SEQ_LENGTH = 256
|
||||
EMBED_DIM = 768 # Dimension for all-mpnet-base-v2
|
||||
|
||||
# --- Generate Dummy Data ---
|
||||
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||
|
||||
|
||||
# --- PyTorch Benchmark Function ---
|
||||
def benchmark_torch(model, sentences):
|
||||
start_time = time.time()
|
||||
model.encode(sentences, convert_to_numpy=True)
|
||||
torch.mps.synchronize() # Ensure computation is finished on MPS
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
# --- Simulated MLX Benchmark Function ---
|
||||
def benchmark_mlx_simulated(dummy_embedding_table, sentences):
|
||||
# 1. Simulate tokenization (result is just shape)
|
||||
batch_size = len(sentences)
|
||||
input_ids = mx.random.randint(0, 30000, (batch_size, SEQ_LENGTH))
|
||||
attention_mask = mx.ones((batch_size, SEQ_LENGTH))
|
||||
|
||||
start_time = time.time()
|
||||
# 2. Simulate embedding lookup
|
||||
embeddings = dummy_embedding_table[input_ids]
|
||||
|
||||
# 3. Simulate mean pooling
|
||||
mask = mx.expand_dims(attention_mask, -1)
|
||||
sum_embeddings = (embeddings * mask).sum(axis=1)
|
||||
sum_mask = mask.sum(axis=1)
|
||||
_ = sum_embeddings / sum_mask
|
||||
|
||||
mx.eval() # Ensure all MLX computations are finished
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
def main():
|
||||
print("--- Initializing Models ---")
|
||||
# Load real PyTorch model
|
||||
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
if device == "cpu":
|
||||
print("Warning: MPS not available for PyTorch. Benchmark will run on CPU.")
|
||||
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
|
||||
print(f"PyTorch model loaded on: {device}")
|
||||
|
||||
# Create dummy MLX embedding table
|
||||
print("Creating simulated MLX model...")
|
||||
dummy_vocab_size = 30522 # Typical BERT vocab size
|
||||
dummy_embedding_table_mlx = mx.random.normal((dummy_vocab_size, EMBED_DIM))
|
||||
mx.eval() # Ensure table is created
|
||||
print("Simulated MLX model created.")
|
||||
|
||||
# --- Warm-up ---
|
||||
print("\n--- Performing Warm-up Runs ---")
|
||||
for _ in range(WARMUP_RUNS):
|
||||
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
|
||||
benchmark_mlx_simulated(dummy_embedding_table_mlx, DUMMY_SENTENCES[:1])
|
||||
print("Warm-up complete.")
|
||||
|
||||
# --- Benchmarking ---
|
||||
print("\n--- Starting Benchmark ---")
|
||||
results_torch = []
|
||||
results_mlx = []
|
||||
|
||||
for batch_size in BATCH_SIZES:
|
||||
print(f"Benchmarking batch size: {batch_size}")
|
||||
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
||||
|
||||
# Benchmark PyTorch
|
||||
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
results_torch.append(np.mean(torch_times))
|
||||
|
||||
# Benchmark MLX
|
||||
mlx_times = [
|
||||
benchmark_mlx_simulated(dummy_embedding_table_mlx, sentences_batch)
|
||||
for _ in range(NUM_RUNS)
|
||||
]
|
||||
results_mlx.append(np.mean(mlx_times))
|
||||
|
||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||
print(f"Batch Sizes: {BATCH_SIZES}")
|
||||
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
|
||||
print(f"MLX (simulated): {[f'{t:.2f}' for t in results_mlx]}")
|
||||
|
||||
# --- Plotting ---
|
||||
print("\n--- Generating Plot ---")
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(BATCH_SIZES, results_torch, marker="o", linestyle="-", label=f"PyTorch ({device})")
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX (Simulated)")
|
||||
|
||||
plt.title("Simulated Embedding Performance: MLX vs PyTorch")
|
||||
plt.xlabel("Batch Size")
|
||||
plt.ylabel("Average Time per Batch (ms)")
|
||||
plt.xticks(BATCH_SIZES)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
output_filename = "embedding_benchmark_simulated.png"
|
||||
plt.savefig(output_filename)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -82,17 +82,35 @@ def create_hnsw_embedding_server(
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
# Convert relative paths to absolute paths based on metadata file location
|
||||
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
||||
# Resolve passage files for cross-machine portability
|
||||
metadata_dir = Path(passages_file).parent # Same directory as meta.json
|
||||
passage_sources = []
|
||||
for source in meta["passage_sources"]:
|
||||
source_copy = source.copy()
|
||||
# Convert relative paths to absolute paths
|
||||
if not Path(source_copy["path"]).is_absolute():
|
||||
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
||||
if not Path(source_copy["index_path"]).is_absolute():
|
||||
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
||||
passage_sources.append(source_copy)
|
||||
src = dict(source)
|
||||
# Absolute candidates from meta
|
||||
cand_path = Path(src.get("path", ""))
|
||||
cand_idx = Path(src.get("index_path", ""))
|
||||
# Relative hints if provided
|
||||
rel_path = src.get("path_relative")
|
||||
rel_idx = src.get("index_path_relative")
|
||||
# Defaults (siblings of meta)
|
||||
default_path = metadata_dir / "documents.leann.passages.jsonl"
|
||||
default_idx = metadata_dir / "documents.leann.passages.idx"
|
||||
|
||||
# Normalize path
|
||||
if not cand_path.exists():
|
||||
if rel_path and (metadata_dir / rel_path).exists():
|
||||
src["path"] = str(metadata_dir / rel_path)
|
||||
elif default_path.exists():
|
||||
src["path"] = str(default_path)
|
||||
# Normalize index_path
|
||||
if not cand_idx.exists():
|
||||
if rel_idx and (metadata_dir / rel_idx).exists():
|
||||
src["index_path"] = str(metadata_dir / rel_idx)
|
||||
elif default_idx.exists():
|
||||
src["index_path"] = str(default_idx)
|
||||
|
||||
passage_sources.append(src)
|
||||
|
||||
passages = PassageManager(passage_sources)
|
||||
logger.info(
|
||||
|
||||
@@ -328,6 +328,9 @@ class LeannBuilder:
|
||||
"type": "jsonl",
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
# Relative hints for cross-machine portability (non-breaking addition)
|
||||
"path_relative": f"{index_name}.passages.jsonl",
|
||||
"index_path_relative": f"{index_name}.passages.idx",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -444,6 +447,9 @@ class LeannBuilder:
|
||||
"type": "jsonl",
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
# Relative hints for cross-machine portability (non-breaking addition)
|
||||
"path_relative": f"{index_name}.passages.jsonl",
|
||||
"index_path_relative": f"{index_name}.passages.idx",
|
||||
}
|
||||
],
|
||||
"built_from_precomputed_embeddings": True,
|
||||
@@ -485,6 +491,42 @@ class LeannSearcher:
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
# Best-effort portability: if meta contains absolute paths from another machine,
|
||||
# and those paths do not exist locally, try relative hints or fallback sibling filenames.
|
||||
try:
|
||||
idx_path_obj = Path(self.meta_path_str).with_suffix("").with_suffix("")
|
||||
index_dir = idx_path_obj.parent
|
||||
index_name = idx_path_obj.name
|
||||
default_passages = index_dir / f"{index_name}.passages.jsonl"
|
||||
default_offsets = index_dir / f"{index_name}.passages.idx"
|
||||
|
||||
sources = self.meta_data.get("passage_sources", [])
|
||||
normalized_sources: list[dict[str, Any]] = []
|
||||
for src in sources:
|
||||
new_src = dict(src)
|
||||
raw_path = Path(new_src.get("path", ""))
|
||||
raw_idx = Path(new_src.get("index_path", ""))
|
||||
rel_path = new_src.get("path_relative")
|
||||
rel_idx = new_src.get("index_path_relative")
|
||||
|
||||
# Normalize path
|
||||
if not raw_path.exists():
|
||||
cand = index_dir / rel_path if rel_path else default_passages
|
||||
if cand.exists():
|
||||
new_src["path"] = str(cand)
|
||||
# Normalize idx
|
||||
if not raw_idx.exists():
|
||||
cand = index_dir / rel_idx if rel_idx else default_offsets
|
||||
if cand.exists():
|
||||
new_src["index_path"] = str(cand)
|
||||
|
||||
normalized_sources.append(new_src)
|
||||
|
||||
# Only override in-memory view; do not rewrite meta file (non-destructive)
|
||||
self.meta_data["passage_sources"] = normalized_sources
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
|
||||
@@ -7,7 +7,7 @@ resources:
|
||||
# cloud: aws
|
||||
disk_size: 100
|
||||
|
||||
env:
|
||||
envs:
|
||||
# Build parameters (override with: sky launch -c leann-gpu sky/leann-build.yaml -e key=value)
|
||||
index_name: my-index
|
||||
docs: ./data
|
||||
@@ -23,6 +23,8 @@ env:
|
||||
compact: true # for HNSW only
|
||||
# Optional pass-through
|
||||
extra_args: ""
|
||||
# Rebuild control
|
||||
force: true
|
||||
|
||||
# Sync local paths to the remote VM. Adjust as needed.
|
||||
file_mounts:
|
||||
@@ -35,8 +37,17 @@ setup: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
|
||||
# Install the LEANN CLI globally on the remote machine
|
||||
uv tool install leann
|
||||
# Ensure modern libstdc++ for FAISS (GLIBCXX >= 3.4.30)
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y libstdc++6 libgomp1
|
||||
# Also upgrade conda's libstdc++ in base env (Skypilot images include conda)
|
||||
if command -v conda >/dev/null 2>&1; then
|
||||
conda install -y -n base -c conda-forge libstdcxx-ng
|
||||
fi
|
||||
|
||||
# Install LEANN CLI and backends into the user environment
|
||||
uv pip install --upgrade pip
|
||||
uv pip install leann-core leann-backend-hnsw leann-backend-diskann
|
||||
|
||||
run: |
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
@@ -45,9 +56,13 @@ run: |
|
||||
if [ "${recompute}" = "false" ] || [ "${recompute}" = "0" ]; then
|
||||
recompute_flag="--no-recompute"
|
||||
fi
|
||||
force_flag=""
|
||||
if [ "${force}" = "true" ] || [ "${force}" = "1" ]; then
|
||||
force_flag="--force"
|
||||
fi
|
||||
|
||||
# Build command
|
||||
leann build ${index_name} \
|
||||
python -m leann.cli build ${index_name} \
|
||||
--docs ~/leann-data \
|
||||
--backend ${backend} \
|
||||
--complexity ${complexity} \
|
||||
@@ -55,4 +70,7 @@ run: |
|
||||
--num-threads ${num_threads} \
|
||||
--embedding-mode ${embedding_mode} \
|
||||
--embedding-model ${embedding_model} \
|
||||
${recompute_flag} ${extra_args}
|
||||
${recompute_flag} ${force_flag} ${extra_args}
|
||||
|
||||
# Print where the index is stored for downstream rsync
|
||||
echo "INDEX_OUT_DIR=~/.leann/indexes/${index_name}"
|
||||
|
||||
Reference in New Issue
Block a user