From 4ccbbf3e6b7073ebb1d51089e30dea51a6dd0754 Mon Sep 17 00:00:00 2001 From: aakash Date: Mon, 6 Oct 2025 14:51:12 -0700 Subject: [PATCH] fix: Add comprehensive SentenceTransformer version compatibility - Handle both old and new sentence-transformers versions - Gracefully fallback from advanced parameters to basic initialization - Catch TypeError for model_kwargs/tokenizer_kwargs and use basic SentenceTransformer init - Ensures compatibility across different CI environments and local setups - Maintains optimization benefits where supported while ensuring broad compatibility This resolves test failures in CI environments with older sentence-transformers versions. --- .../leann-core/src/leann/embedding_compute.py | 61 ++++++++++++++----- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index fafa1a0..7e90dbb 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -183,7 +183,7 @@ def compute_embeddings_sentence_transformers( } try: - # Try local loading first + # Try loading with advanced parameters first (newer versions) local_model_kwargs = model_kwargs.copy() local_tokenizer_kwargs = tokenizer_kwargs.copy() local_model_kwargs["local_files_only"] = True @@ -197,22 +197,55 @@ def compute_embeddings_sentence_transformers( local_files_only=True, ) logger.info("Model loaded successfully! (local + optimized)") + except TypeError as e: + if "model_kwargs" in str(e) or "tokenizer_kwargs" in str(e): + logger.warning(f"Advanced parameters not supported ({e}), using basic initialization...") + # Fallback to basic initialization for older versions + try: + model = SentenceTransformer( + model_name, + device=device, + local_files_only=True, + ) + logger.info("Model loaded successfully! (local + basic)") + except Exception as e2: + logger.warning(f"Local loading failed ({e2}), trying network download...") + model = SentenceTransformer( + model_name, + device=device, + local_files_only=False, + ) + logger.info("Model loaded successfully! (network + basic)") + else: + raise except Exception as e: logger.warning(f"Local loading failed ({e}), trying network download...") - # Fallback to network loading - network_model_kwargs = model_kwargs.copy() - network_tokenizer_kwargs = tokenizer_kwargs.copy() - network_model_kwargs["local_files_only"] = False - network_tokenizer_kwargs["local_files_only"] = False + # Fallback to network loading with advanced parameters + try: + network_model_kwargs = model_kwargs.copy() + network_tokenizer_kwargs = tokenizer_kwargs.copy() + network_model_kwargs["local_files_only"] = False + network_tokenizer_kwargs["local_files_only"] = False - model = SentenceTransformer( - model_name, - device=device, - model_kwargs=network_model_kwargs, - tokenizer_kwargs=network_tokenizer_kwargs, - local_files_only=False, - ) - logger.info("Model loaded successfully! (network + optimized)") + model = SentenceTransformer( + model_name, + device=device, + model_kwargs=network_model_kwargs, + tokenizer_kwargs=network_tokenizer_kwargs, + local_files_only=False, + ) + logger.info("Model loaded successfully! (network + optimized)") + except TypeError as e2: + if "model_kwargs" in str(e2) or "tokenizer_kwargs" in str(e2): + logger.warning(f"Advanced parameters not supported ({e2}), using basic network loading...") + model = SentenceTransformer( + model_name, + device=device, + local_files_only=False, + ) + logger.info("Model loaded successfully! (network + basic)") + else: + raise # Apply additional optimizations based on mode if use_fp16 and device in ["cuda", "mps"]: