feat: chat on mps
This commit is contained in:
@@ -86,7 +86,7 @@ async def main(args):
|
|||||||
query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1)
|
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -147,8 +147,8 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
|
|
||||||
params = faiss.SearchParametersHNSW()
|
params = faiss.SearchParametersHNSW()
|
||||||
params.zmq_port = kwargs.get("zmq_port", 5557)
|
params.zmq_port = kwargs.get("zmq_port", 5557)
|
||||||
params.efSearch = kwargs.get("ef", 128)
|
params.efSearch = kwargs.get("complexity", 32)
|
||||||
params.beam_size = 2
|
params.beam_size = kwargs.get("beam_width", 1)
|
||||||
|
|
||||||
batch_size = query.shape[0]
|
batch_size = query.shape[0]
|
||||||
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
||||||
|
|||||||
@@ -18,15 +18,43 @@ class LLMInterface(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Sends a prompt to the LLM and returns the generated text.
|
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
|
||||||
|
chat.ask(
|
||||||
|
"What is ANN?",
|
||||||
|
top_k=10,
|
||||||
|
complexity=64,
|
||||||
|
beam_width=8,
|
||||||
|
USE_DEFERRED_FETCH=True,
|
||||||
|
skip_search_reorder=True,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
dedup_node_dis=True,
|
||||||
|
prune_ratio=0.1,
|
||||||
|
batch_recompute=True,
|
||||||
|
global_pruning=True
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
Supported kwargs:
|
||||||
prompt: The input prompt for the LLM.
|
- complexity (int): Search complexity parameter (default: 32)
|
||||||
**kwargs: Additional keyword arguments for the LLM backend.
|
- beam_width (int): Beam width for search (default: 4)
|
||||||
|
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
|
||||||
Returns:
|
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
||||||
The response string from the LLM.
|
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
||||||
|
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
||||||
|
- prune_ratio (float): Pruning ratio for search (default: 0.0)
|
||||||
|
- batch_recompute (bool): Enable batch recomputation (default: False)
|
||||||
|
- global_pruning (bool): Enable global pruning (default: False)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# """
|
||||||
|
# Sends a prompt to the LLM and returns the generated text.
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# prompt: The input prompt for the LLM.
|
||||||
|
# **kwargs: Additional keyword arguments for the LLM backend.
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# The response string from the LLM.
|
||||||
|
# """
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class OllamaChat(LLMInterface):
|
class OllamaChat(LLMInterface):
|
||||||
@@ -82,10 +110,22 @@ class HFChat(LLMInterface):
|
|||||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||||
try:
|
try:
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
import torch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'transformers' library is required for Hugging Face models. Please install it with 'pip install transformers'.")
|
raise ImportError("The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'.")
|
||||||
|
|
||||||
self.pipeline = pipeline("text-generation", model=model_name)
|
# Auto-detect device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
logger.info("CUDA is available. Using GPU.")
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
logger.info("No GPU detected. Using CPU.")
|
||||||
|
|
||||||
|
self.pipeline = pipeline("text-generation", model=model_name, device=device)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
# Sensible defaults for text generation
|
# Sensible defaults for text generation
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
# HNSW Index Storage Optimization
|
|
||||||
|
|
||||||
This document explains the storage optimization features available in the HNSW backend.
|
|
||||||
|
|
||||||
## Storage Modes
|
|
||||||
|
|
||||||
The HNSW backend supports two orthogonal optimization techniques:
|
|
||||||
|
|
||||||
### 1. CSR Compression (`is_compact=True`)
|
|
||||||
- Converts the graph structure from standard format to Compressed Sparse Row (CSR) format
|
|
||||||
- Reduces memory overhead from graph adjacency storage
|
|
||||||
- Maintains all embedding data for direct access
|
|
||||||
|
|
||||||
### 2. Embedding Pruning (`is_recompute=True`)
|
|
||||||
- Removes embedding vectors from the index file
|
|
||||||
- Replaces them with a NULL storage marker
|
|
||||||
- Requires recomputation via embedding server during search
|
|
||||||
- Must be used with `is_compact=True` for efficiency
|
|
||||||
|
|
||||||
## Performance Impact
|
|
||||||
|
|
||||||
**Storage Reduction (100 vectors, 384 dimensions):**
|
|
||||||
```
|
|
||||||
Standard format: 168 KB (embeddings + graph)
|
|
||||||
CSR only: 160 KB (embeddings + compressed graph)
|
|
||||||
CSR + Pruned: 6 KB (compressed graph only)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key Benefits:**
|
|
||||||
- **CSR compression**: ~5% size reduction from graph optimization
|
|
||||||
- **Embedding pruning**: ~95% size reduction by removing embeddings
|
|
||||||
- **Combined**: Up to 96% total storage reduction
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Standard format (largest)
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
is_compact=False,
|
|
||||||
is_recompute=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# CSR compressed (medium)
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# CSR + Pruned (smallest, requires embedding server)
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
is_compact=True, # Required for pruning
|
|
||||||
is_recompute=True # Default: enabled
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Trade-offs
|
|
||||||
|
|
||||||
| Mode | Storage | Search Speed | Memory Usage | Setup Complexity |
|
|
||||||
|------|---------|--------------|--------------|------------------|
|
|
||||||
| Standard | Largest | Fastest | Highest | Simple |
|
|
||||||
| CSR | Medium | Fast | Medium | Simple |
|
|
||||||
| CSR + Pruned | Smallest | Slower* | Lowest | Complex** |
|
|
||||||
|
|
||||||
*Requires network round-trip to embedding server for recomputation
|
|
||||||
**Needs embedding server and passages file for search
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Sanity check script to verify HNSW index pruning effectiveness.
|
|
||||||
Tests the difference in file sizes between pruned and non-pruned indices.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
|
|
||||||
# Add the project root to the Python path
|
|
||||||
project_root = Path(__file__).parent.parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
# Import backend packages to trigger plugin registration
|
|
||||||
import leann_backend_hnsw
|
|
||||||
|
|
||||||
from leann.api import LeannBuilder
|
|
||||||
|
|
||||||
def create_sample_documents(num_docs=1000):
|
|
||||||
"""Create sample documents for testing"""
|
|
||||||
documents = []
|
|
||||||
for i in range(num_docs):
|
|
||||||
documents.append(f"Sample document {i} with some random text content for testing purposes.")
|
|
||||||
return documents
|
|
||||||
|
|
||||||
def build_index(documents, output_dir, is_recompute=True):
|
|
||||||
"""Build HNSW index with specified recompute setting"""
|
|
||||||
index_path = os.path.join(output_dir, "test_index.hnsw")
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
M=16,
|
|
||||||
efConstruction=100,
|
|
||||||
distance_metric="mips",
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=is_recompute
|
|
||||||
)
|
|
||||||
|
|
||||||
for doc in documents:
|
|
||||||
builder.add_text(doc)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
def get_file_size(filepath):
|
|
||||||
"""Get file size in bytes"""
|
|
||||||
return os.path.getsize(filepath)
|
|
||||||
|
|
||||||
def main():
|
|
||||||
print("🔍 HNSW Pruning Sanity Check")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Create sample data
|
|
||||||
print("📊 Creating sample documents...")
|
|
||||||
documents = create_sample_documents(num_docs=1000)
|
|
||||||
print(f" Number of documents: {len(documents)}")
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
print(f"📁 Working in temporary directory: {temp_dir}")
|
|
||||||
|
|
||||||
# Build index with pruning (is_recompute=True)
|
|
||||||
print("\n🔨 Building index with pruning enabled (is_recompute=True)...")
|
|
||||||
pruned_dir = os.path.join(temp_dir, "pruned")
|
|
||||||
os.makedirs(pruned_dir, exist_ok=True)
|
|
||||||
|
|
||||||
pruned_index_path = build_index(documents, pruned_dir, is_recompute=True)
|
|
||||||
# Check what files were actually created
|
|
||||||
print(f" Looking for index files at: {pruned_index_path}")
|
|
||||||
import glob
|
|
||||||
files = glob.glob(f"{pruned_index_path}*")
|
|
||||||
print(f" Found files: {files}")
|
|
||||||
|
|
||||||
# Try to find the actual index file
|
|
||||||
if os.path.exists(f"{pruned_index_path}.index"):
|
|
||||||
pruned_index_file = f"{pruned_index_path}.index"
|
|
||||||
else:
|
|
||||||
# Look for any .index file in the directory
|
|
||||||
index_files = glob.glob(f"{pruned_dir}/*.index")
|
|
||||||
if index_files:
|
|
||||||
pruned_index_file = index_files[0]
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"No .index file found in {pruned_dir}")
|
|
||||||
|
|
||||||
pruned_size = get_file_size(pruned_index_file)
|
|
||||||
print(f" ✅ Pruned index built successfully")
|
|
||||||
print(f" 📏 Pruned index size: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
|
|
||||||
|
|
||||||
# Build index without pruning (is_recompute=False)
|
|
||||||
print("\n🔨 Building index without pruning (is_recompute=False)...")
|
|
||||||
non_pruned_dir = os.path.join(temp_dir, "non_pruned")
|
|
||||||
os.makedirs(non_pruned_dir, exist_ok=True)
|
|
||||||
|
|
||||||
non_pruned_index_path = build_index(documents, non_pruned_dir, is_recompute=False)
|
|
||||||
# Check what files were actually created
|
|
||||||
print(f" Looking for index files at: {non_pruned_index_path}")
|
|
||||||
files = glob.glob(f"{non_pruned_index_path}*")
|
|
||||||
print(f" Found files: {files}")
|
|
||||||
|
|
||||||
# Try to find the actual index file
|
|
||||||
if os.path.exists(f"{non_pruned_index_path}.index"):
|
|
||||||
non_pruned_index_file = f"{non_pruned_index_path}.index"
|
|
||||||
else:
|
|
||||||
# Look for any .index file in the directory
|
|
||||||
index_files = glob.glob(f"{non_pruned_dir}/*.index")
|
|
||||||
if index_files:
|
|
||||||
non_pruned_index_file = index_files[0]
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"No .index file found in {non_pruned_dir}")
|
|
||||||
|
|
||||||
non_pruned_size = get_file_size(non_pruned_index_file)
|
|
||||||
print(f" ✅ Non-pruned index built successfully")
|
|
||||||
print(f" 📏 Non-pruned index size: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
|
|
||||||
|
|
||||||
# Compare sizes
|
|
||||||
print("\n📊 Comparison Results:")
|
|
||||||
print("=" * 30)
|
|
||||||
size_diff = non_pruned_size - pruned_size
|
|
||||||
size_ratio = pruned_size / non_pruned_size if non_pruned_size > 0 else 0
|
|
||||||
reduction_percent = (1 - size_ratio) * 100
|
|
||||||
|
|
||||||
print(f"Non-pruned index: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
|
|
||||||
print(f"Pruned index: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
|
|
||||||
print(f"Size difference: {size_diff:,} bytes ({size_diff/1024:.1f} KB)")
|
|
||||||
print(f"Size ratio: {size_ratio:.3f}")
|
|
||||||
print(f"Size reduction: {reduction_percent:.1f}%")
|
|
||||||
|
|
||||||
# Verify pruning effectiveness
|
|
||||||
print("\n🔍 Verification:")
|
|
||||||
if size_diff > 0:
|
|
||||||
print(" ✅ Pruning is effective - pruned index is smaller")
|
|
||||||
if reduction_percent > 10:
|
|
||||||
print(f" ✅ Significant size reduction: {reduction_percent:.1f}%")
|
|
||||||
else:
|
|
||||||
print(f" ⚠️ Small size reduction: {reduction_percent:.1f}%")
|
|
||||||
else:
|
|
||||||
print(" ❌ Pruning appears ineffective - no size reduction")
|
|
||||||
|
|
||||||
# Check if passages files were created
|
|
||||||
pruned_passages = f"{pruned_index_path}.passages.json"
|
|
||||||
non_pruned_passages = f"{non_pruned_index_path}.passages.json"
|
|
||||||
|
|
||||||
print(f"\n📄 Passages files:")
|
|
||||||
print(f" Pruned passages file exists: {os.path.exists(pruned_passages)}")
|
|
||||||
print(f" Non-pruned passages file exists: {os.path.exists(non_pruned_passages)}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
Reference in New Issue
Block a user