diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index fafa1a0..7e90dbb 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -183,7 +183,7 @@ def compute_embeddings_sentence_transformers( } try: - # Try local loading first + # Try loading with advanced parameters first (newer versions) local_model_kwargs = model_kwargs.copy() local_tokenizer_kwargs = tokenizer_kwargs.copy() local_model_kwargs["local_files_only"] = True @@ -197,22 +197,55 @@ def compute_embeddings_sentence_transformers( local_files_only=True, ) logger.info("Model loaded successfully! (local + optimized)") + except TypeError as e: + if "model_kwargs" in str(e) or "tokenizer_kwargs" in str(e): + logger.warning(f"Advanced parameters not supported ({e}), using basic initialization...") + # Fallback to basic initialization for older versions + try: + model = SentenceTransformer( + model_name, + device=device, + local_files_only=True, + ) + logger.info("Model loaded successfully! (local + basic)") + except Exception as e2: + logger.warning(f"Local loading failed ({e2}), trying network download...") + model = SentenceTransformer( + model_name, + device=device, + local_files_only=False, + ) + logger.info("Model loaded successfully! (network + basic)") + else: + raise except Exception as e: logger.warning(f"Local loading failed ({e}), trying network download...") - # Fallback to network loading - network_model_kwargs = model_kwargs.copy() - network_tokenizer_kwargs = tokenizer_kwargs.copy() - network_model_kwargs["local_files_only"] = False - network_tokenizer_kwargs["local_files_only"] = False + # Fallback to network loading with advanced parameters + try: + network_model_kwargs = model_kwargs.copy() + network_tokenizer_kwargs = tokenizer_kwargs.copy() + network_model_kwargs["local_files_only"] = False + network_tokenizer_kwargs["local_files_only"] = False - model = SentenceTransformer( - model_name, - device=device, - model_kwargs=network_model_kwargs, - tokenizer_kwargs=network_tokenizer_kwargs, - local_files_only=False, - ) - logger.info("Model loaded successfully! (network + optimized)") + model = SentenceTransformer( + model_name, + device=device, + model_kwargs=network_model_kwargs, + tokenizer_kwargs=network_tokenizer_kwargs, + local_files_only=False, + ) + logger.info("Model loaded successfully! (network + optimized)") + except TypeError as e2: + if "model_kwargs" in str(e2) or "tokenizer_kwargs" in str(e2): + logger.warning(f"Advanced parameters not supported ({e2}), using basic network loading...") + model = SentenceTransformer( + model_name, + device=device, + local_files_only=False, + ) + logger.info("Model loaded successfully! (network + basic)") + else: + raise # Apply additional optimizations based on mode if use_fp16 and device in ["cuda", "mps"]: