feat: implement true batch processing for Ollama embeddings

Migrate from deprecated /api/embeddings to modern /api/embed endpoint
which supports batch inputs. This reduces HTTP overhead by sending
32 texts per request instead of making individual API calls.

Changes:
- Update endpoint from /api/embeddings to /api/embed
- Change parameter from 'prompt' (single) to 'input' (array)
- Update response parsing for batch embeddings array
- Increase timeout to 60s for batch processing
- Improve error handling for batch requests

Performance:
- Reduces API calls by 32x (batch size)
- Eliminates HTTP connection overhead per text
- Note: Ollama still processes batch items sequentially internally

Related: #151
This commit is contained in:
ww2283
2025-10-23 11:30:09 -04:00
parent 45b87ce128
commit d226f72bc0

View File

@@ -574,9 +574,10 @@ def compute_embeddings_ollama(
host: Optional[str] = None, host: Optional[str] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embeddings using Ollama API with simplified batch processing. Compute embeddings using Ollama API with true batch processing.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance. Uses the /api/embed endpoint which supports batch inputs.
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
Args: Args:
texts: List of texts to compute embeddings for texts: List of texts to compute embeddings for
@@ -681,11 +682,11 @@ def compute_embeddings_ollama(
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'") logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name model_name = resolved_model_name
# Verify the model supports embeddings by testing it # Verify the model supports embeddings by testing it with /api/embed
try: try:
test_response = requests.post( test_response = requests.post(
f"{resolved_host}/api/embeddings", f"{resolved_host}/api/embed",
json={"model": model_name, "prompt": "test"}, json={"model": model_name, "input": "test"},
timeout=10, timeout=10,
) )
if test_response.status_code != 200: if test_response.status_code != 200:
@@ -717,56 +718,55 @@ def compute_embeddings_ollama(
# If torch is not available, use conservative batch size # If torch is not available, use conservative batch size
batch_size = 32 batch_size = 32
logger.info(f"Using batch size: {batch_size}") logger.info(f"Using batch size: {batch_size} for true batch processing")
def get_batch_embeddings(batch_texts): def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts.""" """Get embeddings for a batch of texts using /api/embed endpoint."""
all_embeddings = [] max_retries = 3
failed_indices = [] retry_count = 0
for i, text in enumerate(batch_texts): # Truncate very long texts to avoid API issues
max_retries = 3 truncated_texts = [text[:8000] if len(text) > 8000 else text for text in batch_texts]
retry_count = 0
# Truncate very long texts to avoid API issues while retry_count < max_retries:
truncated_text = text[:8000] if len(text) > 8000 else text try:
while retry_count < max_retries: # Use /api/embed endpoint with "input" parameter for batch processing
try: response = requests.post(
response = requests.post( f"{resolved_host}/api/embed",
f"{resolved_host}/api/embeddings", json={"model": model_name, "input": truncated_texts},
json={"model": model_name, "prompt": truncated_text}, timeout=60, # Increased timeout for batch processing
timeout=30, )
response.raise_for_status()
result = response.json()
batch_embeddings = result.get("embeddings")
if batch_embeddings is None:
raise ValueError("No embeddings returned from API")
if not isinstance(batch_embeddings, list):
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
if len(batch_embeddings) != len(batch_texts):
raise ValueError(
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
) )
response.raise_for_status()
result = response.json() return batch_embeddings, []
embedding = result.get("embedding")
if embedding is None: except requests.exceptions.Timeout:
raise ValueError(f"No embedding returned for text {i}") retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for batch after {max_retries} retries")
return None, list(range(len(batch_texts)))
if not isinstance(embedding, list) or len(embedding) == 0: except Exception as e:
raise ValueError(f"Invalid embedding format for text {i}") retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embeddings for batch: {e}")
return None, list(range(len(batch_texts)))
all_embeddings.append(embedding) return None, list(range(len(batch_texts)))
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embedding for text {i}: {e}")
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
# Process texts in batches # Process texts in batches
all_embeddings = [] all_embeddings = []
@@ -784,7 +784,7 @@ def compute_embeddings_ollama(
num_batches = (len(texts) + batch_size - 1) // batch_size num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress: if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings") batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
else: else:
batch_iterator = range(num_batches) batch_iterator = range(num_batches)
@@ -795,10 +795,14 @@ def compute_embeddings_ollama(
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts) batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
# Adjust failed indices to global indices if batch_embeddings is not None:
global_failed = [start_idx + idx for idx in batch_failed] all_embeddings.extend(batch_embeddings)
all_failed_indices.extend(global_failed) else:
all_embeddings.extend(batch_embeddings) # Entire batch failed, add None placeholders
all_embeddings.extend([None] * len(batch_texts))
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
# Handle failed embeddings # Handle failed embeddings
if all_failed_indices: if all_failed_indices: