reader: non-destructive portability (relative hints + fallback); fix comments; sky: refine yaml

This commit is contained in:
Andy Lee
2025-08-14 01:05:01 -07:00
parent 3f81861cba
commit 0361725323
4 changed files with 213 additions and 14 deletions

View File

@@ -0,0 +1,121 @@
import time
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
# --- Configuration ---
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
BATCH_SIZES = [1, 8, 16, 32, 64, 128, 256]
NUM_RUNS = 10
WARMUP_RUNS = 2
SEQ_LENGTH = 256
EMBED_DIM = 768 # Dimension for all-mpnet-base-v2
# --- Generate Dummy Data ---
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
# --- PyTorch Benchmark Function ---
def benchmark_torch(model, sentences):
start_time = time.time()
model.encode(sentences, convert_to_numpy=True)
torch.mps.synchronize() # Ensure computation is finished on MPS
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
# --- Simulated MLX Benchmark Function ---
def benchmark_mlx_simulated(dummy_embedding_table, sentences):
# 1. Simulate tokenization (result is just shape)
batch_size = len(sentences)
input_ids = mx.random.randint(0, 30000, (batch_size, SEQ_LENGTH))
attention_mask = mx.ones((batch_size, SEQ_LENGTH))
start_time = time.time()
# 2. Simulate embedding lookup
embeddings = dummy_embedding_table[input_ids]
# 3. Simulate mean pooling
mask = mx.expand_dims(attention_mask, -1)
sum_embeddings = (embeddings * mask).sum(axis=1)
sum_mask = mask.sum(axis=1)
_ = sum_embeddings / sum_mask
mx.eval() # Ensure all MLX computations are finished
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
# --- Main Execution ---
def main():
print("--- Initializing Models ---")
# Load real PyTorch model
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
device = "mps" if torch.backends.mps.is_available() else "cpu"
if device == "cpu":
print("Warning: MPS not available for PyTorch. Benchmark will run on CPU.")
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
print(f"PyTorch model loaded on: {device}")
# Create dummy MLX embedding table
print("Creating simulated MLX model...")
dummy_vocab_size = 30522 # Typical BERT vocab size
dummy_embedding_table_mlx = mx.random.normal((dummy_vocab_size, EMBED_DIM))
mx.eval() # Ensure table is created
print("Simulated MLX model created.")
# --- Warm-up ---
print("\n--- Performing Warm-up Runs ---")
for _ in range(WARMUP_RUNS):
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
benchmark_mlx_simulated(dummy_embedding_table_mlx, DUMMY_SENTENCES[:1])
print("Warm-up complete.")
# --- Benchmarking ---
print("\n--- Starting Benchmark ---")
results_torch = []
results_mlx = []
for batch_size in BATCH_SIZES:
print(f"Benchmarking batch size: {batch_size}")
sentences_batch = DUMMY_SENTENCES[:batch_size]
# Benchmark PyTorch
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
results_torch.append(np.mean(torch_times))
# Benchmark MLX
mlx_times = [
benchmark_mlx_simulated(dummy_embedding_table_mlx, sentences_batch)
for _ in range(NUM_RUNS)
]
results_mlx.append(np.mean(mlx_times))
print("\n--- Benchmark Results (Average time per batch in ms) ---")
print(f"Batch Sizes: {BATCH_SIZES}")
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
print(f"MLX (simulated): {[f'{t:.2f}' for t in results_mlx]}")
# --- Plotting ---
print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6))
plt.plot(BATCH_SIZES, results_torch, marker="o", linestyle="-", label=f"PyTorch ({device})")
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX (Simulated)")
plt.title("Simulated Embedding Performance: MLX vs PyTorch")
plt.xlabel("Batch Size")
plt.ylabel("Average Time per Batch (ms)")
plt.xticks(BATCH_SIZES)
plt.grid(True)
plt.legend()
output_filename = "embedding_benchmark_simulated.png"
plt.savefig(output_filename)
print(f"Plot saved to {output_filename}")
if __name__ == "__main__":
main()

View File

@@ -82,17 +82,35 @@ def create_hnsw_embedding_server(
with open(passages_file) as f:
meta = json.load(f)
# Convert relative paths to absolute paths based on metadata file location
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
# Resolve passage files for cross-machine portability
metadata_dir = Path(passages_file).parent # Same directory as meta.json
passage_sources = []
for source in meta["passage_sources"]:
source_copy = source.copy()
# Convert relative paths to absolute paths
if not Path(source_copy["path"]).is_absolute():
source_copy["path"] = str(metadata_dir / source_copy["path"])
if not Path(source_copy["index_path"]).is_absolute():
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
passage_sources.append(source_copy)
src = dict(source)
# Absolute candidates from meta
cand_path = Path(src.get("path", ""))
cand_idx = Path(src.get("index_path", ""))
# Relative hints if provided
rel_path = src.get("path_relative")
rel_idx = src.get("index_path_relative")
# Defaults (siblings of meta)
default_path = metadata_dir / "documents.leann.passages.jsonl"
default_idx = metadata_dir / "documents.leann.passages.idx"
# Normalize path
if not cand_path.exists():
if rel_path and (metadata_dir / rel_path).exists():
src["path"] = str(metadata_dir / rel_path)
elif default_path.exists():
src["path"] = str(default_path)
# Normalize index_path
if not cand_idx.exists():
if rel_idx and (metadata_dir / rel_idx).exists():
src["index_path"] = str(metadata_dir / rel_idx)
elif default_idx.exists():
src["index_path"] = str(default_idx)
passage_sources.append(src)
passages = PassageManager(passage_sources)
logger.info(

View File

@@ -328,6 +328,9 @@ class LeannBuilder:
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file),
# Relative hints for cross-machine portability (non-breaking addition)
"path_relative": f"{index_name}.passages.jsonl",
"index_path_relative": f"{index_name}.passages.idx",
}
],
}
@@ -444,6 +447,9 @@ class LeannBuilder:
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file),
# Relative hints for cross-machine portability (non-breaking addition)
"path_relative": f"{index_name}.passages.jsonl",
"index_path_relative": f"{index_name}.passages.idx",
}
],
"built_from_precomputed_embeddings": True,
@@ -485,6 +491,42 @@ class LeannSearcher:
self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
# Best-effort portability: if meta contains absolute paths from another machine,
# and those paths do not exist locally, try relative hints or fallback sibling filenames.
try:
idx_path_obj = Path(self.meta_path_str).with_suffix("").with_suffix("")
index_dir = idx_path_obj.parent
index_name = idx_path_obj.name
default_passages = index_dir / f"{index_name}.passages.jsonl"
default_offsets = index_dir / f"{index_name}.passages.idx"
sources = self.meta_data.get("passage_sources", [])
normalized_sources: list[dict[str, Any]] = []
for src in sources:
new_src = dict(src)
raw_path = Path(new_src.get("path", ""))
raw_idx = Path(new_src.get("index_path", ""))
rel_path = new_src.get("path_relative")
rel_idx = new_src.get("index_path_relative")
# Normalize path
if not raw_path.exists():
cand = index_dir / rel_path if rel_path else default_passages
if cand.exists():
new_src["path"] = str(cand)
# Normalize idx
if not raw_idx.exists():
cand = index_dir / rel_idx if rel_idx else default_offsets
if cand.exists():
new_src["index_path"] = str(cand)
normalized_sources.append(new_src)
# Only override in-memory view; do not rewrite meta file (non-destructive)
self.meta_data["passage_sources"] = normalized_sources
except Exception:
pass
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:

View File

@@ -7,7 +7,7 @@ resources:
# cloud: aws
disk_size: 100
env:
envs:
# Build parameters (override with: sky launch -c leann-gpu sky/leann-build.yaml -e key=value)
index_name: my-index
docs: ./data
@@ -23,6 +23,8 @@ env:
compact: true # for HNSW only
# Optional pass-through
extra_args: ""
# Rebuild control
force: true
# Sync local paths to the remote VM. Adjust as needed.
file_mounts:
@@ -35,8 +37,17 @@ setup: |
curl -LsSf https://astral.sh/uv/install.sh | sh
export PATH="$HOME/.local/bin:$PATH"
# Install the LEANN CLI globally on the remote machine
uv tool install leann
# Ensure modern libstdc++ for FAISS (GLIBCXX >= 3.4.30)
sudo apt-get update -y
sudo apt-get install -y libstdc++6 libgomp1
# Also upgrade conda's libstdc++ in base env (Skypilot images include conda)
if command -v conda >/dev/null 2>&1; then
conda install -y -n base -c conda-forge libstdcxx-ng
fi
# Install LEANN CLI and backends into the user environment
uv pip install --upgrade pip
uv pip install leann-core leann-backend-hnsw leann-backend-diskann
run: |
export PATH="$HOME/.local/bin:$PATH"
@@ -45,9 +56,13 @@ run: |
if [ "${recompute}" = "false" ] || [ "${recompute}" = "0" ]; then
recompute_flag="--no-recompute"
fi
force_flag=""
if [ "${force}" = "true" ] || [ "${force}" = "1" ]; then
force_flag="--force"
fi
# Build command
leann build ${index_name} \
python -m leann.cli build ${index_name} \
--docs ~/leann-data \
--backend ${backend} \
--complexity ${complexity} \
@@ -55,4 +70,7 @@ run: |
--num-threads ${num_threads} \
--embedding-mode ${embedding_mode} \
--embedding-model ${embedding_model} \
${recompute_flag} ${extra_args}
${recompute_flag} ${force_flag} ${extra_args}
# Print where the index is stored for downstream rsync
echo "INDEX_OUT_DIR=~/.leann/indexes/${index_name}"