From d3f85678ecaaa1c5a24cb39c19e30b24a9115535 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Tue, 22 Jul 2025 19:38:22 -0700 Subject: [PATCH] perf: much faster loading and embedding serving --- .../leann-backend-diskann/third_party/DiskANN | 2 +- packages/leann-core/src/leann/api.py | 8 +- .../leann-core/src/leann/embedding_compute.py | 115 ++++++++++++++---- 3 files changed, 100 insertions(+), 25 deletions(-) diff --git a/packages/leann-backend-diskann/third_party/DiskANN b/packages/leann-backend-diskann/third_party/DiskANN index af2a264..25339b0 160000 --- a/packages/leann-backend-diskann/third_party/DiskANN +++ b/packages/leann-backend-diskann/third_party/DiskANN @@ -1 +1 @@ -Subproject commit af2a26481e65232b57b82d96e68833cdee9f7635 +Subproject commit 25339b03413b5067c25b6092ea3e0f77ef8515c8 diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 111a52b..31d9ab1 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -421,9 +421,9 @@ class LeannSearcher: logger.info(f" Top_k: {top_k}") logger.info(f" Additional kwargs: {kwargs}") - start_time = time.time() - zmq_port = None + + start_time = time.time() if recompute_embeddings: zmq_port = self.backend_impl._ensure_server_running( self.meta_path_str, @@ -431,6 +431,10 @@ class LeannSearcher: **kwargs, ) del expected_zmq_port + zmq_time = time.time() - start_time + logger.info(f" Launching server time: {zmq_time} seconds") + + start_time = time.time() query_embedding = self.backend_impl.compute_query_embedding( query, diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 9f6be79..ce85055 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -25,6 +25,8 @@ def compute_embeddings( model_name: str, mode: str = "sentence-transformers", is_build: bool = False, + batch_size: int = 32, + adaptive_optimization: bool = True, ) -> np.ndarray: """ Unified embedding computation entry point @@ -33,13 +35,20 @@ def compute_embeddings( texts: List of texts to compute embeddings for model_name: Model name mode: Computation mode ('sentence-transformers', 'openai', 'mlx') + is_build: Whether this is a build operation (shows progress bar) + batch_size: Batch size for processing + adaptive_optimization: Whether to use adaptive optimization based on batch size Returns: Normalized embeddings array, shape: (len(texts), embedding_dim) """ if mode == "sentence-transformers": return compute_embeddings_sentence_transformers( - texts, model_name, is_build=is_build + texts, + model_name, + is_build=is_build, + batch_size=batch_size, + adaptive_optimization=adaptive_optimization, ) elif mode == "openai": return compute_embeddings_openai(texts, model_name) @@ -56,9 +65,19 @@ def compute_embeddings_sentence_transformers( device: str = "auto", batch_size: int = 32, is_build: bool = False, + adaptive_optimization: bool = True, ) -> np.ndarray: """ - Compute embeddings using SentenceTransformer with model caching + Compute embeddings using SentenceTransformer with model caching and adaptive optimization + + Args: + texts: List of texts to compute embeddings for + model_name: Model name + use_fp16: Whether to use FP16 precision + device: Device to use ('auto', 'cuda', 'mps', 'cpu') + batch_size: Batch size for processing + is_build: Whether this is a build operation (shows progress bar) + adaptive_optimization: Whether to use adaptive optimization based on batch size """ # Handle empty input if not texts: @@ -76,28 +95,68 @@ def compute_embeddings_sentence_transformers( else: device = "cpu" + # Apply optimizations based on benchmark results + if adaptive_optimization: + # Use optimal batch_size constants for different devices based on benchmark results + if device == "mps": + batch_size = 128 # MPS optimal batch size from benchmark + if model_name == "Qwen/Qwen3-Embedding-0.6B": + batch_size = 64 + elif device == "cuda": + batch_size = 256 # CUDA optimal batch size + # Keep original batch_size for CPU + # Create cache key - cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}" + cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized" # Check if model is already cached if cache_key in _model_cache: - logger.info(f"Using cached model: {model_name}") + logger.info(f"Using cached optimized model: {model_name}") model = _model_cache[cache_key] else: - logger.info(f"Loading and caching SentenceTransformer model: {model_name}") + logger.info( + f"Loading and caching optimized SentenceTransformer model: {model_name}" + ) from sentence_transformers import SentenceTransformer logger.info(f"Using device: {device}") - # Prepare model and tokenizer optimization parameters + # Apply hardware optimizations + if device == "cuda": + # TODO: Haven't tested this yet + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + torch.cuda.set_per_process_memory_fraction(0.9) + elif device == "mps": + try: + if hasattr(torch.mps, "set_per_process_memory_fraction"): + torch.mps.set_per_process_memory_fraction(0.9) + except AttributeError: + logger.warning( + "Some MPS optimizations not available in this PyTorch version" + ) + elif device == "cpu": + # TODO: Haven't tested this yet + torch.set_num_threads(min(8, os.cpu_count() or 4)) + try: + torch.backends.mkldnn.enabled = True + except AttributeError: + pass + + # Prepare optimized model and tokenizer parameters model_kwargs = { "torch_dtype": torch.float16 if use_fp16 else torch.float32, "low_cpu_mem_usage": True, "_fast_init": True, + "attn_implementation": "eager", # Use eager attention for speed } tokenizer_kwargs = { "use_fast": True, + "padding": True, + "truncation": True, } try: @@ -128,32 +187,44 @@ def compute_embeddings_sentence_transformers( ) logger.info("Model loaded successfully! (network + optimized)") - # Apply additional optimizations (if supported) + # Apply additional optimizations based on mode if use_fp16 and device in ["cuda", "mps"]: try: model = model.half() - model = torch.compile(model) - logger.info( - f"Using FP16 precision and compile optimization: {model_name}" - ) + logger.info(f"Applied FP16 precision: {model_name}") except Exception as e: - logger.warning(f"FP16 or compile optimization failed: {e}") + logger.warning(f"FP16 optimization failed: {e}") + + # Apply torch.compile optimization + if device in ["cuda", "mps"]: + try: + model = torch.compile(model, mode="reduce-overhead", dynamic=True) + logger.info(f"Applied torch.compile optimization: {model_name}") + except Exception as e: + logger.warning(f"torch.compile optimization failed: {e}") + + # Set model to eval mode and disable gradients for inference + model.eval() + for param in model.parameters(): + param.requires_grad_(False) # Cache the model _model_cache[cache_key] = model logger.info(f"Model cached: {cache_key}") - # Compute embeddings - logger.info("Starting embedding computation...") + # Compute embeddings with optimized inference mode + logger.info(f"Starting embedding computation... (batch_size: {batch_size})") - embeddings = model.encode( - texts, - batch_size=batch_size, - show_progress_bar=is_build, # Don't show progress bar in server environment - convert_to_numpy=True, - normalize_embeddings=False, - device=device, - ) + # Use torch.inference_mode for optimal performance + with torch.inference_mode(): + embeddings = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=is_build, # Don't show progress bar in server environment + convert_to_numpy=True, + normalize_embeddings=False, + device=device, + ) logger.info( f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"