diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 06fba3d..e3d9f86 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -574,9 +574,10 @@ def compute_embeddings_ollama( host: Optional[str] = None, ) -> np.ndarray: """ - Compute embeddings using Ollama API with simplified batch processing. + Compute embeddings using Ollama API with true batch processing. - Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance. + Uses the /api/embed endpoint which supports batch inputs. + Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance. Args: texts: List of texts to compute embeddings for @@ -681,11 +682,11 @@ def compute_embeddings_ollama( logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'") model_name = resolved_model_name - # Verify the model supports embeddings by testing it + # Verify the model supports embeddings by testing it with /api/embed try: test_response = requests.post( - f"{resolved_host}/api/embeddings", - json={"model": model_name, "prompt": "test"}, + f"{resolved_host}/api/embed", + json={"model": model_name, "input": "test"}, timeout=10, ) if test_response.status_code != 200: @@ -717,56 +718,55 @@ def compute_embeddings_ollama( # If torch is not available, use conservative batch size batch_size = 32 - logger.info(f"Using batch size: {batch_size}") + logger.info(f"Using batch size: {batch_size} for true batch processing") def get_batch_embeddings(batch_texts): - """Get embeddings for a batch of texts.""" - all_embeddings = [] - failed_indices = [] + """Get embeddings for a batch of texts using /api/embed endpoint.""" + max_retries = 3 + retry_count = 0 - for i, text in enumerate(batch_texts): - max_retries = 3 - retry_count = 0 + # Truncate very long texts to avoid API issues + truncated_texts = [text[:8000] if len(text) > 8000 else text for text in batch_texts] - # Truncate very long texts to avoid API issues - truncated_text = text[:8000] if len(text) > 8000 else text - while retry_count < max_retries: - try: - response = requests.post( - f"{resolved_host}/api/embeddings", - json={"model": model_name, "prompt": truncated_text}, - timeout=30, + while retry_count < max_retries: + try: + # Use /api/embed endpoint with "input" parameter for batch processing + response = requests.post( + f"{resolved_host}/api/embed", + json={"model": model_name, "input": truncated_texts}, + timeout=60, # Increased timeout for batch processing + ) + response.raise_for_status() + + result = response.json() + batch_embeddings = result.get("embeddings") + + if batch_embeddings is None: + raise ValueError("No embeddings returned from API") + + if not isinstance(batch_embeddings, list): + raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}") + + if len(batch_embeddings) != len(batch_texts): + raise ValueError( + f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}" ) - response.raise_for_status() - result = response.json() - embedding = result.get("embedding") + return batch_embeddings, [] - if embedding is None: - raise ValueError(f"No embedding returned for text {i}") + except requests.exceptions.Timeout: + retry_count += 1 + if retry_count >= max_retries: + logger.warning(f"Timeout for batch after {max_retries} retries") + return None, list(range(len(batch_texts))) - if not isinstance(embedding, list) or len(embedding) == 0: - raise ValueError(f"Invalid embedding format for text {i}") + except Exception as e: + retry_count += 1 + if retry_count >= max_retries: + logger.error(f"Failed to get embeddings for batch: {e}") + return None, list(range(len(batch_texts))) - all_embeddings.append(embedding) - break - - except requests.exceptions.Timeout: - retry_count += 1 - if retry_count >= max_retries: - logger.warning(f"Timeout for text {i} after {max_retries} retries") - failed_indices.append(i) - all_embeddings.append(None) - break - - except Exception as e: - retry_count += 1 - if retry_count >= max_retries: - logger.error(f"Failed to get embedding for text {i}: {e}") - failed_indices.append(i) - all_embeddings.append(None) - break - return all_embeddings, failed_indices + return None, list(range(len(batch_texts))) # Process texts in batches all_embeddings = [] @@ -784,7 +784,7 @@ def compute_embeddings_ollama( num_batches = (len(texts) + batch_size - 1) // batch_size if show_progress: - batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings") + batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)") else: batch_iterator = range(num_batches) @@ -795,10 +795,14 @@ def compute_embeddings_ollama( batch_embeddings, batch_failed = get_batch_embeddings(batch_texts) - # Adjust failed indices to global indices - global_failed = [start_idx + idx for idx in batch_failed] - all_failed_indices.extend(global_failed) - all_embeddings.extend(batch_embeddings) + if batch_embeddings is not None: + all_embeddings.extend(batch_embeddings) + else: + # Entire batch failed, add None placeholders + all_embeddings.extend([None] * len(batch_texts)) + # Adjust failed indices to global indices + global_failed = [start_idx + idx for idx in batch_failed] + all_failed_indices.extend(global_failed) # Handle failed embeddings if all_failed_indices: