feat(embeddings): add optional manual tokenization path (HF tokenizer+model) with mean pooling; default remains SentenceTransformer.encode
This commit is contained in:
@@ -10,6 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import time
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -28,6 +29,8 @@ def compute_embeddings(
|
|||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Unified embedding computation entry point
|
Unified embedding computation entry point
|
||||||
@@ -50,6 +53,8 @@ def compute_embeddings(
|
|||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
adaptive_optimization=adaptive_optimization,
|
adaptive_optimization=adaptive_optimization,
|
||||||
|
manual_tokenize=manual_tokenize,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(texts, model_name)
|
||||||
@@ -71,6 +76,8 @@ def compute_embeddings_sentence_transformers(
|
|||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||||
@@ -214,20 +221,117 @@ def compute_embeddings_sentence_transformers(
|
|||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
# Compute embeddings with optimized inference mode
|
# Compute embeddings with optimized inference mode
|
||||||
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
logger.info(
|
||||||
|
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||||
|
)
|
||||||
|
|
||||||
# Use torch.inference_mode for optimal performance
|
start_time = time.time()
|
||||||
with torch.inference_mode():
|
if not manual_tokenize:
|
||||||
embeddings = model.encode(
|
# Use SentenceTransformer's optimized encode path (default)
|
||||||
texts,
|
with torch.inference_mode():
|
||||||
batch_size=batch_size,
|
embeddings = model.encode(
|
||||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
texts,
|
||||||
convert_to_numpy=True,
|
batch_size=batch_size,
|
||||||
normalize_embeddings=False,
|
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||||
device=device,
|
convert_to_numpy=True,
|
||||||
)
|
normalize_embeddings=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# Synchronize if CUDA to measure accurate wall time
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
||||||
|
try:
|
||||||
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"transformers is required for manual_tokenize=True: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache tokenizer and model
|
||||||
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
|
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
||||||
|
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||||
|
hf_tokenizer = _model_cache[tok_cache_key]
|
||||||
|
hf_model = _model_cache[mdl_cache_key]
|
||||||
|
logger.info("Using cached HF tokenizer/model for manual path")
|
||||||
|
else:
|
||||||
|
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
||||||
|
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name, use_fast=True
|
||||||
|
)
|
||||||
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
|
hf_model = AutoModel.from_pretrained(
|
||||||
|
model_name, torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
hf_model.to(device)
|
||||||
|
hf_model.eval()
|
||||||
|
# Optional compile on supported devices
|
||||||
|
if device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_model_cache[tok_cache_key] = hf_tokenizer
|
||||||
|
_model_cache[mdl_cache_key] = hf_model
|
||||||
|
|
||||||
|
all_embeddings: list[np.ndarray] = []
|
||||||
|
# Progress bar when building or for large inputs
|
||||||
|
show_progress = is_build or len(texts) > 32
|
||||||
|
try:
|
||||||
|
if show_progress:
|
||||||
|
from tqdm import tqdm # type: ignore
|
||||||
|
batch_iter = tqdm(
|
||||||
|
range(0, len(texts), batch_size),
|
||||||
|
desc="Embedding (manual)",
|
||||||
|
unit="batch",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
except Exception:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
for start_index in batch_iter:
|
||||||
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
|
batch_texts = texts[start_index:end_index]
|
||||||
|
inputs = hf_tokenizer(
|
||||||
|
batch_texts,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
|
outputs = hf_model(**inputs)
|
||||||
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
|
attention_mask = inputs.get("attention_mask")
|
||||||
|
if attention_mask is None:
|
||||||
|
# Fallback: assume all tokens are valid
|
||||||
|
pooled = last_hidden_state.mean(dim=1)
|
||||||
|
else:
|
||||||
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||||
|
masked = last_hidden_state * mask
|
||||||
|
lengths = mask.sum(dim=1).clamp(min=1)
|
||||||
|
pooled = masked.sum(dim=1) / lengths
|
||||||
|
# Move to CPU float32
|
||||||
|
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||||
|
all_embeddings.append(batch_embeddings)
|
||||||
|
|
||||||
|
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
logger.info(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
|
||||||
# Validate results
|
# Validate results
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
|||||||
Reference in New Issue
Block a user