fix: run faiss in subprocess to prevent kmp
This commit is contained in:
@@ -4,10 +4,12 @@ Memory comparison between Faiss HNSW and LEANN HNSW backend
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import psutil
|
import psutil
|
||||||
import gc
|
import gc
|
||||||
|
import subprocess
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
@@ -50,104 +52,40 @@ class MemoryTracker:
|
|||||||
|
|
||||||
|
|
||||||
def test_faiss_hnsw():
|
def test_faiss_hnsw():
|
||||||
"""Test Faiss HNSW Vector Store"""
|
"""Test Faiss HNSW Vector Store in subprocess"""
|
||||||
print("\n" + "=" * 50)
|
print("\n" + "=" * 50)
|
||||||
print("TESTING FAISS HNSW VECTOR STORE")
|
print("TESTING FAISS HNSW VECTOR STORE")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import faiss
|
result = subprocess.run([sys.executable, "examples/test_faiss_only.py"], capture_output=True, text=True, timeout=300)
|
||||||
from llama_index.core import (
|
|
||||||
SimpleDirectoryReader,
|
print(result.stdout)
|
||||||
VectorStoreIndex,
|
if result.stderr:
|
||||||
StorageContext,
|
print("Stderr:", result.stderr)
|
||||||
Settings,
|
|
||||||
)
|
if result.returncode != 0:
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
return {
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
"peak_memory": float("inf"),
|
||||||
except ImportError as e:
|
"error": f"Process failed with code {result.returncode}",
|
||||||
print(f"❌ Missing dependencies for Faiss test: {e}")
|
}
|
||||||
print("Please install:")
|
|
||||||
print(" pip install faiss-cpu")
|
# Parse peak memory from output
|
||||||
print(" pip install llama-index-vector-stores-faiss")
|
lines = result.stdout.split('\n')
|
||||||
print(" pip install llama-index-embeddings-huggingface")
|
peak_memory = 0.0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if "Peak Memory:" in line:
|
||||||
|
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||||
|
|
||||||
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
return {
|
return {
|
||||||
"build_time": float("inf"),
|
|
||||||
"peak_memory": float("inf"),
|
"peak_memory": float("inf"),
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
tracker = MemoryTracker("Faiss HNSW")
|
|
||||||
|
|
||||||
# Import and setup
|
|
||||||
tracker.checkpoint("Initial")
|
|
||||||
|
|
||||||
tracker.checkpoint("After imports")
|
|
||||||
|
|
||||||
# Setup embedding model (same as LEANN)
|
|
||||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
|
||||||
Settings.embed_model = embed_model
|
|
||||||
|
|
||||||
tracker.checkpoint("After embedding model setup")
|
|
||||||
|
|
||||||
# Create Faiss index
|
|
||||||
d = 768 # facebook/contriever embedding dimension
|
|
||||||
faiss_index = faiss.IndexHNSWFlat(d, 32) # M=32 same as LEANN
|
|
||||||
faiss_index.hnsw.efConstruction = 64 # same as LEANN complexity
|
|
||||||
|
|
||||||
tracker.checkpoint("After Faiss index creation")
|
|
||||||
|
|
||||||
# Load documents
|
|
||||||
documents = SimpleDirectoryReader(
|
|
||||||
"examples/data",
|
|
||||||
recursive=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
|
||||||
).load_data()
|
|
||||||
|
|
||||||
tracker.checkpoint("After document loading")
|
|
||||||
|
|
||||||
# Create vector store and index
|
|
||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
|
||||||
|
|
||||||
# Build index
|
|
||||||
print("Building Faiss HNSW index...")
|
|
||||||
start_time = time.time()
|
|
||||||
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
|
|
||||||
build_time = time.time() - start_time
|
|
||||||
|
|
||||||
tracker.checkpoint("After index building")
|
|
||||||
|
|
||||||
# Save index
|
|
||||||
index.storage_context.persist("./storage_faiss")
|
|
||||||
tracker.checkpoint("After index saving")
|
|
||||||
|
|
||||||
# Test queries
|
|
||||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
|
||||||
|
|
||||||
print("Running queries...")
|
|
||||||
queries = [
|
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
|
||||||
"What is LEANN and how does it work?",
|
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, query in enumerate(queries):
|
|
||||||
start_time = time.time()
|
|
||||||
response = query_engine.query(query)
|
|
||||||
query_time = time.time() - start_time
|
|
||||||
print(f"Query {i + 1} time: {query_time:.3f}s")
|
|
||||||
tracker.checkpoint(f"After query {i + 1}")
|
|
||||||
|
|
||||||
peak_memory = tracker.summary()
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
del index, vector_store, storage_context, faiss_index
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
return {"build_time": build_time, "peak_memory": peak_memory, "tracker": tracker}
|
|
||||||
|
|
||||||
|
|
||||||
def test_leann_hnsw():
|
def test_leann_hnsw():
|
||||||
"""Test LEANN HNSW Backend"""
|
"""Test LEANN HNSW Backend"""
|
||||||
@@ -213,13 +151,11 @@ def test_leann_hnsw():
|
|||||||
tracker.checkpoint("After builder setup")
|
tracker.checkpoint("After builder setup")
|
||||||
|
|
||||||
print("Building LEANN HNSW index...")
|
print("Building LEANN HNSW index...")
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for chunk_text in all_texts:
|
for chunk_text in all_texts:
|
||||||
builder.add_text(chunk_text)
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
builder.build_index(INDEX_PATH)
|
||||||
build_time = time.time() - start_time
|
|
||||||
|
|
||||||
tracker.checkpoint("After index building")
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
@@ -278,22 +214,37 @@ def test_leann_hnsw():
|
|||||||
|
|
||||||
for i, query in enumerate(queries):
|
for i, query in enumerate(queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
response = chat.ask(
|
_ = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32)
|
||||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
|
|
||||||
)
|
|
||||||
query_time = time.time() - start_time
|
query_time = time.time() - start_time
|
||||||
print(f"Query {i + 1} time: {query_time:.3f}s")
|
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||||
tracker.checkpoint(f"After query {i + 1}")
|
tracker.checkpoint(f"After query {i + 1}")
|
||||||
|
|
||||||
peak_memory = tracker.summary()
|
peak_memory = tracker.summary()
|
||||||
|
|
||||||
# Clean up
|
# Get storage size before cleanup - only index files (exclude text data)
|
||||||
del chat, builder
|
storage_size = 0
|
||||||
if INDEX_DIR.exists():
|
if INDEX_DIR.exists():
|
||||||
shutil.rmtree(INDEX_DIR)
|
total_size = 0
|
||||||
|
for dirpath, dirnames, filenames in os.walk(str(INDEX_DIR)):
|
||||||
|
for filename in filenames:
|
||||||
|
# Only count actual index files, skip text data and backups
|
||||||
|
if filename.endswith(('.old', '.tmp', '.bak', '.jsonl', '.json')):
|
||||||
|
continue
|
||||||
|
# Count .index, .idx, .map files (actual index structures)
|
||||||
|
if filename.endswith(('.index', '.idx', '.map')):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
# Clean up (but keep directory for storage size comparison)
|
||||||
|
del chat, builder
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
return {"build_time": build_time, "peak_memory": peak_memory, "tracker": tracker}
|
return {
|
||||||
|
"peak_memory": peak_memory,
|
||||||
|
"storage_size": storage_size,
|
||||||
|
"tracker": tracker,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -316,36 +267,61 @@ def main():
|
|||||||
print("FINAL COMPARISON")
|
print("FINAL COMPARISON")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Get storage sizes
|
||||||
|
faiss_storage_size = 0
|
||||||
|
leann_storage_size = leann_results.get("storage_size", 0)
|
||||||
|
|
||||||
|
# Get Faiss storage size using Python
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, dirnames, filenames in os.walk("./storage_faiss"):
|
||||||
|
for filename in filenames:
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
# LEANN storage size is already captured in leann_results
|
||||||
|
|
||||||
print(f"Faiss HNSW:")
|
print(f"Faiss HNSW:")
|
||||||
if "error" in faiss_results:
|
if "error" in faiss_results:
|
||||||
print(f" ❌ Failed: {faiss_results['error']}")
|
print(f" ❌ Failed: {faiss_results['error']}")
|
||||||
else:
|
else:
|
||||||
print(f" Build Time: {faiss_results['build_time']:.3f}s")
|
|
||||||
print(f" Peak Memory: {faiss_results['peak_memory']:.1f} MB")
|
print(f" Peak Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {faiss_storage_size:.1f} MB")
|
||||||
|
|
||||||
print(f"\nLEANN HNSW:")
|
print(f"\nLEANN HNSW:")
|
||||||
print(f" Build Time: {leann_results['build_time']:.3f}s")
|
|
||||||
print(f" Peak Memory: {leann_results['peak_memory']:.1f} MB")
|
print(f" Peak Memory: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
|
||||||
# Calculate improvements only if Faiss test succeeded
|
# Calculate improvements only if Faiss test succeeded
|
||||||
if "error" not in faiss_results:
|
if "error" not in faiss_results:
|
||||||
time_ratio = faiss_results["build_time"] / leann_results["build_time"]
|
|
||||||
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
|
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
|
||||||
|
|
||||||
print(f"\nLEANN vs Faiss:")
|
print(f"\nLEANN vs Faiss:")
|
||||||
print(
|
print(f" Memory Usage: {memory_ratio:.1f}x less")
|
||||||
f" Build Time: {time_ratio:.2f}x {'faster' if time_ratio > 1 else 'slower'}"
|
|
||||||
)
|
# Storage comparison - be clear about which is larger
|
||||||
print(
|
if leann_storage_size > faiss_storage_size:
|
||||||
f" Memory Usage: {memory_ratio:.2f}x {'less' if memory_ratio > 1 else 'more'}"
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
)
|
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||||
|
elif faiss_storage_size > leann_storage_size:
|
||||||
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
|
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||||
|
else:
|
||||||
|
print(f" Storage Size: similar")
|
||||||
|
|
||||||
print(
|
print(f"\nSavings:")
|
||||||
f"\nMemory Savings: {faiss_results['peak_memory'] - leann_results['peak_memory']:.1f} MB"
|
memory_saving = faiss_results['peak_memory'] - leann_results['peak_memory']
|
||||||
)
|
storage_diff = faiss_storage_size - leann_storage_size
|
||||||
|
print(f" Memory: {memory_saving:.1f} MB")
|
||||||
|
if storage_diff >= 0:
|
||||||
|
print(f" Storage: {storage_diff:.1f} MB saved")
|
||||||
|
else:
|
||||||
|
print(f" Storage: {abs(storage_diff):.1f} MB additional used")
|
||||||
else:
|
else:
|
||||||
print(f"\n✅ LEANN HNSW ran successfully!")
|
print(f"\n✅ LEANN HNSW ran successfully!")
|
||||||
print(f"📊 LEANN Memory Usage: {leann_results['peak_memory']:.1f} MB")
|
print(f"📊 LEANN Memory Usage: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f"📊 LEANN Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -485,8 +485,6 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
request_payload = msgpack.unpackb(message_bytes)
|
request_payload = msgpack.unpackb(message_bytes)
|
||||||
print(f"DEBUG: Raw request_payload: {request_payload}")
|
|
||||||
print(f"DEBUG: request_payload type: {type(request_payload)}")
|
|
||||||
if isinstance(request_payload, list):
|
if isinstance(request_payload, list):
|
||||||
print(f"DEBUG: request_payload length: {len(request_payload)}")
|
print(f"DEBUG: request_payload length: {len(request_payload)}")
|
||||||
for i, item in enumerate(request_payload):
|
for i, item in enumerate(request_payload):
|
||||||
|
|||||||
Reference in New Issue
Block a user