From 03617253230eb9d154a7565fce1b008f139f0506 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Thu, 14 Aug 2025 01:05:01 -0700 Subject: [PATCH] reader: non-destructive portability (relative hints + fallback); fix comments; sky: refine yaml --- benchmark_embeddings_simulated.py | 121 ++++++++++++++++++ .../hnsw_embedding_server.py | 36 ++++-- packages/leann-core/src/leann/api.py | 42 ++++++ sky/leann-build.yaml | 28 +++- 4 files changed, 213 insertions(+), 14 deletions(-) create mode 100644 benchmark_embeddings_simulated.py diff --git a/benchmark_embeddings_simulated.py b/benchmark_embeddings_simulated.py new file mode 100644 index 0000000..f951388 --- /dev/null +++ b/benchmark_embeddings_simulated.py @@ -0,0 +1,121 @@ +import time + +import matplotlib.pyplot as plt +import mlx.core as mx +import numpy as np +import torch +from sentence_transformers import SentenceTransformer + +# --- Configuration --- +MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B" +BATCH_SIZES = [1, 8, 16, 32, 64, 128, 256] +NUM_RUNS = 10 +WARMUP_RUNS = 2 +SEQ_LENGTH = 256 +EMBED_DIM = 768 # Dimension for all-mpnet-base-v2 + +# --- Generate Dummy Data --- +DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES) + + +# --- PyTorch Benchmark Function --- +def benchmark_torch(model, sentences): + start_time = time.time() + model.encode(sentences, convert_to_numpy=True) + torch.mps.synchronize() # Ensure computation is finished on MPS + end_time = time.time() + return (end_time - start_time) * 1000 # Return time in ms + + +# --- Simulated MLX Benchmark Function --- +def benchmark_mlx_simulated(dummy_embedding_table, sentences): + # 1. Simulate tokenization (result is just shape) + batch_size = len(sentences) + input_ids = mx.random.randint(0, 30000, (batch_size, SEQ_LENGTH)) + attention_mask = mx.ones((batch_size, SEQ_LENGTH)) + + start_time = time.time() + # 2. Simulate embedding lookup + embeddings = dummy_embedding_table[input_ids] + + # 3. Simulate mean pooling + mask = mx.expand_dims(attention_mask, -1) + sum_embeddings = (embeddings * mask).sum(axis=1) + sum_mask = mask.sum(axis=1) + _ = sum_embeddings / sum_mask + + mx.eval() # Ensure all MLX computations are finished + end_time = time.time() + return (end_time - start_time) * 1000 # Return time in ms + + +# --- Main Execution --- +def main(): + print("--- Initializing Models ---") + # Load real PyTorch model + print(f"Loading PyTorch model: {MODEL_NAME_TORCH}") + device = "mps" if torch.backends.mps.is_available() else "cpu" + if device == "cpu": + print("Warning: MPS not available for PyTorch. Benchmark will run on CPU.") + model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device) + print(f"PyTorch model loaded on: {device}") + + # Create dummy MLX embedding table + print("Creating simulated MLX model...") + dummy_vocab_size = 30522 # Typical BERT vocab size + dummy_embedding_table_mlx = mx.random.normal((dummy_vocab_size, EMBED_DIM)) + mx.eval() # Ensure table is created + print("Simulated MLX model created.") + + # --- Warm-up --- + print("\n--- Performing Warm-up Runs ---") + for _ in range(WARMUP_RUNS): + benchmark_torch(model_torch, DUMMY_SENTENCES[:1]) + benchmark_mlx_simulated(dummy_embedding_table_mlx, DUMMY_SENTENCES[:1]) + print("Warm-up complete.") + + # --- Benchmarking --- + print("\n--- Starting Benchmark ---") + results_torch = [] + results_mlx = [] + + for batch_size in BATCH_SIZES: + print(f"Benchmarking batch size: {batch_size}") + sentences_batch = DUMMY_SENTENCES[:batch_size] + + # Benchmark PyTorch + torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)] + results_torch.append(np.mean(torch_times)) + + # Benchmark MLX + mlx_times = [ + benchmark_mlx_simulated(dummy_embedding_table_mlx, sentences_batch) + for _ in range(NUM_RUNS) + ] + results_mlx.append(np.mean(mlx_times)) + + print("\n--- Benchmark Results (Average time per batch in ms) ---") + print(f"Batch Sizes: {BATCH_SIZES}") + print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}") + print(f"MLX (simulated): {[f'{t:.2f}' for t in results_mlx]}") + + # --- Plotting --- + print("\n--- Generating Plot ---") + plt.figure(figsize=(10, 6)) + plt.plot(BATCH_SIZES, results_torch, marker="o", linestyle="-", label=f"PyTorch ({device})") + plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX (Simulated)") + + plt.title("Simulated Embedding Performance: MLX vs PyTorch") + plt.xlabel("Batch Size") + plt.ylabel("Average Time per Batch (ms)") + plt.xticks(BATCH_SIZES) + plt.grid(True) + plt.legend() + + output_filename = "embedding_benchmark_simulated.png" + plt.savefig(output_filename) + print(f"Plot saved to {output_filename}") + + +if __name__ == "__main__": + main() diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 013ae5a..77ca57d 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -82,17 +82,35 @@ def create_hnsw_embedding_server( with open(passages_file) as f: meta = json.load(f) - # Convert relative paths to absolute paths based on metadata file location - metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file + # Resolve passage files for cross-machine portability + metadata_dir = Path(passages_file).parent # Same directory as meta.json passage_sources = [] for source in meta["passage_sources"]: - source_copy = source.copy() - # Convert relative paths to absolute paths - if not Path(source_copy["path"]).is_absolute(): - source_copy["path"] = str(metadata_dir / source_copy["path"]) - if not Path(source_copy["index_path"]).is_absolute(): - source_copy["index_path"] = str(metadata_dir / source_copy["index_path"]) - passage_sources.append(source_copy) + src = dict(source) + # Absolute candidates from meta + cand_path = Path(src.get("path", "")) + cand_idx = Path(src.get("index_path", "")) + # Relative hints if provided + rel_path = src.get("path_relative") + rel_idx = src.get("index_path_relative") + # Defaults (siblings of meta) + default_path = metadata_dir / "documents.leann.passages.jsonl" + default_idx = metadata_dir / "documents.leann.passages.idx" + + # Normalize path + if not cand_path.exists(): + if rel_path and (metadata_dir / rel_path).exists(): + src["path"] = str(metadata_dir / rel_path) + elif default_path.exists(): + src["path"] = str(default_path) + # Normalize index_path + if not cand_idx.exists(): + if rel_idx and (metadata_dir / rel_idx).exists(): + src["index_path"] = str(metadata_dir / rel_idx) + elif default_idx.exists(): + src["index_path"] = str(default_idx) + + passage_sources.append(src) passages = PassageManager(passage_sources) logger.info( diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index bc060b0..6a9b504 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -328,6 +328,9 @@ class LeannBuilder: "type": "jsonl", "path": str(passages_file), "index_path": str(offset_file), + # Relative hints for cross-machine portability (non-breaking addition) + "path_relative": f"{index_name}.passages.jsonl", + "index_path_relative": f"{index_name}.passages.idx", } ], } @@ -444,6 +447,9 @@ class LeannBuilder: "type": "jsonl", "path": str(passages_file), "index_path": str(offset_file), + # Relative hints for cross-machine portability (non-breaking addition) + "path_relative": f"{index_name}.passages.jsonl", + "index_path_relative": f"{index_name}.passages.idx", } ], "built_from_precomputed_embeddings": True, @@ -485,6 +491,42 @@ class LeannSearcher: self.embedding_model = self.meta_data["embedding_model"] # Support both old and new format self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") + # Best-effort portability: if meta contains absolute paths from another machine, + # and those paths do not exist locally, try relative hints or fallback sibling filenames. + try: + idx_path_obj = Path(self.meta_path_str).with_suffix("").with_suffix("") + index_dir = idx_path_obj.parent + index_name = idx_path_obj.name + default_passages = index_dir / f"{index_name}.passages.jsonl" + default_offsets = index_dir / f"{index_name}.passages.idx" + + sources = self.meta_data.get("passage_sources", []) + normalized_sources: list[dict[str, Any]] = [] + for src in sources: + new_src = dict(src) + raw_path = Path(new_src.get("path", "")) + raw_idx = Path(new_src.get("index_path", "")) + rel_path = new_src.get("path_relative") + rel_idx = new_src.get("index_path_relative") + + # Normalize path + if not raw_path.exists(): + cand = index_dir / rel_path if rel_path else default_passages + if cand.exists(): + new_src["path"] = str(cand) + # Normalize idx + if not raw_idx.exists(): + cand = index_dir / rel_idx if rel_idx else default_offsets + if cand.exists(): + new_src["index_path"] = str(cand) + + normalized_sources.append(new_src) + + # Only override in-memory view; do not rewrite meta file (non-destructive) + self.meta_data["passage_sources"] = normalized_sources + except Exception: + pass + self.passage_manager = PassageManager(self.meta_data.get("passage_sources", [])) backend_factory = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: diff --git a/sky/leann-build.yaml b/sky/leann-build.yaml index f5db04b..53fd909 100644 --- a/sky/leann-build.yaml +++ b/sky/leann-build.yaml @@ -7,7 +7,7 @@ resources: # cloud: aws disk_size: 100 -env: +envs: # Build parameters (override with: sky launch -c leann-gpu sky/leann-build.yaml -e key=value) index_name: my-index docs: ./data @@ -23,6 +23,8 @@ env: compact: true # for HNSW only # Optional pass-through extra_args: "" + # Rebuild control + force: true # Sync local paths to the remote VM. Adjust as needed. file_mounts: @@ -35,8 +37,17 @@ setup: | curl -LsSf https://astral.sh/uv/install.sh | sh export PATH="$HOME/.local/bin:$PATH" - # Install the LEANN CLI globally on the remote machine - uv tool install leann + # Ensure modern libstdc++ for FAISS (GLIBCXX >= 3.4.30) + sudo apt-get update -y + sudo apt-get install -y libstdc++6 libgomp1 + # Also upgrade conda's libstdc++ in base env (Skypilot images include conda) + if command -v conda >/dev/null 2>&1; then + conda install -y -n base -c conda-forge libstdcxx-ng + fi + + # Install LEANN CLI and backends into the user environment + uv pip install --upgrade pip + uv pip install leann-core leann-backend-hnsw leann-backend-diskann run: | export PATH="$HOME/.local/bin:$PATH" @@ -45,9 +56,13 @@ run: | if [ "${recompute}" = "false" ] || [ "${recompute}" = "0" ]; then recompute_flag="--no-recompute" fi + force_flag="" + if [ "${force}" = "true" ] || [ "${force}" = "1" ]; then + force_flag="--force" + fi # Build command - leann build ${index_name} \ + python -m leann.cli build ${index_name} \ --docs ~/leann-data \ --backend ${backend} \ --complexity ${complexity} \ @@ -55,4 +70,7 @@ run: | --num-threads ${num_threads} \ --embedding-mode ${embedding_mode} \ --embedding-model ${embedding_model} \ - ${recompute_flag} ${extra_args} + ${recompute_flag} ${force_flag} ${extra_args} + + # Print where the index is stored for downstream rsync + echo "INDEX_OUT_DIR=~/.leann/indexes/${index_name}"