feat: Add Ollama embedding support for local embedding models

This commit is contained in:
Andy Lee
2025-08-08 18:07:37 -07:00
parent 67fef60466
commit 068fcd71cf
5 changed files with 129 additions and 5 deletions

View File

@@ -75,7 +75,7 @@ class BaseRAGExample(ABC):
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx"],
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers)",
)
@@ -85,7 +85,7 @@ class BaseRAGExample(ABC):
"--llm",
type=str,
default="openai",
choices=["openai", "ollama", "hf"],
choices=["openai", "ollama", "hf", "simulated"],
help="LLM backend to use (default: openai)",
)
llm_group.add_argument(

View File

@@ -261,7 +261,7 @@ if __name__ == "__main__":
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx"],
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument(

View File

@@ -295,7 +295,7 @@ if __name__ == "__main__":
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx"],
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)

View File

@@ -94,6 +94,13 @@ Examples:
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
)
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
build_parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers)",
)
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
build_parser.add_argument("--graph-degree", type=int, default=32)
build_parser.add_argument("--complexity", type=int, default=64)
@@ -469,6 +476,7 @@ Examples:
builder = LeannBuilder(
backend_name=args.backend,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
graph_degree=args.graph_degree,
complexity=args.complexity,
is_compact=args.compact,

View File

@@ -35,7 +35,7 @@ def compute_embeddings(
Args:
texts: List of texts to compute embeddings for
model_name: Model name
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
is_build: Whether this is a build operation (shows progress bar)
batch_size: Batch size for processing
adaptive_optimization: Whether to use adaptive optimization based on batch size
@@ -55,6 +55,8 @@ def compute_embeddings(
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama":
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
else:
raise ValueError(f"Unsupported embedding mode: {mode}")
@@ -365,3 +367,117 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
# Stack numpy arrays
return np.stack(all_embeddings)
def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray:
"""
Compute embeddings using Ollama API.
Args:
texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text-v2", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (default: http://localhost:11434)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
try:
import requests
except ImportError:
raise ImportError(
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
)
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
)
# Check if Ollama is running
try:
response = requests.get(f"{host}/api/version", timeout=5)
response.raise_for_status()
except Exception as e:
raise RuntimeError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running. Error: {e}"
)
# Check if model exists
try:
response = requests.get(f"{host}/api/tags", timeout=5)
response.raise_for_status()
models = response.json()
model_names = [model["name"] for model in models.get("models", [])]
# Check if model exists (handle versioned names like nomic-embed-text-v2:latest)
model_found = any(
model_name in name or name.startswith(f"{model_name}:") for name in model_names
)
if not model_found:
error_msg = f"Model '{model_name}' not found in Ollama. Available models: {', '.join(model_names)}"
error_msg += f"\n\nTo install the model, run:\n ollama pull {model_name}"
raise ValueError(error_msg)
except requests.exceptions.RequestException as e:
logger.warning(f"Could not verify model existence: {e}")
# Process embeddings
all_embeddings = []
batch_size = 100 # Process in batches to avoid overwhelming the API
# Add progress bar if in build mode
try:
if is_build:
from tqdm import tqdm
batch_iterator = tqdm(
range(0, len(texts), batch_size),
desc="Computing Ollama embeddings",
unit="batch",
total=(len(texts) + batch_size - 1) // batch_size,
)
else:
batch_iterator = range(0, len(texts), batch_size)
except ImportError:
batch_iterator = range(0, len(texts), batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + batch_size]
for text in batch_texts:
try:
response = requests.post(
f"{host}/api/embeddings", json={"model": model_name, "prompt": text}, timeout=30
)
response.raise_for_status()
result = response.json()
embedding = result.get("embedding")
if embedding is None:
raise ValueError(f"No embedding returned for text: {text[:50]}...")
all_embeddings.append(embedding)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get embedding for text: {e}")
raise RuntimeError(f"Error getting embedding from Ollama: {e}")
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise
# Convert to numpy array and normalize
embeddings = np.array(all_embeddings, dtype=np.float32)
# Normalize embeddings (L2 normalization)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
return embeddings