[Ollama] fix ollama recompute

This commit is contained in:
yichuan520030910320
2025-08-12 00:24:20 -07:00
parent e8fca2c84a
commit b2390ccc14
2 changed files with 137 additions and 139 deletions

View File

@@ -13,7 +13,7 @@ if(APPLE)
else() else()
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp") message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
endif() endif()
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include") set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include") set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_C_LIB_NAMES "omp") set(OpenMP_C_LIB_NAMES "omp")

View File

@@ -6,7 +6,6 @@ Preserves all optimization parameters to ensure performance
import logging import logging
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any from typing import Any
import numpy as np import numpy as np
@@ -374,7 +373,9 @@ def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434" texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embeddings using Ollama API. Compute embeddings using Ollama API with simplified batch processing.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
Args: Args:
texts: List of texts to compute embeddings for texts: List of texts to compute embeddings for
@@ -438,12 +439,19 @@ def compute_embeddings_ollama(
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]): if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
embedding_models.append(model) embedding_models.append(model)
# Check if model exists (handle versioned names) # Check if model exists (handle versioned names) and resolve to full name
model_found = any( resolved_model_name = None
model_name == name.split(":")[0] or model_name == name for name in model_names for name in model_names:
) # Exact match
if model_name == name:
resolved_model_name = name
break
# Match without version tag (use the versioned name)
elif model_name == name.split(":")[0]:
resolved_model_name = name
break
if not model_found: if not resolved_model_name:
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n" error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
# Suggest pulling the model # Suggest pulling the model
@@ -465,6 +473,11 @@ def compute_embeddings_ollama(
error_msg += "\n📚 Browse more: https://ollama.com/library" error_msg += "\n📚 Browse more: https://ollama.com/library"
raise ValueError(error_msg) raise ValueError(error_msg)
# Use the resolved model name for all subsequent operations
if resolved_model_name != model_name:
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name
# Verify the model supports embeddings by testing it # Verify the model supports embeddings by testing it
try: try:
test_response = requests.post( test_response = requests.post(
@@ -485,162 +498,147 @@ def compute_embeddings_ollama(
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.warning(f"Could not verify model existence: {e}") logger.warning(f"Could not verify model existence: {e}")
# Process embeddings with optimized concurrent processing # Determine batch size based on device availability
import requests # Check for CUDA/MPS availability using torch if available
batch_size = 32 # Default for MPS/CPU
try:
import torch
def get_single_embedding(text_idx_tuple): if torch.cuda.is_available():
"""Helper function to get embedding for a single text.""" batch_size = 128 # CUDA gets larger batch size
text, idx = text_idx_tuple elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
max_retries = 3 batch_size = 32 # MPS gets smaller batch size
retry_count = 0 except ImportError:
# If torch is not available, use conservative batch size
batch_size = 32
# Truncate very long texts to avoid API issues logger.info(f"Using batch size: {batch_size}")
truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries: def get_batch_embeddings(batch_texts):
try: """Get embeddings for a batch of texts."""
response = requests.post( all_embeddings = []
f"{host}/api/embeddings", failed_indices = []
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)
response.raise_for_status()
result = response.json() for i, text in enumerate(batch_texts):
embedding = result.get("embedding") max_retries = 3
retry_count = 0
if embedding is None: # Truncate very long texts to avoid API issues
raise ValueError(f"No embedding returned for text {idx}") truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries:
return idx, embedding
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
return idx, None
except Exception as e:
if retry_count >= max_retries - 1:
logger.error(f"Failed to get embedding for text {idx}: {e}")
return idx, None
retry_count += 1
return idx, None
# Determine if we should use concurrent processing
use_concurrent = (
len(texts) > 5 and not is_build
) # Don't use concurrent in build mode to avoid overwhelming
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
failed_indices = []
if use_concurrent:
logger.info(
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_idx = {
executor.submit(get_single_embedding, (text, idx)): idx
for idx, text in enumerate(texts)
}
# Add progress bar for concurrent processing
try:
if is_build or len(texts) > 10:
from tqdm import tqdm
futures_iterator = tqdm(
as_completed(future_to_idx),
total=len(texts),
desc="Computing Ollama embeddings",
)
else:
futures_iterator = as_completed(future_to_idx)
except ImportError:
futures_iterator = as_completed(future_to_idx)
# Collect results as they complete
for future in futures_iterator:
try: try:
idx, embedding = future.result() response = requests.post(
if embedding is not None: f"{host}/api/embeddings",
all_embeddings[idx] = embedding json={"model": model_name, "prompt": truncated_text},
else: timeout=30,
failed_indices.append(idx) )
response.raise_for_status()
result = response.json()
embedding = result.get("embedding")
if embedding is None:
raise ValueError(f"No embedding returned for text {i}")
if not isinstance(embedding, list) or len(embedding) == 0:
raise ValueError(f"Invalid embedding format for text {i}")
all_embeddings.append(embedding)
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
except Exception as e: except Exception as e:
idx = future_to_idx[future] retry_count += 1
logger.error(f"Exception for text {idx}: {e}") if retry_count >= max_retries:
failed_indices.append(idx) logger.error(f"Failed to get embedding for text {i}: {e}")
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
# Process texts in batches
all_embeddings = []
all_failed_indices = []
# Setup progress bar if needed
show_progress = is_build or len(texts) > 10
try:
if show_progress:
from tqdm import tqdm
except ImportError:
show_progress = False
# Process batches
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
else: else:
# Sequential processing with progress bar batch_iterator = range(num_batches)
show_progress = is_build or len(texts) > 10
try: for batch_idx in batch_iterator:
if show_progress: start_idx = batch_idx * batch_size
from tqdm import tqdm end_idx = min(start_idx + batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
iterator = tqdm( batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
)
else:
iterator = enumerate(texts)
except ImportError:
iterator = enumerate(texts)
for idx, text in iterator: # Adjust failed indices to global indices
result_idx, embedding = get_single_embedding((text, idx)) global_failed = [start_idx + idx for idx in batch_failed]
if embedding is not None: all_failed_indices.extend(global_failed)
all_embeddings[idx] = embedding all_embeddings.extend(batch_embeddings)
else:
failed_indices.append(idx)
# Handle failed embeddings # Handle failed embeddings
if failed_indices: if all_failed_indices:
if len(failed_indices) == len(texts): if len(all_failed_indices) == len(texts):
raise RuntimeError("Failed to compute any embeddings") raise RuntimeError("Failed to compute any embeddings")
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts") logger.warning(
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
)
# Use zero embeddings as fallback for failed ones # Use zero embeddings as fallback for failed ones
valid_embedding = next((e for e in all_embeddings if e is not None), None) valid_embedding = next((e for e in all_embeddings if e is not None), None)
if valid_embedding: if valid_embedding:
embedding_dim = len(valid_embedding) embedding_dim = len(valid_embedding)
for idx in failed_indices: for i, embedding in enumerate(all_embeddings):
all_embeddings[idx] = [0.0] * embedding_dim if embedding is None:
all_embeddings[i] = [0.0] * embedding_dim
# Remove None values and convert to numpy array # Remove None values
all_embeddings = [e for e in all_embeddings if e is not None] all_embeddings = [e for e in all_embeddings if e is not None]
# Validate embedding dimensions before creating numpy array if not all_embeddings:
if all_embeddings: raise RuntimeError("No valid embeddings were computed")
expected_dim = len(all_embeddings[0])
inconsistent_dims = []
for i, embedding in enumerate(all_embeddings):
if len(embedding) != expected_dim:
inconsistent_dims.append((i, len(embedding)))
if inconsistent_dims: # Validate embedding dimensions
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n" expected_dim = len(all_embeddings[0])
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones inconsistent_dims = []
error_msg += f" - Text {idx}: {dim} dimensions\n" for i, embedding in enumerate(all_embeddings):
if len(inconsistent_dims) > 10: if len(embedding) != expected_dim:
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n" inconsistent_dims.append((i, len(embedding)))
error_msg += (
f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n" if inconsistent_dims:
) error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
error_msg += "1. Restart Ollama service: 'ollama serve'\n" for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n" error_msg += f" - Text {idx}: {dim} dimensions\n"
error_msg += ( if len(inconsistent_dims) > 10:
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n" error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
) error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues" error_msg += "1. Restart Ollama service: 'ollama serve'\n"
raise ValueError(error_msg) error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
error_msg += (
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
)
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
raise ValueError(error_msg)
# Convert to numpy array and normalize # Convert to numpy array and normalize
embeddings = np.array(all_embeddings, dtype=np.float32) embeddings = np.array(all_embeddings, dtype=np.float32)