diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 0fc8c4e..d6abce5 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -10,6 +10,7 @@ from typing import Any import numpy as np import torch +import time # Set up logger with proper level logger = logging.getLogger(__name__) @@ -28,6 +29,8 @@ def compute_embeddings( is_build: bool = False, batch_size: int = 32, adaptive_optimization: bool = True, + manual_tokenize: bool = False, + max_length: int = 512, ) -> np.ndarray: """ Unified embedding computation entry point @@ -50,6 +53,8 @@ def compute_embeddings( is_build=is_build, batch_size=batch_size, adaptive_optimization=adaptive_optimization, + manual_tokenize=manual_tokenize, + max_length=max_length, ) elif mode == "openai": return compute_embeddings_openai(texts, model_name) @@ -71,6 +76,8 @@ def compute_embeddings_sentence_transformers( batch_size: int = 32, is_build: bool = False, adaptive_optimization: bool = True, + manual_tokenize: bool = False, + max_length: int = 512, ) -> np.ndarray: """ Compute embeddings using SentenceTransformer with model caching and adaptive optimization @@ -214,20 +221,117 @@ def compute_embeddings_sentence_transformers( logger.info(f"Model cached: {cache_key}") # Compute embeddings with optimized inference mode - logger.info(f"Starting embedding computation... (batch_size: {batch_size})") + logger.info( + f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})" + ) - # Use torch.inference_mode for optimal performance - with torch.inference_mode(): - embeddings = model.encode( - texts, - batch_size=batch_size, - show_progress_bar=is_build, # Don't show progress bar in server environment - convert_to_numpy=True, - normalize_embeddings=False, - device=device, - ) + start_time = time.time() + if not manual_tokenize: + # Use SentenceTransformer's optimized encode path (default) + with torch.inference_mode(): + embeddings = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=is_build, # Don't show progress bar in server environment + convert_to_numpy=True, + normalize_embeddings=False, + device=device, + ) + # Synchronize if CUDA to measure accurate wall time + try: + if torch.cuda.is_available(): + torch.cuda.synchronize() + except Exception: + pass + else: + # Manual tokenization + forward pass using HF AutoTokenizer/AutoModel + try: + from transformers import AutoModel, AutoTokenizer # type: ignore + except Exception as e: + raise ImportError( + f"transformers is required for manual_tokenize=True: {e}" + ) + # Cache tokenizer and model + tok_cache_key = f"hf_tokenizer_{model_name}" + mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}" + if tok_cache_key in _model_cache and mdl_cache_key in _model_cache: + hf_tokenizer = _model_cache[tok_cache_key] + hf_model = _model_cache[mdl_cache_key] + logger.info("Using cached HF tokenizer/model for manual path") + else: + logger.info("Loading HF tokenizer/model for manual tokenization path") + hf_tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=True + ) + torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32 + hf_model = AutoModel.from_pretrained( + model_name, torch_dtype=torch_dtype + ) + hf_model.to(device) + hf_model.eval() + # Optional compile on supported devices + if device in ["cuda", "mps"]: + try: + hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore + except Exception: + pass + _model_cache[tok_cache_key] = hf_tokenizer + _model_cache[mdl_cache_key] = hf_model + + all_embeddings: list[np.ndarray] = [] + # Progress bar when building or for large inputs + show_progress = is_build or len(texts) > 32 + try: + if show_progress: + from tqdm import tqdm # type: ignore + batch_iter = tqdm( + range(0, len(texts), batch_size), + desc="Embedding (manual)", + unit="batch", + ) + else: + batch_iter = range(0, len(texts), batch_size) + except Exception: + batch_iter = range(0, len(texts), batch_size) + + with torch.inference_mode(): + for start_index in batch_iter: + end_index = min(start_index + batch_size, len(texts)) + batch_texts = texts[start_index:end_index] + inputs = hf_tokenizer( + batch_texts, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs = hf_model(**inputs) + last_hidden_state = outputs.last_hidden_state # (B, L, H) + attention_mask = inputs.get("attention_mask") + if attention_mask is None: + # Fallback: assume all tokens are valid + pooled = last_hidden_state.mean(dim=1) + else: + mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) + masked = last_hidden_state * mask + lengths = mask.sum(dim=1).clamp(min=1) + pooled = masked.sum(dim=1) / lengths + # Move to CPU float32 + batch_embeddings = pooled.detach().to("cpu").float().numpy() + all_embeddings.append(batch_embeddings) + + embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False) + try: + if torch.cuda.is_available(): + torch.cuda.synchronize() + except Exception: + pass + + end_time = time.time() logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}") + logger.info(f"Time taken: {end_time - start_time} seconds") # Validate results if np.isnan(embeddings).any() or np.isinf(embeddings).any():