|
|
|
@@ -6,6 +6,7 @@ Preserves all optimization parameters to ensure performance
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
@@ -35,7 +36,7 @@ def compute_embeddings(
|
|
|
|
|
Args:
|
|
|
|
|
texts: List of texts to compute embeddings for
|
|
|
|
|
model_name: Model name
|
|
|
|
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
|
|
|
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
|
|
|
|
|
is_build: Whether this is a build operation (shows progress bar)
|
|
|
|
|
batch_size: Batch size for processing
|
|
|
|
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
|
|
|
@@ -55,6 +56,8 @@ def compute_embeddings(
|
|
|
|
|
return compute_embeddings_openai(texts, model_name)
|
|
|
|
|
elif mode == "mlx":
|
|
|
|
|
return compute_embeddings_mlx(texts, model_name)
|
|
|
|
|
elif mode == "ollama":
|
|
|
|
|
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
|
|
|
|
|
|
|
|
@@ -365,3 +368,262 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
|
|
|
|
|
|
|
|
|
# Stack numpy arrays
|
|
|
|
|
return np.stack(all_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_embeddings_ollama(
|
|
|
|
|
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
|
|
|
|
) -> np.ndarray:
|
|
|
|
|
"""
|
|
|
|
|
Compute embeddings using Ollama API.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
texts: List of texts to compute embeddings for
|
|
|
|
|
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
|
|
|
|
is_build: Whether this is a build operation (shows progress bar)
|
|
|
|
|
host: Ollama host URL (default: http://localhost:11434)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
import requests
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not texts:
|
|
|
|
|
raise ValueError("Cannot compute embeddings for empty text list")
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Check if Ollama is running
|
|
|
|
|
try:
|
|
|
|
|
response = requests.get(f"{host}/api/version", timeout=5)
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
except requests.exceptions.ConnectionError:
|
|
|
|
|
error_msg = (
|
|
|
|
|
f"❌ Could not connect to Ollama at {host}.\n\n"
|
|
|
|
|
"Please ensure Ollama is running:\n"
|
|
|
|
|
" • macOS/Linux: ollama serve\n"
|
|
|
|
|
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
|
|
|
|
"Installation: https://ollama.com/download"
|
|
|
|
|
)
|
|
|
|
|
raise RuntimeError(error_msg)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
|
|
|
|
|
|
|
|
|
|
# Check if model exists and provide helpful suggestions
|
|
|
|
|
try:
|
|
|
|
|
response = requests.get(f"{host}/api/tags", timeout=5)
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
models = response.json()
|
|
|
|
|
model_names = [model["name"] for model in models.get("models", [])]
|
|
|
|
|
|
|
|
|
|
# Filter for embedding models (models that support embeddings)
|
|
|
|
|
embedding_models = []
|
|
|
|
|
suggested_embedding_models = [
|
|
|
|
|
"nomic-embed-text",
|
|
|
|
|
"mxbai-embed-large",
|
|
|
|
|
"bge-m3",
|
|
|
|
|
"all-minilm",
|
|
|
|
|
"snowflake-arctic-embed",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for model in model_names:
|
|
|
|
|
# Check if it's an embedding model (by name patterns or known models)
|
|
|
|
|
base_name = model.split(":")[0]
|
|
|
|
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
|
|
|
|
|
embedding_models.append(model)
|
|
|
|
|
|
|
|
|
|
# Check if model exists (handle versioned names)
|
|
|
|
|
model_found = any(
|
|
|
|
|
model_name == name.split(":")[0] or model_name == name for name in model_names
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not model_found:
|
|
|
|
|
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
|
|
|
|
|
|
|
|
|
|
# Suggest pulling the model
|
|
|
|
|
error_msg += "📦 To install this embedding model:\n"
|
|
|
|
|
error_msg += f" ollama pull {model_name}\n\n"
|
|
|
|
|
|
|
|
|
|
# Show available embedding models
|
|
|
|
|
if embedding_models:
|
|
|
|
|
error_msg += "✅ Available embedding models:\n"
|
|
|
|
|
for model in embedding_models[:5]:
|
|
|
|
|
error_msg += f" • {model}\n"
|
|
|
|
|
if len(embedding_models) > 5:
|
|
|
|
|
error_msg += f" ... and {len(embedding_models) - 5} more\n"
|
|
|
|
|
else:
|
|
|
|
|
error_msg += "💡 Popular embedding models to install:\n"
|
|
|
|
|
for model in suggested_embedding_models[:3]:
|
|
|
|
|
error_msg += f" • ollama pull {model}\n"
|
|
|
|
|
|
|
|
|
|
error_msg += "\n📚 Browse more: https://ollama.com/library"
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
|
|
|
|
# Verify the model supports embeddings by testing it
|
|
|
|
|
try:
|
|
|
|
|
test_response = requests.post(
|
|
|
|
|
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
|
|
|
|
)
|
|
|
|
|
if test_response.status_code != 200:
|
|
|
|
|
error_msg = (
|
|
|
|
|
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
|
|
|
|
|
f"Please use an embedding model like:\n"
|
|
|
|
|
)
|
|
|
|
|
for model in suggested_embedding_models[:3]:
|
|
|
|
|
error_msg += f" • {model}\n"
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
except requests.exceptions.RequestException:
|
|
|
|
|
# If test fails, continue anyway - model might still work
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
|
|
logger.warning(f"Could not verify model existence: {e}")
|
|
|
|
|
|
|
|
|
|
# Process embeddings with optimized concurrent processing
|
|
|
|
|
import requests
|
|
|
|
|
|
|
|
|
|
def get_single_embedding(text_idx_tuple):
|
|
|
|
|
"""Helper function to get embedding for a single text."""
|
|
|
|
|
text, idx = text_idx_tuple
|
|
|
|
|
max_retries = 3
|
|
|
|
|
retry_count = 0
|
|
|
|
|
|
|
|
|
|
# Truncate very long texts to avoid API issues
|
|
|
|
|
truncated_text = text[:8000] if len(text) > 8000 else text
|
|
|
|
|
|
|
|
|
|
while retry_count < max_retries:
|
|
|
|
|
try:
|
|
|
|
|
response = requests.post(
|
|
|
|
|
f"{host}/api/embeddings",
|
|
|
|
|
json={"model": model_name, "prompt": truncated_text},
|
|
|
|
|
timeout=30,
|
|
|
|
|
)
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
|
|
|
|
|
result = response.json()
|
|
|
|
|
embedding = result.get("embedding")
|
|
|
|
|
|
|
|
|
|
if embedding is None:
|
|
|
|
|
raise ValueError(f"No embedding returned for text {idx}")
|
|
|
|
|
|
|
|
|
|
return idx, embedding
|
|
|
|
|
|
|
|
|
|
except requests.exceptions.Timeout:
|
|
|
|
|
retry_count += 1
|
|
|
|
|
if retry_count >= max_retries:
|
|
|
|
|
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
|
|
|
|
|
return idx, None
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if retry_count >= max_retries - 1:
|
|
|
|
|
logger.error(f"Failed to get embedding for text {idx}: {e}")
|
|
|
|
|
return idx, None
|
|
|
|
|
retry_count += 1
|
|
|
|
|
|
|
|
|
|
return idx, None
|
|
|
|
|
|
|
|
|
|
# Determine if we should use concurrent processing
|
|
|
|
|
use_concurrent = (
|
|
|
|
|
len(texts) > 5 and not is_build
|
|
|
|
|
) # Don't use concurrent in build mode to avoid overwhelming
|
|
|
|
|
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
|
|
|
|
|
|
|
|
|
|
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
|
|
|
|
|
failed_indices = []
|
|
|
|
|
|
|
|
|
|
if use_concurrent:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
|
|
# Submit all tasks
|
|
|
|
|
future_to_idx = {
|
|
|
|
|
executor.submit(get_single_embedding, (text, idx)): idx
|
|
|
|
|
for idx, text in enumerate(texts)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Add progress bar for concurrent processing
|
|
|
|
|
try:
|
|
|
|
|
if is_build or len(texts) > 10:
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
futures_iterator = tqdm(
|
|
|
|
|
as_completed(future_to_idx),
|
|
|
|
|
total=len(texts),
|
|
|
|
|
desc="Computing Ollama embeddings",
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
futures_iterator = as_completed(future_to_idx)
|
|
|
|
|
except ImportError:
|
|
|
|
|
futures_iterator = as_completed(future_to_idx)
|
|
|
|
|
|
|
|
|
|
# Collect results as they complete
|
|
|
|
|
for future in futures_iterator:
|
|
|
|
|
try:
|
|
|
|
|
idx, embedding = future.result()
|
|
|
|
|
if embedding is not None:
|
|
|
|
|
all_embeddings[idx] = embedding
|
|
|
|
|
else:
|
|
|
|
|
failed_indices.append(idx)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
idx = future_to_idx[future]
|
|
|
|
|
logger.error(f"Exception for text {idx}: {e}")
|
|
|
|
|
failed_indices.append(idx)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# Sequential processing with progress bar
|
|
|
|
|
show_progress = is_build or len(texts) > 10
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if show_progress:
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
iterator = tqdm(
|
|
|
|
|
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
iterator = enumerate(texts)
|
|
|
|
|
except ImportError:
|
|
|
|
|
iterator = enumerate(texts)
|
|
|
|
|
|
|
|
|
|
for idx, text in iterator:
|
|
|
|
|
result_idx, embedding = get_single_embedding((text, idx))
|
|
|
|
|
if embedding is not None:
|
|
|
|
|
all_embeddings[idx] = embedding
|
|
|
|
|
else:
|
|
|
|
|
failed_indices.append(idx)
|
|
|
|
|
|
|
|
|
|
# Handle failed embeddings
|
|
|
|
|
if failed_indices:
|
|
|
|
|
if len(failed_indices) == len(texts):
|
|
|
|
|
raise RuntimeError("Failed to compute any embeddings")
|
|
|
|
|
|
|
|
|
|
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
|
|
|
|
|
|
|
|
|
|
# Use zero embeddings as fallback for failed ones
|
|
|
|
|
valid_embedding = next((e for e in all_embeddings if e is not None), None)
|
|
|
|
|
if valid_embedding:
|
|
|
|
|
embedding_dim = len(valid_embedding)
|
|
|
|
|
for idx in failed_indices:
|
|
|
|
|
all_embeddings[idx] = [0.0] * embedding_dim
|
|
|
|
|
|
|
|
|
|
# Remove None values and convert to numpy array
|
|
|
|
|
all_embeddings = [e for e in all_embeddings if e is not None]
|
|
|
|
|
|
|
|
|
|
# Convert to numpy array and normalize
|
|
|
|
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
# Normalize embeddings (L2 normalization)
|
|
|
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
|
|
|
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
|
|
|
|
|
|
|
|
|
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
|
|
|
|
|
|
|
|
|
return embeddings
|
|
|
|
|