diff --git a/benchmarks/laion/.gitignore b/benchmarks/laion/.gitignore new file mode 100644 index 0000000..adbb97d --- /dev/null +++ b/benchmarks/laion/.gitignore @@ -0,0 +1 @@ +data/ \ No newline at end of file diff --git a/benchmarks/laion/README.md b/benchmarks/laion/README.md new file mode 100644 index 0000000..38650f0 --- /dev/null +++ b/benchmarks/laion/README.md @@ -0,0 +1,169 @@ +# LAION Multimodal Benchmark + +A multimodal benchmark for evaluating image retrieval performance using LEANN with CLIP embeddings on LAION dataset subset. + +## Overview + +This benchmark evaluates: +- **Image retrieval timing** using caption-based queries +- **Recall@K performance** for image search +- **Complexity analysis** across different search parameters +- **Index size and storage efficiency** + +## Dataset Configuration + +- **Dataset**: LAION-400M subset (10,000 images) +- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions) +- **Queries**: 200 random captions from the dataset +- **Ground Truth**: Self-recall (query caption โ†’ original image) + +## Quick Start + +### 1. Setup the benchmark + +```bash +cd benchmarks/laion +python setup_laion.py --num-samples 10000 --num-queries 200 +``` + +This will: +- Create dummy LAION data (10K samples) +- Generate CLIP embeddings (512-dim) +- Build LEANN index with HNSW backend +- Create 200 evaluation queries + +### 2. Run evaluation + +```bash +# Run all evaluation stages +python evaluate_laion.py --index data/laion_index.leann + +# Run specific stages +python evaluate_laion.py --index data/laion_index.leann --stage timing +python evaluate_laion.py --index data/laion_index.leann --stage recall +python evaluate_laion.py --index data/laion_index.leann --stage complexity +``` + +### 3. Save results + +```bash +python evaluate_laion.py --index data/laion_index.leann --output results.json +``` + +## Configuration Options + +### Setup Options +```bash +python setup_laion.py \ + --num-samples 10000 \ + --num-queries 200 \ + --index-path data/laion_index.leann \ + --backend hnsw +``` + +### Evaluation Options +```bash +python evaluate_laion.py \ + --index data/laion_index.leann \ + --queries data/evaluation_queries.jsonl \ + --complexity 64 \ + --top-k 3 \ + --num-samples 100 \ + --stage all +``` + +## Evaluation Stages + +### Stage 1: Index Analysis +- Analyzes index file sizes and metadata +- Reports storage efficiency + +### Stage 2: Search Timing +- Measures average search latency +- Tests with configurable complexity and top-k +- Reports searches per second + +### Stage 3: Recall Evaluation +- Evaluates Recall@K using ground truth +- Self-recall: query caption should retrieve original image + +### Stage 4: Complexity Analysis +- Tests performance across different complexity levels [16, 32, 64, 128] +- Analyzes speed vs. accuracy tradeoffs + +## Output Metrics + +### Timing Metrics +- Average/median/min/max search time +- Standard deviation +- Searches per second +- Latency in milliseconds + +### Recall Metrics +- Recall@K percentage +- Number of queries with ground truth + +### Index Metrics +- Total index size (MB) +- Component breakdown (index, passages, metadata) +- Backend and embedding model info + +## Example Results + +``` +๐ŸŽฏ LAION MULTIMODAL BENCHMARK RESULTS +============================================================ + +๐Ÿ“ Index Information: + Total size: 145.2 MB + Backend: hnsw + Embedding model: clip-vit-b-32 + Total passages: 10000 + +โšก Search Performance: + Total queries: 200 + Average search time: 0.023s + Median search time: 0.021s + Min/Max search time: 0.012s / 0.089s + Std dev: 0.008s + Complexity: 64 + Top-K: 3 + +๐Ÿ“Š Recall Performance: + Recall@3: 85.5% + Queries with ground truth: 200 + +โš™๏ธ Complexity Analysis: + Complexity 16: 0.015s avg + Complexity 32: 0.019s avg + Complexity 64: 0.023s avg + Complexity 128: 0.031s avg + +๐Ÿš€ Performance Summary: + Searches per second: 43.5 + Latency (ms): 23.0ms +``` + +## Directory Structure + +``` +benchmarks/laion/ +โ”œโ”€โ”€ setup_laion.py # Setup script +โ”œโ”€โ”€ evaluate_laion.py # Evaluation script +โ”œโ”€โ”€ README.md # This file +โ””โ”€โ”€ data/ # Generated data + โ”œโ”€โ”€ laion_images/ # Image files (placeholder) + โ”œโ”€โ”€ laion_metadata.jsonl # Image metadata + โ”œโ”€โ”€ laion_passages.jsonl # LEANN passages + โ”œโ”€โ”€ laion_embeddings.npy # CLIP embeddings + โ”œโ”€โ”€ evaluation_queries.jsonl # Evaluation queries + โ””โ”€โ”€ laion_index.leann/ # LEANN index files +``` + +## Notes + +- Current implementation uses dummy data for demonstration +- For real LAION data, implement actual download logic in `setup_laion.py` +- CLIP embeddings are randomly generated - replace with real CLIP model for production +- Adjust `num_samples` and `num_queries` based on available resources +- Consider using `--num-samples` during evaluation for faster testing \ No newline at end of file diff --git a/benchmarks/laion/evaluate_laion.py b/benchmarks/laion/evaluate_laion.py new file mode 100644 index 0000000..bbe7c54 --- /dev/null +++ b/benchmarks/laion/evaluate_laion.py @@ -0,0 +1,630 @@ +""" +LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation +""" + +import argparse +import json +import os +import pickle +import time +from pathlib import Path + +import numpy as np +from leann import LeannSearcher +from leann_backend_hnsw import faiss +from sentence_transformers import SentenceTransformer + + +class RecallEvaluator: + """Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)""" + + def __init__(self, index_path: str, baseline_dir: str): + self.index_path = index_path + self.baseline_dir = baseline_dir + self.searcher = LeannSearcher(index_path) + + # Load FAISS flat baseline (image embeddings) + baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index") + metadata_path = os.path.join(baseline_dir, "metadata.pkl") + + self.faiss_index = faiss.read_index(baseline_index_path) + with open(metadata_path, "rb") as f: + self.image_ids = pickle.load(f) + print(f"๐Ÿ“š Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors") + + # Load sentence-transformers CLIP for text embedding (ViT-L/14) + self.st_clip = SentenceTransformer("clip-ViT-L-14") + + def evaluate_recall_at_3( + self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True + ) -> float: + """Evaluate recall@3 for multimodal retrieval: caption queries -> image results""" + recompute_str = "with recompute" if recompute_embeddings else "no recompute" + print(f"๐Ÿ” Evaluating recall@3 with complexity={complexity} ({recompute_str})...") + + total_recall = 0.0 + num_queries = len(captions) + + for i, caption in enumerate(captions): + # Get ground truth: search with FAISS flat using caption text embedding + # Generate CLIP text embedding for caption via sentence-transformers (normalized) + query_embedding = self.st_clip.encode( + [caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False + ).astype(np.float32) + + # Search FAISS flat for ground truth using LEANN's modified faiss API + n = query_embedding.shape[0] # Number of queries + k = 3 # Number of nearest neighbors + distances = np.zeros((n, k), dtype=np.float32) + labels = np.zeros((n, k), dtype=np.int64) + + self.faiss_index.search( + n, + faiss.swig_ptr(query_embedding), + k, + faiss.swig_ptr(distances), + faiss.swig_ptr(labels), + ) + + # Extract the results (image IDs from FAISS) + baseline_ids = {self.image_ids[idx] for idx in labels[0]} + + # Search with LEANN at specified complexity (using caption as text query) + test_results = self.searcher.search( + caption, + top_k=3, + complexity=complexity, + recompute_embeddings=recompute_embeddings, + ) + test_ids = {result.id for result in test_results} + + # Calculate recall@3 = |intersection| / |ground_truth| + intersection = test_ids.intersection(baseline_ids) + recall = len(intersection) / 3.0 # Ground truth size is 3 + total_recall += recall + + if i < 3: # Show first few examples + print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}") + print(f" FAISS ground truth: {list(baseline_ids)}") + print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}") + print(f" Intersection: {list(intersection)}") + + avg_recall = total_recall / num_queries + print(f"๐Ÿ“Š Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)") + return avg_recall + + def cleanup(self): + """Cleanup resources""" + if hasattr(self, "searcher"): + self.searcher.cleanup() + + +class LAIONEvaluator: + def __init__(self, index_path: str): + self.index_path = index_path + self.searcher = LeannSearcher(index_path) + + def load_queries(self, queries_file: str) -> list[str]: + """Load caption queries from evaluation file""" + captions = [] + with open(queries_file, encoding="utf-8") as f: + for line in f: + if line.strip(): + query_data = json.loads(line) + captions.append(query_data["query"]) + + print(f"๐Ÿ“Š Loaded {len(captions)} caption queries") + return captions + + def analyze_index_sizes(self) -> dict: + """Analyze index sizes, emphasizing .index only (exclude passages).""" + print("๐Ÿ“ Analyzing index sizes (.index only)...") + + # Get all index-related files + index_path = Path(self.index_path) + index_dir = index_path.parent + index_name = index_path.stem # Remove .leann extension + + sizes: dict[str, float] = {} + + # Core index files + index_file = index_dir / f"{index_name}.index" + meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file + passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages + passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx + + # Core index size (.index only) + index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0 + sizes["index_only_mb"] = index_mb + + # Other files for reference (not counted in index_only_mb) + sizes["metadata_mb"] = meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0 + sizes["passages_text_mb"] = ( + passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0 + ) + sizes["passages_index_mb"] = ( + passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0 + ) + + print(f" ๐Ÿ“ .index size: {index_mb:.1f} MB") + if sizes["metadata_mb"]: + print(f" ๐Ÿงพ metadata: {sizes['metadata_mb']:.3f} MB") + if sizes["passages_text_mb"] or sizes["passages_index_mb"]: + print( + f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB" + ) + + return sizes + + def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict: + """Create a non-compact index for comparison purposes""" + print("๐Ÿ—๏ธ Building non-compact index from existing passages...") + + # Load existing passages from current index + from leann import LeannBuilder + + current_index_path = Path(self.index_path) + current_index_dir = current_index_path.parent + current_index_name = current_index_path.name + + # Read metadata to get passage source + meta_path = current_index_dir / f"{current_index_name}.meta.json" + with open(meta_path) as f: + meta = json.load(f) + + passage_source = meta["passage_sources"][0] + passage_file = passage_source["path"] + + # Convert relative path to absolute + if not Path(passage_file).is_absolute(): + passage_file = current_index_dir / Path(passage_file).name + + print(f"๐Ÿ“„ Loading passages from {passage_file}...") + + # Load CLIP embeddings + embeddings_file = current_index_dir / "clip_image_embeddings.npy" + embeddings = np.load(embeddings_file) + print(f"๐Ÿ“ Loaded embeddings shape: {embeddings.shape}") + + # Build non-compact index with same passages and embeddings + builder = LeannBuilder( + backend_name="hnsw", + # Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim) + embedding_model="clip-ViT-L-14", + embedding_mode="sentence-transformers", + is_recompute=False, # Disable recompute (store embeddings) + is_compact=False, # Disable compact storage + distance_metric="cosine", + **{ + k: v + for k, v in meta.get("backend_kwargs", {}).items() + if k not in ["is_recompute", "is_compact", "distance_metric"] + }, + ) + + # Prepare ids and add passages + ids: list[str] = [] + with open(passage_file, encoding="utf-8") as f: + for line in f: + if line.strip(): + data = json.loads(line) + ids.append(str(data["id"])) + # Ensure metadata contains the id used by the vector index + metadata = {**data.get("metadata", {}), "id": data["id"]} + builder.add_text(text=data["text"], metadata=metadata) + + if len(ids) != embeddings.shape[0]: + raise ValueError( + f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})." + ) + + # Persist a pickle for build_index_from_embeddings + pkl_path = current_index_dir / "clip_image_embeddings.pkl" + with open(pkl_path, "wb") as pf: + pickle.dump((ids, embeddings.astype(np.float32)), pf) + + print( + f"๐Ÿ”จ Building non-compact index at {non_compact_index_path} from precomputed embeddings..." + ) + builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path)) + + # Analyze the non-compact index size + temp_evaluator = LAIONEvaluator(non_compact_index_path) + non_compact_sizes = temp_evaluator.analyze_index_sizes() + non_compact_sizes["index_type"] = "non_compact" + + return non_compact_sizes + + def compare_index_performance( + self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int + ) -> dict: + """Compare performance between non-compact and compact indexes""" + print("โšก Comparing search performance between indexes...") + + # Test queries + test_queries = test_captions[:5] + + results = { + "non_compact": {"search_times": []}, + "compact": {"search_times": []}, + "avg_search_times": {}, + "speed_ratio": 0.0, + } + + # Test non-compact index (no recompute) + print(" ๐Ÿ” Testing non-compact index (no recompute)...") + non_compact_searcher = LeannSearcher(non_compact_path) + + for caption in test_queries: + start_time = time.time() + search_results = non_compact_searcher.search( + caption, top_k=3, complexity=complexity, recompute_embeddings=False + ) + search_time = time.time() - start_time + results["non_compact"]["search_times"].append(search_time) + + # Test compact index (with recompute) + print(" ๐Ÿ” Testing compact index (with recompute)...") + compact_searcher = LeannSearcher(compact_path) + + for caption in test_queries: + start_time = time.time() + search_results = compact_searcher.search( + caption, top_k=3, complexity=complexity, recompute_embeddings=True + ) + search_time = time.time() - start_time + results["compact"]["search_times"].append(search_time) + + # Calculate averages + results["avg_search_times"]["non_compact"] = sum( + results["non_compact"]["search_times"] + ) / len(results["non_compact"]["search_times"]) + results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len( + results["compact"]["search_times"] + ) + + # Performance ratio + if results["avg_search_times"]["compact"] > 0: + results["speed_ratio"] = ( + results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"] + ) + else: + results["speed_ratio"] = float("inf") + + print( + f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg" + ) + print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg") + print(f" Speed ratio: {results['speed_ratio']:.2f}x") + + # Cleanup + non_compact_searcher.cleanup() + compact_searcher.cleanup() + + return results + + def _print_results(self, timing_metrics: dict): + """Print evaluation results""" + print("\n๐ŸŽฏ LAION MULTIMODAL BENCHMARK RESULTS") + print("=" * 60) + + # Index comparison analysis + if "current_index" in timing_metrics and "non_compact_index" in timing_metrics: + print("\n๐Ÿ“ Index Comparison Analysis:") + current = timing_metrics["current_index"] + non_compact = timing_metrics["non_compact_index"] + + print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB") + print( + f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB" + ) + print( + f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%" + ) + + print(" Component breakdown (non-compact):") + print(f" - Main index: {non_compact.get('index', 0):.1f} MB") + print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB") + print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB") + print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB") + + # Performance comparison + if "performance_comparison" in timing_metrics: + perf = timing_metrics["performance_comparison"] + print("\nโšก Performance Comparison:") + print( + f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg" + ) + print( + f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg" + ) + print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x") + + # Legacy single index analysis (fallback) + if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics: + print("\n๐Ÿ“ Index Size Analysis:") + print( + f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB" + ) + print( + f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB" + ) + print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x") + + def cleanup(self): + """Cleanup resources""" + if self.searcher: + self.searcher.cleanup() + + +def main(): + parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation") + parser.add_argument("--index", required=True, help="Path to LEANN index") + parser.add_argument( + "--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries" + ) + parser.add_argument( + "--stage", + choices=["2", "3", "4", "all"], + default="all", + help="Which stage to run (2=recall, 3=complexity, 4=index comparison)", + ) + parser.add_argument("--complexity", type=int, default=None, help="Complexity for search") + parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory") + parser.add_argument("--output", help="Save results to JSON file") + + args = parser.parse_args() + + try: + # Check if baseline exists + baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index") + if not os.path.exists(baseline_index_path): + print(f"โŒ FAISS baseline not found at {baseline_index_path}") + print("๐Ÿ’ก Please run setup_laion.py first to build the baseline") + exit(1) + + if args.stage == "2" or args.stage == "all": + # Stage 2: Recall@3 evaluation + print("๐Ÿš€ Starting Stage 2: Recall@3 evaluation for multimodal retrieval") + + evaluator = RecallEvaluator(args.index, args.baseline_dir) + + # Load caption queries for testing + laion_evaluator = LAIONEvaluator(args.index) + captions = laion_evaluator.load_queries(args.queries) + + # Test with queries for robust measurement + test_captions = captions[:100] # Use subset for speed + print(f"๐Ÿงช Testing with {len(test_captions)} caption queries") + + # Test with complexity 64 + complexity = 64 + recall = evaluator.evaluate_recall_at_3(test_captions, complexity) + print(f"๐Ÿ“ˆ Recall@3 at complexity {complexity}: {recall * 100:.1f}%") + + evaluator.cleanup() + print("โœ… Stage 2 completed!\n") + + # Shared non-compact index path for Stage 3 and 4 + non_compact_index_path = args.index.replace(".leann", "_noncompact.leann") + complexity = args.complexity + + if args.stage == "3" or args.stage == "all": + # Stage 3: Binary search for 90% recall complexity + print("๐Ÿš€ Starting Stage 3: Binary search for 90% recall complexity") + print( + "๐Ÿ’ก Creating non-compact index for fast binary search with recompute_embeddings=False" + ) + + # Create non-compact index for binary search + print("๐Ÿ—๏ธ Creating non-compact index for binary search...") + evaluator = LAIONEvaluator(args.index) + evaluator.create_non_compact_index_for_comparison(non_compact_index_path) + + # Use non-compact index for binary search + binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir) + + # Load caption queries for testing + captions = evaluator.load_queries(args.queries) + + # Use subset for robust measurement + test_captions = captions[:50] # Smaller subset for binary search speed + print(f"๐Ÿงช Testing with {len(test_captions)} caption queries") + + # Binary search for 90% recall complexity + target_recall = 0.9 + min_complexity, max_complexity = 1, 64 + + print(f"๐Ÿ” Binary search for {target_recall * 100}% recall complexity...") + print(f"Search range: {min_complexity} to {max_complexity}") + + best_complexity = None + best_recall = 0.0 + + while min_complexity <= max_complexity: + mid_complexity = (min_complexity + max_complexity) // 2 + + print( + f"\n๐Ÿงช Testing complexity {mid_complexity} (no recompute, non-compact index)..." + ) + # Use recompute_embeddings=False on non-compact index for fast binary search + recall = binary_search_evaluator.evaluate_recall_at_3( + test_captions, mid_complexity, recompute_embeddings=False + ) + + print( + f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)" + ) + + if recall >= target_recall: + best_complexity = mid_complexity + best_recall = recall + max_complexity = mid_complexity - 1 + print(" โœ… Target reached! Searching for lower complexity...") + else: + min_complexity = mid_complexity + 1 + print(" โŒ Below target. Searching for higher complexity...") + + if best_complexity is not None: + print("\n๐ŸŽฏ Optimal complexity found!") + print(f" Complexity: {best_complexity}") + print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)") + + # Test a few complexities around the optimal one for verification + print("\n๐Ÿ”ฌ Verification test around optimal complexity:") + verification_complexities = [ + max(1, best_complexity - 2), + max(1, best_complexity - 1), + best_complexity, + best_complexity + 1, + best_complexity + 2, + ] + + for complexity in verification_complexities: + if complexity <= 512: # reasonable upper bound + recall = binary_search_evaluator.evaluate_recall_at_3( + test_captions, complexity, recompute_embeddings=False + ) + status = "โœ…" if recall >= target_recall else "โŒ" + print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%") + + # Now test the optimal complexity with compact index and recompute for comparison + print( + f"\n๐Ÿ”„ Testing optimal complexity {best_complexity} on compact index WITH recompute..." + ) + compact_evaluator = RecallEvaluator(args.index, args.baseline_dir) + recall_with_recompute = compact_evaluator.evaluate_recall_at_3( + test_captions[:10], best_complexity, recompute_embeddings=True + ) + print( + f" โœ… Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%" + ) + complexity = best_complexity + print( + f" ๐Ÿ“Š Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%" + ) + compact_evaluator.cleanup() + else: + print(f"\nโŒ Could not find complexity achieving {target_recall * 100}% recall") + print("All tested complexities were below target.") + + # Cleanup evaluators (keep non-compact index for Stage 4) + binary_search_evaluator.cleanup() + evaluator.cleanup() + + print("โœ… Stage 3 completed! Non-compact index saved for Stage 4.\n") + + if args.stage == "4" or args.stage == "all": + # Stage 4: Index comparison (without LLM generation) + print("๐Ÿš€ Starting Stage 4: Index comparison analysis") + + # Use LAION evaluator for index comparison + evaluator = LAIONEvaluator(args.index) + + # Load caption queries + captions = evaluator.load_queries(args.queries) + + # Step 1: Analyze current (compact) index + print("\n๐Ÿ“ Analyzing current index (compact, pruned)...") + compact_size_metrics = evaluator.analyze_index_sizes() + compact_size_metrics["index_type"] = "compact" + + # Step 2: Use existing non-compact index or create if needed + if Path(non_compact_index_path).exists(): + print( + f"\n๐Ÿ“ Using existing non-compact index from Stage 3: {non_compact_index_path}" + ) + temp_evaluator = LAIONEvaluator(non_compact_index_path) + non_compact_size_metrics = temp_evaluator.analyze_index_sizes() + non_compact_size_metrics["index_type"] = "non_compact" + else: + print("\n๐Ÿ—๏ธ Creating non-compact index (with embeddings) for comparison...") + non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison( + non_compact_index_path + ) + + # Step 3: Compare index sizes (.index only) + print("\n๐Ÿ“Š Index size comparison (.index only):") + print( + f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB" + ) + print( + f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB" + ) + + storage_saving = 0.0 + if non_compact_size_metrics.get("index_only_mb", 0) > 0: + storage_saving = ( + ( + non_compact_size_metrics.get("index_only_mb", 0) + - compact_size_metrics.get("index_only_mb", 0) + ) + / non_compact_size_metrics.get("index_only_mb", 1) + * 100 + ) + print(f" Storage saving by compact: {storage_saving:.1f}%") + + # Step 4: Performance comparison between the two indexes + if complexity is None: + raise ValueError("Complexity is required for index comparison") + + print("\nโšก Performance comparison between indexes...") + performance_metrics = evaluator.compare_index_performance( + non_compact_index_path, args.index, captions[:10], complexity=complexity + ) + + # Combine all metrics + combined_metrics = { + "current_index": compact_size_metrics, + "non_compact_index": non_compact_size_metrics, + "performance_comparison": performance_metrics, + "storage_saving_percent": storage_saving, + } + + # Print comprehensive results + evaluator._print_results(combined_metrics) + + # Save results if requested + if args.output: + print(f"\n๐Ÿ’พ Saving results to {args.output}...") + with open(args.output, "w") as f: + json.dump(combined_metrics, f, indent=2, default=str) + print(f"โœ… Results saved to {args.output}") + + evaluator.cleanup() + print("โœ… Stage 4 completed!\n") + + if args.stage == "all": + print("๐ŸŽ‰ All evaluation stages completed successfully!") + print("\n๐Ÿ“‹ Summary:") + print(" Stage 2: โœ… Multimodal Recall@3 evaluation completed") + print(" Stage 3: โœ… Optimal complexity found") + print(" Stage 4: โœ… Index comparison analysis completed") + print("\n๐Ÿ”ง Recommended next steps:") + print(" - Use optimal complexity for best speed/accuracy balance") + print(" - Review index comparison for storage vs performance tradeoffs") + + # Clean up non-compact index after all stages complete + print("\n๐Ÿงน Cleaning up temporary non-compact index...") + if Path(non_compact_index_path).exists(): + temp_index_dir = Path(non_compact_index_path).parent + temp_index_name = Path(non_compact_index_path).name + for temp_file in temp_index_dir.glob(f"{temp_index_name}*"): + temp_file.unlink() + print(f"โœ… Cleaned up {non_compact_index_path}") + else: + print("๐Ÿ“ No temporary index to clean up") + + except KeyboardInterrupt: + print("\nโš ๏ธ Evaluation interrupted by user") + exit(1) + except Exception as e: + print(f"\nโŒ Stage {args.stage} failed: {e}") + import traceback + + traceback.print_exc() + exit(1) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/laion/setup_laion.py b/benchmarks/laion/setup_laion.py new file mode 100644 index 0000000..64f01d9 --- /dev/null +++ b/benchmarks/laion/setup_laion.py @@ -0,0 +1,576 @@ +""" +LAION Multimodal Benchmark Setup Script +Downloads LAION subset and builds LEANN index with sentence embeddings +""" + +import argparse +import asyncio +import io +import json +import os +import pickle +import time +from pathlib import Path + +import aiohttp +import numpy as np +from datasets import load_dataset +from leann import LeannBuilder +from PIL import Image +from sentence_transformers import SentenceTransformer +from tqdm import tqdm + + +class LAIONSetup: + def __init__(self, data_dir: str = "data"): + self.data_dir = Path(data_dir) + self.images_dir = self.data_dir / "laion_images" + self.metadata_file = self.data_dir / "laion_metadata.jsonl" + + # Create directories + self.data_dir.mkdir(exist_ok=True) + self.images_dir.mkdir(exist_ok=True) + + async def download_single_image(self, session, sample_data, semaphore, progress_bar): + """Download a single image asynchronously""" + async with semaphore: # Limit concurrent downloads + try: + image_url = sample_data["url"] + image_path = sample_data["image_path"] + + # Skip if already exists + if os.path.exists(image_path): + progress_bar.update(1) + return sample_data + + async with session.get(image_url, timeout=10) as response: + if response.status == 200: + content = await response.read() + + # Verify it's a valid image + try: + img = Image.open(io.BytesIO(content)) + img = img.convert("RGB") + img.save(image_path, "JPEG") + progress_bar.update(1) + return sample_data + except Exception: + progress_bar.update(1) + return None # Skip invalid images + else: + progress_bar.update(1) + return None + + except Exception: + progress_bar.update(1) + return None + + def download_laion_subset(self, num_samples: int = 1000): + """Download LAION subset from HuggingFace datasets with async parallel downloading""" + print(f"๐Ÿ“ฅ Downloading LAION subset ({num_samples} samples)...") + + # Load LAION-400M subset from HuggingFace + print("๐Ÿค— Loading from HuggingFace datasets...") + dataset = load_dataset("laion/laion400m", split="train", streaming=True) + + # Collect sample metadata first (fast) + print("๐Ÿ“‹ Collecting sample metadata...") + candidates = [] + for i, sample in enumerate(dataset): + if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail + break + + image_url = sample.get("url", "") + caption = sample.get("caption", "") + + if not image_url or not caption: + continue + + image_filename = f"laion_{len(candidates):06d}.jpg" + image_path = self.images_dir / image_filename + + candidate = { + "id": f"laion_{len(candidates):06d}", + "url": image_url, + "caption": caption, + "image_path": str(image_path), + "width": sample.get("original_width", 512), + "height": sample.get("original_height", 512), + "similarity": sample.get("similarity", 0.0), + } + candidates.append(candidate) + + print( + f"๐Ÿ“Š Collected {len(candidates)} candidates, downloading {num_samples} in parallel..." + ) + + # Download images in parallel + async def download_batch(): + semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads + connector = aiohttp.TCPConnector(limit=100, limit_per_host=20) + timeout = aiohttp.ClientTimeout(total=30) + + progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images") + + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + tasks = [] + for candidate in candidates[: num_samples * 2]: # Try 2x more than needed + task = self.download_single_image(session, candidate, semaphore, progress_bar) + tasks.append(task) + + # Wait for all downloads + results = await asyncio.gather(*tasks, return_exceptions=True) + progress_bar.close() + + # Filter successful downloads + successful = [r for r in results if r is not None and not isinstance(r, Exception)] + return successful[:num_samples] + + # Run async download + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + samples = loop.run_until_complete(download_batch()) + finally: + loop.close() + + # Save metadata + with open(self.metadata_file, "w", encoding="utf-8") as f: + for sample in samples: + f.write(json.dumps(sample) + "\n") + + print(f"โœ… Downloaded {len(samples)} real LAION samples with async parallel downloading") + return samples + + def generate_clip_image_embeddings(self, samples: list[dict]): + """Generate CLIP image embeddings for downloaded images""" + print("๐Ÿ” Generating CLIP image embeddings...") + + # Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings + # This single model can encode both images and text. + model = SentenceTransformer("clip-ViT-L-14") + + embeddings = [] + valid_samples = [] + + for sample in tqdm(samples, desc="Processing images"): + try: + # Load image + image_path = sample["image_path"] + image = Image.open(image_path).convert("RGB") + + # Encode image to 768-dim embedding via sentence-transformers (normalized) + vec = model.encode( + [image], + convert_to_numpy=True, + normalize_embeddings=True, + batch_size=1, + show_progress_bar=False, + )[0] + embeddings.append(vec.astype(np.float32)) + valid_samples.append(sample) + + except Exception as e: + print(f" โš ๏ธ Failed to process {sample['id']}: {e}") + # Skip invalid images + + embeddings = np.array(embeddings, dtype=np.float32) + + # Save embeddings + embeddings_file = self.data_dir / "clip_image_embeddings.npy" + np.save(embeddings_file, embeddings) + print(f"โœ… Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}") + + return embeddings, valid_samples + + def build_faiss_baseline( + self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline" + ): + """Build FAISS flat baseline using CLIP image embeddings""" + print("๐Ÿ”จ Building FAISS Flat baseline...") + + from leann_backend_hnsw import faiss + + os.makedirs(output_dir, exist_ok=True) + baseline_path = os.path.join(output_dir, "faiss_flat.index") + metadata_path = os.path.join(output_dir, "metadata.pkl") + + if os.path.exists(baseline_path) and os.path.exists(metadata_path): + print(f"โœ… Baseline already exists at {baseline_path}") + return baseline_path + + # Extract image IDs (must be present) + if not samples or "id" not in samples[0]: + raise KeyError("samples missing 'id' field for FAISS baseline") + image_ids: list[str] = [str(sample["id"]) for sample in samples] + + print(f"๐Ÿ“ Embedding shape: {embeddings.shape}") + print(f"๐Ÿ“„ Processing {len(image_ids)} images") + + # Build FAISS flat index + print("๐Ÿ—๏ธ Building FAISS IndexFlatIP...") + dimension = embeddings.shape[1] + index = faiss.IndexFlatIP(dimension) + + # Add embeddings to flat index + embeddings_f32 = embeddings.astype(np.float32) + index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32)) + + # Save index and metadata + faiss.write_index(index, baseline_path) + with open(metadata_path, "wb") as f: + pickle.dump(image_ids, f) + + print(f"โœ… FAISS baseline saved to {baseline_path}") + print(f"โœ… Metadata saved to {metadata_path}") + print(f"๐Ÿ“Š Total vectors: {index.ntotal}") + + return baseline_path + + def create_leann_passages(self, samples: list[dict]): + """Create LEANN-compatible passages from LAION data""" + print("๐Ÿ“ Creating LEANN passages...") + + passages_file = self.data_dir / "laion_passages.jsonl" + + with open(passages_file, "w", encoding="utf-8") as f: + for i, sample in enumerate(samples): + passage = { + "id": sample["id"], + "text": sample["caption"], # Use caption as searchable text + "metadata": { + "image_url": sample["url"], + "image_path": sample.get("image_path", ""), + "width": sample["width"], + "height": sample["height"], + "similarity": sample["similarity"], + "image_index": i, # Index for embedding lookup + }, + } + f.write(json.dumps(passage) + "\n") + + print(f"โœ… Created {len(samples)} passages") + return passages_file + + def build_compact_index( + self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw" + ): + """Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)""" + print(f"๐Ÿ—๏ธ Building compact LEANN index with {backend} backend...") + + start_time = time.time() + + # Save CLIP embeddings (npy) and also a pickle with (ids, embeddings) + npy_path = self.data_dir / "clip_image_embeddings.npy" + np.save(npy_path, embeddings) + print(f"๐Ÿ’พ Saved CLIP embeddings to {npy_path}") + + # Prepare ids in the same order as passages_file (matches embeddings order) + ids: list[str] = [] + with open(passages_file, encoding="utf-8") as f: + for line in f: + if line.strip(): + rec = json.loads(line) + ids.append(str(rec["id"])) + + if len(ids) != embeddings.shape[0]: + raise ValueError( + f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})." + ) + + pkl_path = self.data_dir / "clip_image_embeddings.pkl" + with open(pkl_path, "wb") as pf: + pickle.dump((ids, embeddings.astype(np.float32)), pf) + print(f"๐Ÿ’พ Saved (ids, embeddings) pickle to {pkl_path}") + + # Initialize builder - compact with recompute + # Note: For multimodal case, we need to handle embeddings differently + # Let's try using sentence-transformers mode but with custom embeddings + builder = LeannBuilder( + backend_name=backend, + # Use CLIP text encoder (ViT-L/14) to match image space (768-dim) + embedding_model="clip-ViT-L-14", + embedding_mode="sentence-transformers", + # HNSW params (or forwarded to chosen backend) + graph_degree=32, + complexity=64, + # Compact/pruned with recompute at query time + is_recompute=True, + is_compact=True, + distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate + num_threads=4, + ) + + # Add passages (text + metadata) + print("๐Ÿ“š Adding passages...") + self._add_passages_with_embeddings(builder, passages_file, embeddings) + + print(f"๐Ÿ”จ Building compact index at {index_path} from precomputed embeddings...") + builder.build_index_from_embeddings(index_path, str(pkl_path)) + + build_time = time.time() - start_time + print(f"โœ… Compact index built in {build_time:.2f}s") + + # Analyze index size + self._analyze_index_size(index_path) + + return index_path + + def build_non_compact_index( + self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw" + ): + """Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)""" + print(f"๐Ÿ—๏ธ Building non-compact LEANN index with {backend} backend...") + + start_time = time.time() + + # Ensure embeddings are saved (npy + pickle) + npy_path = self.data_dir / "clip_image_embeddings.npy" + if not npy_path.exists(): + np.save(npy_path, embeddings) + print(f"๐Ÿ’พ Saved CLIP embeddings to {npy_path}") + # Prepare ids in same order as passages_file + ids: list[str] = [] + with open(passages_file, encoding="utf-8") as f: + for line in f: + if line.strip(): + rec = json.loads(line) + ids.append(str(rec["id"])) + if len(ids) != embeddings.shape[0]: + raise ValueError( + f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})." + ) + pkl_path = self.data_dir / "clip_image_embeddings.pkl" + if not pkl_path.exists(): + with open(pkl_path, "wb") as pf: + pickle.dump((ids, embeddings.astype(np.float32)), pf) + print(f"๐Ÿ’พ Saved (ids, embeddings) pickle to {pkl_path}") + + # Initialize builder - non-compact without recompute + builder = LeannBuilder( + backend_name=backend, + embedding_model="clip-ViT-L-14", + embedding_mode="sentence-transformers", + graph_degree=32, + complexity=64, + is_recompute=False, # Store embeddings (no recompute needed) + is_compact=False, # Store full index (not pruned) + distance_metric="cosine", + num_threads=4, + ) + + # Add passages - embeddings will be loaded from file + print("๐Ÿ“š Adding passages...") + self._add_passages_with_embeddings(builder, passages_file, embeddings) + + print(f"๐Ÿ”จ Building non-compact index at {index_path} from precomputed embeddings...") + builder.build_index_from_embeddings(index_path, str(pkl_path)) + + build_time = time.time() - start_time + print(f"โœ… Non-compact index built in {build_time:.2f}s") + + # Analyze index size + self._analyze_index_size(index_path) + + return index_path + + def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray): + """Helper to add passages with pre-computed CLIP embeddings""" + with open(passages_file, encoding="utf-8") as f: + for i, line in enumerate(tqdm(f, desc="Adding passages")): + if line.strip(): + passage = json.loads(line) + + # Add image metadata - LEANN will handle embeddings separately + # Note: We store image metadata and caption text for searchability + # Important: ensure passage ID in metadata matches vector ID + builder.add_text( + text=passage["text"], # Image caption for searchability + metadata={**passage["metadata"], "id": passage["id"]}, + ) + + def _analyze_index_size(self, index_path: str): + """Analyze index file sizes""" + print("๐Ÿ“ Analyzing index sizes...") + + index_path = Path(index_path) + index_dir = index_path.parent + index_name = index_path.name # e.g., laion_index.leann + index_prefix = index_path.stem # e.g., laion_index + + files = [ + (f"{index_prefix}.index", ".index", "core"), + (f"{index_name}.meta.json", ".meta.json", "core"), + (f"{index_name}.ids.txt", ".ids.txt", "core"), + (f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"), + (f"{index_name}.passages.idx", ".passages.idx", "passages"), + ] + + def _fmt_size(bytes_val: int) -> str: + if bytes_val < 1024: + return f"{bytes_val} B" + kb = bytes_val / 1024 + if kb < 1024: + return f"{kb:.1f} KB" + mb = kb / 1024 + if mb < 1024: + return f"{mb:.2f} MB" + gb = mb / 1024 + return f"{gb:.2f} GB" + + total_index_only_mb = 0.0 + total_all_mb = 0.0 + for filename, label, group in files: + file_path = index_dir / filename + if file_path.exists(): + size_bytes = file_path.stat().st_size + print(f" {label}: {_fmt_size(size_bytes)}") + size_mb = size_bytes / (1024 * 1024) + total_all_mb += size_mb + if group == "core": + total_index_only_mb += size_mb + else: + print(f" {label}: (missing)") + print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB") + print(f" Total (including passages): {total_all_mb:.2f} MB") + + def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200): + """Create evaluation queries from captions""" + print(f"๐Ÿ“ Creating {num_queries} evaluation queries...") + + # Sample random captions as queries + import random + + random.seed(42) # For reproducibility + + query_samples = random.sample(samples, min(num_queries, len(samples))) + + queries_file = self.data_dir / "evaluation_queries.jsonl" + with open(queries_file, "w", encoding="utf-8") as f: + for sample in query_samples: + query = { + "id": sample["id"], + "query": sample["caption"], + "ground_truth_id": sample["id"], # For potential recall evaluation + } + f.write(json.dumps(query) + "\n") + + print(f"โœ… Created {len(query_samples)} evaluation queries") + return queries_file + + +def main(): + parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark") + parser.add_argument("--data-dir", default="data", help="Data directory") + parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples") + parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries") + parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path") + parser.add_argument( + "--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend" + ) + parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download") + parser.add_argument("--skip-build", action="store_true", help="Skip index building") + + args = parser.parse_args() + + print("๐Ÿš€ Setting up LAION Multimodal Benchmark") + print("=" * 50) + + try: + # Initialize setup + setup = LAIONSetup(args.data_dir) + + # Step 1: Download LAION subset + if not args.skip_download: + print("\n๐Ÿ“ฆ Step 1: Download LAION subset") + samples = setup.download_laion_subset(args.num_samples) + + # Step 2: Generate CLIP image embeddings + print("\n๐Ÿ” Step 2: Generate CLIP image embeddings") + embeddings, valid_samples = setup.generate_clip_image_embeddings(samples) + + # Step 3: Create LEANN passages (image metadata with embeddings) + print("\n๐Ÿ“ Step 3: Create LEANN passages") + passages_file = setup.create_leann_passages(valid_samples) + else: + print("โญ๏ธ Skipping LAION dataset download") + # Load existing data + passages_file = setup.data_dir / "laion_passages.jsonl" + embeddings_file = setup.data_dir / "clip_image_embeddings.npy" + + if not passages_file.exists() or not embeddings_file.exists(): + raise FileNotFoundError( + "Passages or embeddings file not found. Run without --skip-download first." + ) + + embeddings = np.load(embeddings_file) + print(f"๐Ÿ“Š Loaded {len(embeddings)} embeddings from {embeddings_file}") + + # Step 4: Build LEANN indexes (both compact and non-compact) + if not args.skip_build: + print("\n๐Ÿ—๏ธ Step 4: Build LEANN indexes with CLIP image embeddings") + + # Build compact index (production mode - small, recompute required) + compact_index_path = args.index_path + print(f"Building compact index: {compact_index_path}") + setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend) + + # Build non-compact index (comparison mode - large, fast search) + non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann") + print(f"Building non-compact index: {non_compact_index_path}") + setup.build_non_compact_index( + passages_file, embeddings, non_compact_index_path, args.backend + ) + + # Step 5: Build FAISS flat baseline + print("\n๐Ÿ”จ Step 5: Build FAISS flat baseline") + if not args.skip_download: + baseline_path = setup.build_faiss_baseline(embeddings, valid_samples) + else: + # Load valid_samples from passages file for FAISS baseline + valid_samples = [] + with open(passages_file, encoding="utf-8") as f: + for line in f: + if line.strip(): + passage = json.loads(line) + valid_samples.append({"id": passage["id"], "caption": passage["text"]}) + baseline_path = setup.build_faiss_baseline(embeddings, valid_samples) + + # Step 6: Create evaluation queries + print("\n๐Ÿ“ Step 6: Create evaluation queries") + queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries) + else: + print("โญ๏ธ Skipping index building") + baseline_path = "data/baseline/faiss_index.bin" + queries_file = setup.data_dir / "evaluation_queries.jsonl" + + print("\n๐ŸŽ‰ Setup completed successfully!") + print("๐Ÿ“Š Summary:") + if not args.skip_download: + print(f" Downloaded samples: {len(samples)}") + print(f" Valid samples with embeddings: {len(valid_samples)}") + else: + print(f" Loaded {len(embeddings)} embeddings") + + if not args.skip_build: + print(f" Compact index: {compact_index_path}") + print(f" Non-compact index: {non_compact_index_path}") + print(f" FAISS baseline: {baseline_path}") + print(f" Queries: {queries_file}") + + print("\n๐Ÿ”ง Next steps:") + print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}") + print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}") + else: + print(" Skipped building indexes") + + except KeyboardInterrupt: + print("\nโš ๏ธ Setup interrupted by user") + exit(1) + except Exception as e: + print(f"\nโŒ Setup failed: {e}") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index 2ec6e39..db8fa7d 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -89,6 +89,15 @@ class HNSWBuilder(LeannBackendBuilderInterface): index_file = index_dir / f"{index_prefix}.index" faiss.write_index(index, str(index_file)) + # Persist ID map so searcher can map FAISS integer labels back to passage IDs + try: + idmap_file = index_dir / f"{index_prefix}.ids.txt" + with open(idmap_file, "w", encoding="utf-8") as f: + for id_str in ids: + f.write(str(id_str) + "\n") + except Exception as e: + logger.warning(f"Failed to write ID map: {e}") + if self.is_compact: self._convert_to_csr(index_file) @@ -149,6 +158,16 @@ class HNSWSearcher(BaseSearcher): self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) + # Load ID map if available + self._id_map: list[str] = [] + try: + idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt" + if idmap_file.exists(): + with open(idmap_file, encoding="utf-8") as f: + self._id_map = [line.rstrip("\n") for line in f] + except Exception as e: + logger.warning(f"Failed to load ID map: {e}") + def search( self, query: np.ndarray, @@ -244,7 +263,17 @@ class HNSWSearcher(BaseSearcher): faiss.swig_ptr(labels), params, ) + if self._id_map: - string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels] + def map_label(x: int) -> str: + if 0 <= x < len(self._id_map): + return self._id_map[x] + return str(x) + + string_labels = [[map_label(int(l)) for l in batch_labels] for batch_labels in labels] + else: + string_labels = [ + [str(int_label) for int_label in batch_labels] for batch_labels in labels + ] return {"labels": string_labels, "distances": distances} diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 7c472ad..d384991 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -94,6 +94,35 @@ def create_hnsw_embedding_server( f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" ) + # Attempt to load ID map (maps FAISS integer labels -> passage IDs) + id_map: list[str] = [] + try: + meta_path = Path(passages_file) + base = meta_path.name + if base.endswith(".meta.json"): + base = base[: -len(".meta.json")] # e.g., laion_index.leann + if base.endswith(".leann"): + base = base[: -len(".leann")] # e.g., laion_index + idmap_file = meta_path.parent / f"{base}.ids.txt" + if idmap_file.exists(): + with open(idmap_file, encoding="utf-8") as f: + id_map = [line.rstrip("\n") for line in f] + logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}") + else: + logger.warning(f"ID map file not found at {idmap_file}; will use raw labels") + except Exception as e: + logger.warning(f"Failed to load ID map: {e}") + + def _map_node_id(nid) -> str: + try: + if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)): + idx = int(nid) + if 0 <= idx < len(id_map): + return id_map[idx] + except Exception: + pass + return str(nid) + # (legacy ZMQ thread removed; using shutdown-capable server only) def zmq_server_thread_with_shutdown(shutdown_event): @@ -170,13 +199,14 @@ def create_hnsw_embedding_server( found_indices: list[int] = [] for idx, nid in enumerate(node_ids): try: - passage_data = passages.get_passage(str(nid)) + passage_id = _map_node_id(nid) + passage_data = passages.get_passage(passage_id) txt = passage_data.get("text", "") if isinstance(txt, str) and len(txt) > 0: texts.append(txt) found_indices.append(idx) else: - logger.error(f"Empty text for passage ID {nid}") + logger.error(f"Empty text for passage ID {passage_id}") except KeyError: logger.error(f"Passage ID {nid} not found") except Exception as e: @@ -240,13 +270,14 @@ def create_hnsw_embedding_server( found_indices: list[int] = [] for idx, nid in enumerate(node_ids): try: - passage_data = passages.get_passage(str(nid)) + passage_id = _map_node_id(nid) + passage_data = passages.get_passage(passage_id) txt = passage_data.get("text", "") if isinstance(txt, str) and len(txt) > 0: texts.append(txt) found_indices.append(idx) else: - logger.error(f"Empty text for passage ID {nid}") + logger.error(f"Empty text for passage ID {passage_id}") except KeyError: logger.error(f"Passage with ID {nid} not found") except Exception as e: diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index ec32569..47c66b5 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -372,6 +372,14 @@ class LeannBuilder: is_build=True, ) string_ids = [chunk["id"] for chunk in self.chunks] + # Persist ID map alongside index so backends that return integer labels can remap to passage IDs + try: + idmap_file = index_dir / f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt" + with open(idmap_file, "w", encoding="utf-8") as f: + for sid in string_ids: + f.write(str(sid) + "\n") + except Exception: + pass current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs) @@ -490,6 +498,14 @@ class LeannBuilder: # Build the vector index using precomputed embeddings string_ids = [str(id_val) for id_val in ids] + # Persist ID map (order == embeddings order) + try: + idmap_file = index_dir / f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt" + with open(idmap_file, "w", encoding="utf-8") as f: + for sid in string_ids: + f.write(str(sid) + "\n") + except Exception: + pass current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance.build(embeddings, string_ids, index_path)