fix: faster embed
This commit is contained in:
81
issue_159.py
81
issue_159.py
@@ -9,120 +9,125 @@ Configuration:
|
|||||||
- backend: hnsw
|
- backend: hnsw
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
|
||||||
|
|
||||||
# Configuration matching the issue
|
# Configuration matching the issue
|
||||||
INDEX_PATH = "./test_issue_159.leann"
|
INDEX_PATH = "./test_issue_159.leann"
|
||||||
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
|
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
|
||||||
BACKEND_NAME = "hnsw"
|
BACKEND_NAME = "hnsw"
|
||||||
BEAM_WIDTH = 10 # Note: beam_width is mainly for DiskANN, not HNSW
|
|
||||||
|
|
||||||
def generate_test_data(num_chunks=90000, chunk_size=2000):
|
def generate_test_data(num_chunks=90000, chunk_size=2000):
|
||||||
"""Generate test data similar to 180MB text (~90K chunks)"""
|
"""Generate test data similar to 180MB text (~90K chunks)"""
|
||||||
# Each chunk is approximately 2000 characters
|
# Each chunk is approximately 2000 characters
|
||||||
# 90K chunks * 2000 chars ≈ 180MB
|
# 90K chunks * 2000 chars ≈ 180MB
|
||||||
chunks = []
|
chunks = []
|
||||||
base_text = "这是一个测试文档。LEANN是一个创新的向量数据库,通过图基选择性重计算实现97%的存储节省。"
|
base_text = (
|
||||||
|
"这是一个测试文档。LEANN是一个创新的向量数据库,通过图基选择性重计算实现97%的存储节省。"
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(num_chunks):
|
for i in range(num_chunks):
|
||||||
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
|
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
|
||||||
chunks.append(chunk[:chunk_size])
|
chunks.append(chunk[:chunk_size])
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def test_search_performance():
|
def test_search_performance():
|
||||||
"""Test search performance with different configurations"""
|
"""Test search performance with different configurations"""
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("Testing LEANN Search Performance (Issue #159)")
|
print("Testing LEANN Search Performance (Issue #159)")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
# Check if index exists - skip build if it does
|
meta_path = Path(f"{INDEX_PATH}.meta.json")
|
||||||
index_path = Path(INDEX_PATH)
|
if meta_path.exists():
|
||||||
if True:
|
|
||||||
print(f"\n✓ Index already exists at {INDEX_PATH}")
|
print(f"\n✓ Index already exists at {INDEX_PATH}")
|
||||||
print(" Skipping build phase. Delete the index to rebuild.")
|
print(" Skipping build phase. Delete the index to rebuild.")
|
||||||
else:
|
else:
|
||||||
print(f"\n📦 Building index...")
|
print("\n📦 Building index...")
|
||||||
print(f" Backend: {BACKEND_NAME}")
|
print(f" Backend: {BACKEND_NAME}")
|
||||||
print(f" Embedding Model: {EMBEDDING_MODEL}")
|
print(f" Embedding Model: {EMBEDDING_MODEL}")
|
||||||
print(f" Generating test data (~90K chunks, ~180MB)...")
|
print(" Generating test data (~90K chunks, ~180MB)...")
|
||||||
|
|
||||||
chunks = generate_test_data(num_chunks=90000)
|
chunks = generate_test_data(num_chunks=90000)
|
||||||
print(f" Generated {len(chunks)} chunks")
|
print(f" Generated {len(chunks)} chunks")
|
||||||
print(f" Total text size: {sum(len(c) for c in chunks) / (1024*1024):.2f} MB")
|
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=BACKEND_NAME,
|
backend_name=BACKEND_NAME,
|
||||||
embedding_model=EMBEDDING_MODEL,
|
embedding_model=EMBEDDING_MODEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f" Adding chunks to builder...")
|
print(" Adding chunks to builder...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
builder.add_text(chunk)
|
builder.add_text(chunk)
|
||||||
if (i + 1) % 10000 == 0:
|
if (i + 1) % 10000 == 0:
|
||||||
print(f" Added {i + 1}/{len(chunks)} chunks...")
|
print(f" Added {i + 1}/{len(chunks)} chunks...")
|
||||||
|
|
||||||
print(f" Building index...")
|
print(" Building index...")
|
||||||
build_start = time.time()
|
build_start = time.time()
|
||||||
builder.build_index(INDEX_PATH)
|
builder.build_index(INDEX_PATH)
|
||||||
build_time = time.time() - build_start
|
build_time = time.time() - build_start
|
||||||
print(f" ✓ Index built in {build_time:.2f} seconds")
|
print(f" ✓ Index built in {build_time:.2f} seconds")
|
||||||
|
|
||||||
# Test search with different complexity values
|
# Test search with different complexity values
|
||||||
print(f"\n🔍 Testing search performance...")
|
print("\n🔍 Testing search performance...")
|
||||||
searcher = LeannSearcher(INDEX_PATH)
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
|
|
||||||
test_query = "LEANN向量数据库存储优化"
|
test_query = "LEANN向量数据库存储优化"
|
||||||
|
|
||||||
# Test with default complexity (64)
|
# Test with default complexity (64)
|
||||||
print(f"\n Test 1: Default complexity (64) `1 ")
|
print("\n Test 1: Default complexity (64) `1 ")
|
||||||
print(f" Query: '{test_query}'")
|
print(f" Query: '{test_query}'")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = searcher.search(test_query, top_k=10, complexity=64, beam_width=BEAM_WIDTH)
|
results = searcher.search(test_query, top_k=10, complexity=64)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
||||||
print(f" Results: {len(results)} items")
|
print(f" Results: {len(results)} items")
|
||||||
|
|
||||||
# Test with default complexity (64)
|
# Test with default complexity (64)
|
||||||
print(f"\n Test 1: Default complexity (64)")
|
print("\n Test 1: Default complexity (64)")
|
||||||
print(f" Query: '{test_query}'")
|
print(f" Query: '{test_query}'")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = searcher.search(test_query, top_k=10, complexity=64, beam_width=BEAM_WIDTH)
|
results = searcher.search(test_query, top_k=10, complexity=64)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
||||||
print(f" Results: {len(results)} items")
|
print(f" Results: {len(results)} items")
|
||||||
|
|
||||||
# Test with lower complexity (32)
|
# Test with lower complexity (32)
|
||||||
print(f"\n Test 2: Lower complexity (32)")
|
print("\n Test 2: Lower complexity (32)")
|
||||||
print(f" Query: '{test_query}'")
|
print(f" Query: '{test_query}'")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = searcher.search(test_query, top_k=10, complexity=32, beam_width=BEAM_WIDTH)
|
results = searcher.search(test_query, top_k=10, complexity=32)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
||||||
print(f" Results: {len(results)} items")
|
print(f" Results: {len(results)} items")
|
||||||
|
|
||||||
# Test with even lower complexity (16)
|
# Test with even lower complexity (16)
|
||||||
print(f"\n Test 3: Lower complexity (16)")
|
print("\n Test 3: Lower complexity (16)")
|
||||||
print(f" Query: '{test_query}'")
|
print(f" Query: '{test_query}'")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = searcher.search(test_query, top_k=10, complexity=16, beam_width=BEAM_WIDTH)
|
results = searcher.search(test_query, top_k=10, complexity=16)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
||||||
print(f" Results: {len(results)} items")
|
print(f" Results: {len(results)} items")
|
||||||
|
|
||||||
# Test with minimal complexity (8)
|
# Test with minimal complexity (8)
|
||||||
print(f"\n Test 4: Minimal complexity (8)")
|
print("\n Test 4: Minimal complexity (8)")
|
||||||
print(f" Query: '{test_query}'")
|
print(f" Query: '{test_query}'")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = searcher.search(test_query, top_k=10, complexity=8, beam_width=BEAM_WIDTH)
|
results = searcher.search(test_query, top_k=10, complexity=8)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
||||||
print(f" Results: {len(results)} items")
|
print(f" Results: {len(results)} items")
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("Performance Analysis:")
|
print("Performance Analysis:")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -139,6 +144,6 @@ def test_search_performance():
|
|||||||
print("- Consider using DiskANN backend for better performance on large datasets")
|
print("- Consider using DiskANN backend for better performance on large datasets")
|
||||||
print("- Or use a smaller embedding model if speed is critical")
|
print("- Or use a smaller embedding model if speed is critical")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_search_performance()
|
test_search_performance()
|
||||||
|
|
||||||
|
|||||||
@@ -191,9 +191,7 @@ def create_hnsw_embedding_server(
|
|||||||
)
|
)
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(
|
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_distance_request(request: list[Any]) -> None:
|
def _handle_distance_request(request: list[Any]) -> None:
|
||||||
nonlocal last_request_type, last_request_length
|
nonlocal last_request_type, last_request_length
|
||||||
@@ -253,22 +251,14 @@ def create_hnsw_embedding_server(
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Distance computation error, using sentinels: {exc}")
|
logger.error(f"Distance computation error, using sentinels: {exc}")
|
||||||
|
|
||||||
rep_socket.send(
|
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||||
msgpack.packb([response_distances], use_single_float=True)
|
|
||||||
)
|
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_embedding_by_id(request: Any) -> None:
|
def _handle_embedding_by_id(request: Any) -> None:
|
||||||
nonlocal last_request_type, last_request_length
|
nonlocal last_request_type, last_request_length
|
||||||
|
|
||||||
if (
|
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 1
|
|
||||||
and isinstance(request[0], list)
|
|
||||||
):
|
|
||||||
node_ids = request[0]
|
node_ids = request[0]
|
||||||
elif isinstance(request, list):
|
elif isinstance(request, list):
|
||||||
node_ids = request
|
node_ids = request
|
||||||
@@ -336,11 +326,9 @@ def create_hnsw_embedding_server(
|
|||||||
logger.error(f"Embedding computation error, returning zeros: {exc}")
|
logger.error(f"Embedding computation error, returning zeros: {exc}")
|
||||||
|
|
||||||
response_payload = [dims, flat_data]
|
response_payload = [dims, flat_data]
|
||||||
rep_socket.send(
|
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
|
||||||
msgpack.packb(response_payload, use_single_float=True)
|
|
||||||
)
|
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
@@ -359,9 +347,7 @@ def create_hnsw_embedding_server(
|
|||||||
logger.error(f"Error unpacking ZMQ message: {exc}")
|
logger.error(f"Error unpacking ZMQ message: {exc}")
|
||||||
try:
|
try:
|
||||||
safe = _build_safe_fallback()
|
safe = _build_safe_fallback()
|
||||||
rep_socket.send(
|
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||||
msgpack.packb(safe, use_single_float=True)
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
continue
|
continue
|
||||||
@@ -399,9 +385,7 @@ def create_hnsw_embedding_server(
|
|||||||
logger.error(f"Error in ZMQ server loop: {exc}")
|
logger.error(f"Error in ZMQ server loop: {exc}")
|
||||||
try:
|
try:
|
||||||
safe = _build_safe_fallback()
|
safe = _build_safe_fallback()
|
||||||
rep_socket.send(
|
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||||
msgpack.packb(safe, use_single_float=True)
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: e2d243c40d...301bf24f14
@@ -215,9 +215,14 @@ def compute_embeddings(
|
|||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
"""
|
"""
|
||||||
provider_options = provider_options or {}
|
provider_options = provider_options or {}
|
||||||
|
wrapper_start_time = time.time()
|
||||||
|
logger.debug(
|
||||||
|
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
|
||||||
|
)
|
||||||
|
|
||||||
if mode == "sentence-transformers":
|
if mode == "sentence-transformers":
|
||||||
return compute_embeddings_sentence_transformers(
|
inner_start_time = time.time()
|
||||||
|
result = compute_embeddings_sentence_transformers(
|
||||||
texts,
|
texts,
|
||||||
model_name,
|
model_name,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
@@ -226,6 +231,14 @@ def compute_embeddings(
|
|||||||
manual_tokenize=manual_tokenize,
|
manual_tokenize=manual_tokenize,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
|
inner_end_time = time.time()
|
||||||
|
wrapper_end_time = time.time()
|
||||||
|
logger.debug(
|
||||||
|
"[compute_embeddings] sentence-transformers timings: "
|
||||||
|
f"inner={inner_end_time - inner_start_time:.6f}s, "
|
||||||
|
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
|
||||||
|
)
|
||||||
|
return result
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(
|
return compute_embeddings_openai(
|
||||||
texts,
|
texts,
|
||||||
@@ -271,6 +284,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
"""
|
"""
|
||||||
|
outer_start_time = time.time()
|
||||||
# Handle empty input
|
# Handle empty input
|
||||||
if not texts:
|
if not texts:
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
@@ -301,7 +315,14 @@ def compute_embeddings_sentence_transformers(
|
|||||||
# Create cache key
|
# Create cache key
|
||||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||||
|
|
||||||
|
pre_model_init_end_time = time.time()
|
||||||
|
logger.debug(
|
||||||
|
"compute_embeddings_sentence_transformers pre-model-init time "
|
||||||
|
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
|
||||||
|
)
|
||||||
|
|
||||||
# Check if model is already cached
|
# Check if model is already cached
|
||||||
|
start_time = time.time()
|
||||||
if cache_key in _model_cache:
|
if cache_key in _model_cache:
|
||||||
logger.info(f"Using cached optimized model: {model_name}")
|
logger.info(f"Using cached optimized model: {model_name}")
|
||||||
model = _model_cache[cache_key]
|
model = _model_cache[cache_key]
|
||||||
@@ -441,10 +462,13 @@ def compute_embeddings_sentence_transformers(
|
|||||||
_model_cache[cache_key] = model
|
_model_cache[cache_key] = model
|
||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
# Compute embeddings with optimized inference mode
|
end_time = time.time()
|
||||||
logger.info(
|
|
||||||
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
# Compute embeddings with optimized inference mode
|
||||||
)
|
logger.info(
|
||||||
|
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||||
|
)
|
||||||
|
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if not manual_tokenize:
|
if not manual_tokenize:
|
||||||
@@ -465,32 +489,46 @@ def compute_embeddings_sentence_transformers(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel.
|
||||||
|
# This path is reserved for an aggressively optimized FP pipeline
|
||||||
|
# (no quantization), mainly for experimentation.
|
||||||
try:
|
try:
|
||||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||||
|
|
||||||
# Cache tokenizer and model
|
|
||||||
tok_cache_key = f"hf_tokenizer_{model_name}"
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp"
|
||||||
|
|
||||||
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||||
hf_tokenizer = _model_cache[tok_cache_key]
|
hf_tokenizer = _model_cache[tok_cache_key]
|
||||||
hf_model = _model_cache[mdl_cache_key]
|
hf_model = _model_cache[mdl_cache_key]
|
||||||
logger.info("Using cached HF tokenizer/model for manual path")
|
logger.info("Using cached HF tokenizer/model for manual FP path")
|
||||||
else:
|
else:
|
||||||
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
logger.info("Loading HF tokenizer/model for manual FP path")
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
|
||||||
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
hf_model = AutoModel.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
hf_model.to(device)
|
hf_model.to(device)
|
||||||
|
|
||||||
hf_model.eval()
|
hf_model.eval()
|
||||||
# Optional compile on supported devices
|
# Optional compile on supported devices
|
||||||
if device in ["cuda", "mps"]:
|
if device in ["cuda", "mps"]:
|
||||||
try:
|
try:
|
||||||
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
hf_model = torch.compile( # type: ignore
|
||||||
except Exception:
|
hf_model, mode="reduce-overhead", dynamic=True
|
||||||
pass
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Applied torch.compile to HF model for {model_name} "
|
||||||
|
f"(device={device}, dtype={torch_dtype})"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"torch.compile optimization failed: {exc}")
|
||||||
|
|
||||||
_model_cache[tok_cache_key] = hf_tokenizer
|
_model_cache[tok_cache_key] = hf_tokenizer
|
||||||
_model_cache[mdl_cache_key] = hf_model
|
_model_cache[mdl_cache_key] = hf_model
|
||||||
|
|
||||||
@@ -516,7 +554,6 @@ def compute_embeddings_sentence_transformers(
|
|||||||
for start_index in batch_iter:
|
for start_index in batch_iter:
|
||||||
end_index = min(start_index + batch_size, len(texts))
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
batch_texts = texts[start_index:end_index]
|
batch_texts = texts[start_index:end_index]
|
||||||
tokenize_start_time = time.time()
|
|
||||||
inputs = hf_tokenizer(
|
inputs = hf_tokenizer(
|
||||||
batch_texts,
|
batch_texts,
|
||||||
padding=True,
|
padding=True,
|
||||||
@@ -524,34 +561,17 @@ def compute_embeddings_sentence_transformers(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
tokenize_end_time = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
|
||||||
)
|
|
||||||
# Print shapes of all input tensors for debugging
|
|
||||||
for k, v in inputs.items():
|
|
||||||
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
|
||||||
to_device_start_time = time.time()
|
|
||||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
to_device_end_time = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
|
||||||
)
|
|
||||||
forward_start_time = time.time()
|
|
||||||
outputs = hf_model(**inputs)
|
outputs = hf_model(**inputs)
|
||||||
forward_end_time = time.time()
|
|
||||||
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
|
||||||
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
attention_mask = inputs.get("attention_mask")
|
attention_mask = inputs.get("attention_mask")
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
# Fallback: assume all tokens are valid
|
|
||||||
pooled = last_hidden_state.mean(dim=1)
|
pooled = last_hidden_state.mean(dim=1)
|
||||||
else:
|
else:
|
||||||
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||||
masked = last_hidden_state * mask
|
masked = last_hidden_state * mask
|
||||||
lengths = mask.sum(dim=1).clamp(min=1)
|
lengths = mask.sum(dim=1).clamp(min=1)
|
||||||
pooled = masked.sum(dim=1) / lengths
|
pooled = masked.sum(dim=1) / lengths
|
||||||
# Move to CPU float32
|
|
||||||
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||||
all_embeddings.append(batch_embeddings)
|
all_embeddings.append(batch_embeddings)
|
||||||
|
|
||||||
@@ -571,6 +591,12 @@ def compute_embeddings_sentence_transformers(
|
|||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
||||||
|
|
||||||
|
outer_end_time = time.time()
|
||||||
|
logger.debug(
|
||||||
|
"compute_embeddings_sentence_transformers total time "
|
||||||
|
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
|
||||||
|
)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user