feat: implement true batch processing for Ollama embeddings
Migrate from deprecated /api/embeddings to modern /api/embed endpoint which supports batch inputs. This reduces HTTP overhead by sending 32 texts per request instead of making individual API calls. Changes: - Update endpoint from /api/embeddings to /api/embed - Change parameter from 'prompt' (single) to 'input' (array) - Update response parsing for batch embeddings array - Increase timeout to 60s for batch processing - Improve error handling for batch requests Performance: - Reduces API calls by 32x (batch size) - Eliminates HTTP connection overhead per text - Note: Ollama still processes batch items sequentially internally Related: #151
This commit is contained in:
@@ -574,9 +574,10 @@ def compute_embeddings_ollama(
|
|||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with simplified batch processing.
|
Compute embeddings using Ollama API with true batch processing.
|
||||||
|
|
||||||
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
Uses the /api/embed endpoint which supports batch inputs.
|
||||||
|
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
@@ -681,11 +682,11 @@ def compute_embeddings_ollama(
|
|||||||
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
||||||
model_name = resolved_model_name
|
model_name = resolved_model_name
|
||||||
|
|
||||||
# Verify the model supports embeddings by testing it
|
# Verify the model supports embeddings by testing it with /api/embed
|
||||||
try:
|
try:
|
||||||
test_response = requests.post(
|
test_response = requests.post(
|
||||||
f"{resolved_host}/api/embeddings",
|
f"{resolved_host}/api/embed",
|
||||||
json={"model": model_name, "prompt": "test"},
|
json={"model": model_name, "input": "test"},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
if test_response.status_code != 200:
|
if test_response.status_code != 200:
|
||||||
@@ -717,56 +718,55 @@ def compute_embeddings_ollama(
|
|||||||
# If torch is not available, use conservative batch size
|
# If torch is not available, use conservative batch size
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size}")
|
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||||
|
|
||||||
def get_batch_embeddings(batch_texts):
|
def get_batch_embeddings(batch_texts):
|
||||||
"""Get embeddings for a batch of texts."""
|
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
||||||
all_embeddings = []
|
max_retries = 3
|
||||||
failed_indices = []
|
retry_count = 0
|
||||||
|
|
||||||
for i, text in enumerate(batch_texts):
|
# Truncate very long texts to avoid API issues
|
||||||
max_retries = 3
|
truncated_texts = [text[:8000] if len(text) > 8000 else text for text in batch_texts]
|
||||||
retry_count = 0
|
|
||||||
|
|
||||||
# Truncate very long texts to avoid API issues
|
while retry_count < max_retries:
|
||||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
try:
|
||||||
while retry_count < max_retries:
|
# Use /api/embed endpoint with "input" parameter for batch processing
|
||||||
try:
|
response = requests.post(
|
||||||
response = requests.post(
|
f"{resolved_host}/api/embed",
|
||||||
f"{resolved_host}/api/embeddings",
|
json={"model": model_name, "input": truncated_texts},
|
||||||
json={"model": model_name, "prompt": truncated_text},
|
timeout=60, # Increased timeout for batch processing
|
||||||
timeout=30,
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
batch_embeddings = result.get("embeddings")
|
||||||
|
|
||||||
|
if batch_embeddings is None:
|
||||||
|
raise ValueError("No embeddings returned from API")
|
||||||
|
|
||||||
|
if not isinstance(batch_embeddings, list):
|
||||||
|
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
|
||||||
|
|
||||||
|
if len(batch_embeddings) != len(batch_texts):
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
return batch_embeddings, []
|
||||||
embedding = result.get("embedding")
|
|
||||||
|
|
||||||
if embedding is None:
|
except requests.exceptions.Timeout:
|
||||||
raise ValueError(f"No embedding returned for text {i}")
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logger.warning(f"Timeout for batch after {max_retries} retries")
|
||||||
|
return None, list(range(len(batch_texts)))
|
||||||
|
|
||||||
if not isinstance(embedding, list) or len(embedding) == 0:
|
except Exception as e:
|
||||||
raise ValueError(f"Invalid embedding format for text {i}")
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logger.error(f"Failed to get embeddings for batch: {e}")
|
||||||
|
return None, list(range(len(batch_texts)))
|
||||||
|
|
||||||
all_embeddings.append(embedding)
|
return None, list(range(len(batch_texts)))
|
||||||
break
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.error(f"Failed to get embedding for text {i}: {e}")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
return all_embeddings, failed_indices
|
|
||||||
|
|
||||||
# Process texts in batches
|
# Process texts in batches
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
@@ -784,7 +784,7 @@ def compute_embeddings_ollama(
|
|||||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
if show_progress:
|
if show_progress:
|
||||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
||||||
else:
|
else:
|
||||||
batch_iterator = range(num_batches)
|
batch_iterator = range(num_batches)
|
||||||
|
|
||||||
@@ -795,10 +795,14 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||||
|
|
||||||
# Adjust failed indices to global indices
|
if batch_embeddings is not None:
|
||||||
global_failed = [start_idx + idx for idx in batch_failed]
|
all_embeddings.extend(batch_embeddings)
|
||||||
all_failed_indices.extend(global_failed)
|
else:
|
||||||
all_embeddings.extend(batch_embeddings)
|
# Entire batch failed, add None placeholders
|
||||||
|
all_embeddings.extend([None] * len(batch_texts))
|
||||||
|
# Adjust failed indices to global indices
|
||||||
|
global_failed = [start_idx + idx for idx in batch_failed]
|
||||||
|
all_failed_indices.extend(global_failed)
|
||||||
|
|
||||||
# Handle failed embeddings
|
# Handle failed embeddings
|
||||||
if all_failed_indices:
|
if all_failed_indices:
|
||||||
|
|||||||
Reference in New Issue
Block a user