fix: faster embed

This commit is contained in:
Andy Lee
2025-11-24 05:30:11 +00:00
parent 66c6aad3e4
commit 36c44b8806
4 changed files with 110 additions and 95 deletions

View File

@@ -191,9 +191,7 @@ def create_hnsw_embedding_server(
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
)
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_distance_request(request: list[Any]) -> None:
nonlocal last_request_type, last_request_length
@@ -253,22 +251,14 @@ def create_hnsw_embedding_server(
except Exception as exc:
logger.error(f"Distance computation error, using sentinels: {exc}")
rep_socket.send(
msgpack.packb([response_distances], use_single_float=True)
)
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_embedding_by_id(request: Any) -> None:
nonlocal last_request_type, last_request_length
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
@@ -336,11 +326,9 @@ def create_hnsw_embedding_server(
logger.error(f"Embedding computation error, returning zeros: {exc}")
response_payload = [dims, flat_data]
rep_socket.send(
msgpack.packb(response_payload, use_single_float=True)
)
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
try:
while not shutdown_event.is_set():
@@ -359,9 +347,7 @@ def create_hnsw_embedding_server(
logger.error(f"Error unpacking ZMQ message: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(
msgpack.packb(safe, use_single_float=True)
)
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
continue
@@ -399,9 +385,7 @@ def create_hnsw_embedding_server(
logger.error(f"Error in ZMQ server loop: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(
msgpack.packb(safe, use_single_float=True)
)
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
finally:

View File

@@ -215,9 +215,14 @@ def compute_embeddings(
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
provider_options = provider_options or {}
wrapper_start_time = time.time()
logger.debug(
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
)
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
inner_start_time = time.time()
result = compute_embeddings_sentence_transformers(
texts,
model_name,
is_build=is_build,
@@ -226,6 +231,14 @@ def compute_embeddings(
manual_tokenize=manual_tokenize,
max_length=max_length,
)
inner_end_time = time.time()
wrapper_end_time = time.time()
logger.debug(
"[compute_embeddings] sentence-transformers timings: "
f"inner={inner_end_time - inner_start_time:.6f}s, "
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
)
return result
elif mode == "openai":
return compute_embeddings_openai(
texts,
@@ -271,6 +284,7 @@ def compute_embeddings_sentence_transformers(
is_build: Whether this is a build operation (shows progress bar)
adaptive_optimization: Whether to use adaptive optimization based on batch size
"""
outer_start_time = time.time()
# Handle empty input
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
@@ -301,7 +315,14 @@ def compute_embeddings_sentence_transformers(
# Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
pre_model_init_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers pre-model-init time "
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
)
# Check if model is already cached
start_time = time.time()
if cache_key in _model_cache:
logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key]
@@ -441,10 +462,13 @@ def compute_embeddings_sentence_transformers(
_model_cache[cache_key] = model
logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
end_time = time.time()
# Compute embeddings with optimized inference mode
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
start_time = time.time()
if not manual_tokenize:
@@ -465,32 +489,46 @@ def compute_embeddings_sentence_transformers(
except Exception:
pass
else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel.
# This path is reserved for an aggressively optimized FP pipeline
# (no quantization), mainly for experimentation.
try:
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual path")
logger.info("Using cached HF tokenizer/model for manual FP path")
else:
logger.info("Loading HF tokenizer/model for manual tokenization path")
logger.info("Loading HF tokenizer/model for manual FP path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
hf_model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch_dtype,
)
hf_model.to(device)
hf_model.eval()
# Optional compile on supported devices
if device in ["cuda", "mps"]:
try:
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
except Exception:
pass
hf_model = torch.compile( # type: ignore
hf_model, mode="reduce-overhead", dynamic=True
)
logger.info(
f"Applied torch.compile to HF model for {model_name} "
f"(device={device}, dtype={torch_dtype})"
)
except Exception as exc:
logger.warning(f"torch.compile optimization failed: {exc}")
_model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model
@@ -516,7 +554,6 @@ def compute_embeddings_sentence_transformers(
for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer(
batch_texts,
padding=True,
@@ -524,34 +561,17 @@ def compute_embeddings_sentence_transformers(
max_length=max_length,
return_tensors="pt",
)
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1)
else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings)
@@ -571,6 +591,12 @@ def compute_embeddings_sentence_transformers(
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
outer_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers total time "
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
)
return embeddings