From ec5e9ac33b7a2fb27cb8afb7cb7db1c4ecbd7a2c Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Sat, 12 Jul 2025 06:07:43 +0000 Subject: [PATCH] feat: chat on mps --- examples/main_cli_example.py | 2 +- .../leann_backend_hnsw/hnsw_backend.py | 4 +- packages/leann-core/src/leann/chat.py | 62 +++++-- tests/sanity_checks/README_hnsw_pruning.md | 68 -------- tests/sanity_checks/test_hnsw_pruning.py | 156 ------------------ 5 files changed, 54 insertions(+), 238 deletions(-) delete mode 100644 tests/sanity_checks/README_hnsw_pruning.md delete mode 100644 tests/sanity_checks/test_hnsw_pruning.py diff --git a/examples/main_cli_example.py b/examples/main_cli_example.py index 7509a29..fc87cfc 100644 --- a/examples/main_cli_example.py +++ b/examples/main_cli_example.py @@ -86,7 +86,7 @@ async def main(args): query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发" print(f"You: {query}") - chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1) + chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32) print(f"Leann: {chat_response}") if __name__ == "__main__": diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index 98b96ef..02d91c9 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -147,8 +147,8 @@ class HNSWSearcher(BaseSearcher): params = faiss.SearchParametersHNSW() params.zmq_port = kwargs.get("zmq_port", 5557) - params.efSearch = kwargs.get("ef", 128) - params.beam_size = 2 + params.efSearch = kwargs.get("complexity", 32) + params.beam_size = kwargs.get("beam_width", 1) batch_size = query.shape[0] distances = np.empty((batch_size, top_k), dtype=np.float32) diff --git a/packages/leann-core/src/leann/chat.py b/packages/leann-core/src/leann/chat.py index 1a9bc9b..df83a57 100644 --- a/packages/leann-core/src/leann/chat.py +++ b/packages/leann-core/src/leann/chat.py @@ -18,15 +18,43 @@ class LLMInterface(ABC): @abstractmethod def ask(self, prompt: str, **kwargs) -> str: """ - Sends a prompt to the LLM and returns the generated text. - - Args: - prompt: The input prompt for the LLM. - **kwargs: Additional keyword arguments for the LLM backend. - - Returns: - The response string from the LLM. + Additional keyword arguments (kwargs) for advanced search customization. Example usage: + chat.ask( + "What is ANN?", + top_k=10, + complexity=64, + beam_width=8, + USE_DEFERRED_FETCH=True, + skip_search_reorder=True, + recompute_beighbor_embeddings=True, + dedup_node_dis=True, + prune_ratio=0.1, + batch_recompute=True, + global_pruning=True + ) + + Supported kwargs: + - complexity (int): Search complexity parameter (default: 32) + - beam_width (int): Beam width for search (default: 4) + - USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False) + - skip_search_reorder (bool): Skip search reorder step (default: False) + - recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False) + - dedup_node_dis (bool): Deduplicate nodes by distance (default: False) + - prune_ratio (float): Pruning ratio for search (default: 0.0) + - batch_recompute (bool): Enable batch recomputation (default: False) + - global_pruning (bool): Enable global pruning (default: False) """ + + # """ + # Sends a prompt to the LLM and returns the generated text. + + # Args: + # prompt: The input prompt for the LLM. + # **kwargs: Additional keyword arguments for the LLM backend. + + # Returns: + # The response string from the LLM. + # """ pass class OllamaChat(LLMInterface): @@ -82,10 +110,22 @@ class HFChat(LLMInterface): logger.info(f"Initializing HFChat with model='{model_name}'") try: from transformers import pipeline + import torch except ImportError: - raise ImportError("The 'transformers' library is required for Hugging Face models. Please install it with 'pip install transformers'.") - - self.pipeline = pipeline("text-generation", model=model_name) + raise ImportError("The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'.") + + # Auto-detect device + if torch.cuda.is_available(): + device = "cuda" + logger.info("CUDA is available. Using GPU.") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + device = "mps" + logger.info("MPS is available. Using Apple Silicon GPU.") + else: + device = "cpu" + logger.info("No GPU detected. Using CPU.") + + self.pipeline = pipeline("text-generation", model=model_name, device=device) def ask(self, prompt: str, **kwargs) -> str: # Sensible defaults for text generation diff --git a/tests/sanity_checks/README_hnsw_pruning.md b/tests/sanity_checks/README_hnsw_pruning.md deleted file mode 100644 index f5f6d94..0000000 --- a/tests/sanity_checks/README_hnsw_pruning.md +++ /dev/null @@ -1,68 +0,0 @@ -# HNSW Index Storage Optimization - -This document explains the storage optimization features available in the HNSW backend. - -## Storage Modes - -The HNSW backend supports two orthogonal optimization techniques: - -### 1. CSR Compression (`is_compact=True`) -- Converts the graph structure from standard format to Compressed Sparse Row (CSR) format -- Reduces memory overhead from graph adjacency storage -- Maintains all embedding data for direct access - -### 2. Embedding Pruning (`is_recompute=True`) -- Removes embedding vectors from the index file -- Replaces them with a NULL storage marker -- Requires recomputation via embedding server during search -- Must be used with `is_compact=True` for efficiency - -## Performance Impact - -**Storage Reduction (100 vectors, 384 dimensions):** -``` -Standard format: 168 KB (embeddings + graph) -CSR only: 160 KB (embeddings + compressed graph) -CSR + Pruned: 6 KB (compressed graph only) -``` - -**Key Benefits:** -- **CSR compression**: ~5% size reduction from graph optimization -- **Embedding pruning**: ~95% size reduction by removing embeddings -- **Combined**: Up to 96% total storage reduction - -## Usage - -```python -# Standard format (largest) -builder = LeannBuilder( - backend_name="hnsw", - is_compact=False, - is_recompute=False -) - -# CSR compressed (medium) -builder = LeannBuilder( - backend_name="hnsw", - is_compact=True, - is_recompute=False -) - -# CSR + Pruned (smallest, requires embedding server) -builder = LeannBuilder( - backend_name="hnsw", - is_compact=True, # Required for pruning - is_recompute=True # Default: enabled -) -``` - -## Trade-offs - -| Mode | Storage | Search Speed | Memory Usage | Setup Complexity | -|------|---------|--------------|--------------|------------------| -| Standard | Largest | Fastest | Highest | Simple | -| CSR | Medium | Fast | Medium | Simple | -| CSR + Pruned | Smallest | Slower* | Lowest | Complex** | - -*Requires network round-trip to embedding server for recomputation -**Needs embedding server and passages file for search \ No newline at end of file diff --git a/tests/sanity_checks/test_hnsw_pruning.py b/tests/sanity_checks/test_hnsw_pruning.py deleted file mode 100644 index 8d17626..0000000 --- a/tests/sanity_checks/test_hnsw_pruning.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 -""" -Sanity check script to verify HNSW index pruning effectiveness. -Tests the difference in file sizes between pruned and non-pruned indices. -""" - -import os -import sys -import tempfile -from pathlib import Path -import numpy as np -import json - -# Add the project root to the Python path -project_root = Path(__file__).parent.parent.parent -sys.path.insert(0, str(project_root)) - -# Import backend packages to trigger plugin registration -import leann_backend_hnsw - -from leann.api import LeannBuilder - -def create_sample_documents(num_docs=1000): - """Create sample documents for testing""" - documents = [] - for i in range(num_docs): - documents.append(f"Sample document {i} with some random text content for testing purposes.") - return documents - -def build_index(documents, output_dir, is_recompute=True): - """Build HNSW index with specified recompute setting""" - index_path = os.path.join(output_dir, "test_index.hnsw") - - builder = LeannBuilder( - backend_name="hnsw", - embedding_model="sentence-transformers/all-MiniLM-L6-v2", - M=16, - efConstruction=100, - distance_metric="mips", - is_compact=True, - is_recompute=is_recompute - ) - - for doc in documents: - builder.add_text(doc) - - builder.build_index(index_path) - - return index_path - -def get_file_size(filepath): - """Get file size in bytes""" - return os.path.getsize(filepath) - -def main(): - print("🔍 HNSW Pruning Sanity Check") - print("=" * 50) - - # Create sample data - print("📊 Creating sample documents...") - documents = create_sample_documents(num_docs=1000) - print(f" Number of documents: {len(documents)}") - - with tempfile.TemporaryDirectory() as temp_dir: - print(f"📁 Working in temporary directory: {temp_dir}") - - # Build index with pruning (is_recompute=True) - print("\n🔨 Building index with pruning enabled (is_recompute=True)...") - pruned_dir = os.path.join(temp_dir, "pruned") - os.makedirs(pruned_dir, exist_ok=True) - - pruned_index_path = build_index(documents, pruned_dir, is_recompute=True) - # Check what files were actually created - print(f" Looking for index files at: {pruned_index_path}") - import glob - files = glob.glob(f"{pruned_index_path}*") - print(f" Found files: {files}") - - # Try to find the actual index file - if os.path.exists(f"{pruned_index_path}.index"): - pruned_index_file = f"{pruned_index_path}.index" - else: - # Look for any .index file in the directory - index_files = glob.glob(f"{pruned_dir}/*.index") - if index_files: - pruned_index_file = index_files[0] - else: - raise FileNotFoundError(f"No .index file found in {pruned_dir}") - - pruned_size = get_file_size(pruned_index_file) - print(f" ✅ Pruned index built successfully") - print(f" 📏 Pruned index size: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)") - - # Build index without pruning (is_recompute=False) - print("\n🔨 Building index without pruning (is_recompute=False)...") - non_pruned_dir = os.path.join(temp_dir, "non_pruned") - os.makedirs(non_pruned_dir, exist_ok=True) - - non_pruned_index_path = build_index(documents, non_pruned_dir, is_recompute=False) - # Check what files were actually created - print(f" Looking for index files at: {non_pruned_index_path}") - files = glob.glob(f"{non_pruned_index_path}*") - print(f" Found files: {files}") - - # Try to find the actual index file - if os.path.exists(f"{non_pruned_index_path}.index"): - non_pruned_index_file = f"{non_pruned_index_path}.index" - else: - # Look for any .index file in the directory - index_files = glob.glob(f"{non_pruned_dir}/*.index") - if index_files: - non_pruned_index_file = index_files[0] - else: - raise FileNotFoundError(f"No .index file found in {non_pruned_dir}") - - non_pruned_size = get_file_size(non_pruned_index_file) - print(f" ✅ Non-pruned index built successfully") - print(f" 📏 Non-pruned index size: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)") - - # Compare sizes - print("\n📊 Comparison Results:") - print("=" * 30) - size_diff = non_pruned_size - pruned_size - size_ratio = pruned_size / non_pruned_size if non_pruned_size > 0 else 0 - reduction_percent = (1 - size_ratio) * 100 - - print(f"Non-pruned index: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)") - print(f"Pruned index: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)") - print(f"Size difference: {size_diff:,} bytes ({size_diff/1024:.1f} KB)") - print(f"Size ratio: {size_ratio:.3f}") - print(f"Size reduction: {reduction_percent:.1f}%") - - # Verify pruning effectiveness - print("\n🔍 Verification:") - if size_diff > 0: - print(" ✅ Pruning is effective - pruned index is smaller") - if reduction_percent > 10: - print(f" ✅ Significant size reduction: {reduction_percent:.1f}%") - else: - print(f" ⚠️ Small size reduction: {reduction_percent:.1f}%") - else: - print(" ❌ Pruning appears ineffective - no size reduction") - - # Check if passages files were created - pruned_passages = f"{pruned_index_path}.passages.json" - non_pruned_passages = f"{non_pruned_index_path}.passages.json" - - print(f"\n📄 Passages files:") - print(f" Pruned passages file exists: {os.path.exists(pruned_passages)}") - print(f" Non-pruned passages file exists: {os.path.exists(non_pruned_passages)}") - - return True - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) \ No newline at end of file