Merge branch 'main' into feat/claude-code-refine
This commit is contained in:
@@ -6,6 +6,7 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -377,7 +378,7 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
model_name: Ollama model name (e.g., "nomic-embed-text-v2", "mxbai-embed-large")
|
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
host: Ollama host URL (default: http://localhost:11434)
|
host: Ollama host URL (default: http://localhost:11434)
|
||||||
|
|
||||||
@@ -402,57 +403,106 @@ def compute_embeddings_ollama(
|
|||||||
try:
|
try:
|
||||||
response = requests.get(f"{host}/api/version", timeout=5)
|
response = requests.get(f"{host}/api/version", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except Exception as e:
|
except requests.exceptions.ConnectionError:
|
||||||
raise RuntimeError(
|
error_msg = (
|
||||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running. Error: {e}"
|
f"❌ Could not connect to Ollama at {host}.\n\n"
|
||||||
|
"Please ensure Ollama is running:\n"
|
||||||
|
" • macOS/Linux: ollama serve\n"
|
||||||
|
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||||
|
"Installation: https://ollama.com/download"
|
||||||
)
|
)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
|
||||||
|
|
||||||
# Check if model exists
|
# Check if model exists and provide helpful suggestions
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
response = requests.get(f"{host}/api/tags", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
models = response.json()
|
models = response.json()
|
||||||
model_names = [model["name"] for model in models.get("models", [])]
|
model_names = [model["name"] for model in models.get("models", [])]
|
||||||
|
|
||||||
# Check if model exists (handle versioned names like nomic-embed-text-v2:latest)
|
# Filter for embedding models (models that support embeddings)
|
||||||
|
embedding_models = []
|
||||||
|
suggested_embedding_models = [
|
||||||
|
"nomic-embed-text",
|
||||||
|
"mxbai-embed-large",
|
||||||
|
"bge-m3",
|
||||||
|
"all-minilm",
|
||||||
|
"snowflake-arctic-embed",
|
||||||
|
]
|
||||||
|
|
||||||
|
for model in model_names:
|
||||||
|
# Check if it's an embedding model (by name patterns or known models)
|
||||||
|
base_name = model.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
|
||||||
|
embedding_models.append(model)
|
||||||
|
|
||||||
|
# Check if model exists (handle versioned names)
|
||||||
model_found = any(
|
model_found = any(
|
||||||
model_name in name or name.startswith(f"{model_name}:") for name in model_names
|
model_name == name.split(":")[0] or model_name == name for name in model_names
|
||||||
)
|
)
|
||||||
|
|
||||||
if not model_found:
|
if not model_found:
|
||||||
error_msg = f"Model '{model_name}' not found in Ollama. Available models: {', '.join(model_names)}"
|
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
|
||||||
error_msg += f"\n\nTo install the model, run:\n ollama pull {model_name}"
|
|
||||||
|
# Suggest pulling the model
|
||||||
|
error_msg += "📦 To install this embedding model:\n"
|
||||||
|
error_msg += f" ollama pull {model_name}\n\n"
|
||||||
|
|
||||||
|
# Show available embedding models
|
||||||
|
if embedding_models:
|
||||||
|
error_msg += "✅ Available embedding models:\n"
|
||||||
|
for model in embedding_models[:5]:
|
||||||
|
error_msg += f" • {model}\n"
|
||||||
|
if len(embedding_models) > 5:
|
||||||
|
error_msg += f" ... and {len(embedding_models) - 5} more\n"
|
||||||
|
else:
|
||||||
|
error_msg += "💡 Popular embedding models to install:\n"
|
||||||
|
for model in suggested_embedding_models[:3]:
|
||||||
|
error_msg += f" • ollama pull {model}\n"
|
||||||
|
|
||||||
|
error_msg += "\n📚 Browse more: https://ollama.com/library"
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
# Verify the model supports embeddings by testing it
|
||||||
|
try:
|
||||||
|
test_response = requests.post(
|
||||||
|
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
||||||
|
)
|
||||||
|
if test_response.status_code != 200:
|
||||||
|
error_msg = (
|
||||||
|
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
|
||||||
|
f"Please use an embedding model like:\n"
|
||||||
|
)
|
||||||
|
for model in suggested_embedding_models[:3]:
|
||||||
|
error_msg += f" • {model}\n"
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
# If test fails, continue anyway - model might still work
|
||||||
|
pass
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.warning(f"Could not verify model existence: {e}")
|
logger.warning(f"Could not verify model existence: {e}")
|
||||||
|
|
||||||
# Process embeddings
|
# Process embeddings with optimized concurrent processing
|
||||||
all_embeddings = []
|
import requests
|
||||||
batch_size = 100 # Process in batches to avoid overwhelming the API
|
|
||||||
|
|
||||||
# Add progress bar if in build mode
|
def get_single_embedding(text_idx_tuple):
|
||||||
try:
|
"""Helper function to get embedding for a single text."""
|
||||||
if is_build:
|
text, idx = text_idx_tuple
|
||||||
from tqdm import tqdm
|
max_retries = 3
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
batch_iterator = tqdm(
|
# Truncate very long texts to avoid API issues
|
||||||
range(0, len(texts), batch_size),
|
truncated_text = text[:8000] if len(text) > 8000 else text
|
||||||
desc="Computing Ollama embeddings",
|
|
||||||
unit="batch",
|
|
||||||
total=(len(texts) + batch_size - 1) // batch_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch_iterator = range(0, len(texts), batch_size)
|
|
||||||
except ImportError:
|
|
||||||
batch_iterator = range(0, len(texts), batch_size)
|
|
||||||
|
|
||||||
for i in batch_iterator:
|
while retry_count < max_retries:
|
||||||
batch_texts = texts[i : i + batch_size]
|
|
||||||
|
|
||||||
for text in batch_texts:
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": text}, timeout=30
|
f"{host}/api/embeddings",
|
||||||
|
json={"model": model_name, "prompt": truncated_text},
|
||||||
|
timeout=30,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -460,16 +510,112 @@ def compute_embeddings_ollama(
|
|||||||
embedding = result.get("embedding")
|
embedding = result.get("embedding")
|
||||||
|
|
||||||
if embedding is None:
|
if embedding is None:
|
||||||
raise ValueError(f"No embedding returned for text: {text[:50]}...")
|
raise ValueError(f"No embedding returned for text {idx}")
|
||||||
|
|
||||||
all_embeddings.append(embedding)
|
return idx, embedding
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
|
||||||
|
return idx, None
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.error(f"Failed to get embedding for text: {e}")
|
|
||||||
raise RuntimeError(f"Error getting embedding from Ollama: {e}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error: {e}")
|
if retry_count >= max_retries - 1:
|
||||||
raise
|
logger.error(f"Failed to get embedding for text {idx}: {e}")
|
||||||
|
return idx, None
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
return idx, None
|
||||||
|
|
||||||
|
# Determine if we should use concurrent processing
|
||||||
|
use_concurrent = (
|
||||||
|
len(texts) > 5 and not is_build
|
||||||
|
) # Don't use concurrent in build mode to avoid overwhelming
|
||||||
|
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
|
||||||
|
|
||||||
|
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
|
||||||
|
failed_indices = []
|
||||||
|
|
||||||
|
if use_concurrent:
|
||||||
|
logger.info(
|
||||||
|
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
|
||||||
|
)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# Submit all tasks
|
||||||
|
future_to_idx = {
|
||||||
|
executor.submit(get_single_embedding, (text, idx)): idx
|
||||||
|
for idx, text in enumerate(texts)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add progress bar for concurrent processing
|
||||||
|
try:
|
||||||
|
if is_build or len(texts) > 10:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
futures_iterator = tqdm(
|
||||||
|
as_completed(future_to_idx),
|
||||||
|
total=len(texts),
|
||||||
|
desc="Computing Ollama embeddings",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
futures_iterator = as_completed(future_to_idx)
|
||||||
|
except ImportError:
|
||||||
|
futures_iterator = as_completed(future_to_idx)
|
||||||
|
|
||||||
|
# Collect results as they complete
|
||||||
|
for future in futures_iterator:
|
||||||
|
try:
|
||||||
|
idx, embedding = future.result()
|
||||||
|
if embedding is not None:
|
||||||
|
all_embeddings[idx] = embedding
|
||||||
|
else:
|
||||||
|
failed_indices.append(idx)
|
||||||
|
except Exception as e:
|
||||||
|
idx = future_to_idx[future]
|
||||||
|
logger.error(f"Exception for text {idx}: {e}")
|
||||||
|
failed_indices.append(idx)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Sequential processing with progress bar
|
||||||
|
show_progress = is_build or len(texts) > 10
|
||||||
|
|
||||||
|
try:
|
||||||
|
if show_progress:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
iterator = tqdm(
|
||||||
|
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
iterator = enumerate(texts)
|
||||||
|
except ImportError:
|
||||||
|
iterator = enumerate(texts)
|
||||||
|
|
||||||
|
for idx, text in iterator:
|
||||||
|
result_idx, embedding = get_single_embedding((text, idx))
|
||||||
|
if embedding is not None:
|
||||||
|
all_embeddings[idx] = embedding
|
||||||
|
else:
|
||||||
|
failed_indices.append(idx)
|
||||||
|
|
||||||
|
# Handle failed embeddings
|
||||||
|
if failed_indices:
|
||||||
|
if len(failed_indices) == len(texts):
|
||||||
|
raise RuntimeError("Failed to compute any embeddings")
|
||||||
|
|
||||||
|
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
|
||||||
|
|
||||||
|
# Use zero embeddings as fallback for failed ones
|
||||||
|
valid_embedding = next((e for e in all_embeddings if e is not None), None)
|
||||||
|
if valid_embedding:
|
||||||
|
embedding_dim = len(valid_embedding)
|
||||||
|
for idx in failed_indices:
|
||||||
|
all_embeddings[idx] = [0.0] * embedding_dim
|
||||||
|
|
||||||
|
# Remove None values and convert to numpy array
|
||||||
|
all_embeddings = [e for e in all_embeddings if e is not None]
|
||||||
|
|
||||||
# Convert to numpy array and normalize
|
# Convert to numpy array and normalize
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
|||||||
Reference in New Issue
Block a user