diff --git a/apps/base_rag_example.py b/apps/base_rag_example.py index f5a481c..4bd62b9 100644 --- a/apps/base_rag_example.py +++ b/apps/base_rag_example.py @@ -75,7 +75,7 @@ class BaseRAGExample(ABC): "--embedding-mode", type=str, default="sentence-transformers", - choices=["sentence-transformers", "openai", "mlx"], + choices=["sentence-transformers", "openai", "mlx", "ollama"], help="Embedding backend mode (default: sentence-transformers)", ) @@ -85,7 +85,7 @@ class BaseRAGExample(ABC): "--llm", type=str, default="openai", - choices=["openai", "ollama", "hf"], + choices=["openai", "ollama", "hf", "simulated"], help="LLM backend to use (default: openai)", ) llm_group.add_argument( diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py index ee7423f..1928dc8 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py @@ -261,7 +261,7 @@ if __name__ == "__main__": "--embedding-mode", type=str, default="sentence-transformers", - choices=["sentence-transformers", "openai", "mlx"], + choices=["sentence-transformers", "openai", "mlx", "ollama"], help="Embedding backend mode", ) parser.add_argument( diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 331477f..e9c246c 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -295,7 +295,7 @@ if __name__ == "__main__": "--embedding-mode", type=str, default="sentence-transformers", - choices=["sentence-transformers", "openai", "mlx"], + choices=["sentence-transformers", "openai", "mlx", "ollama"], help="Embedding backend mode", ) diff --git a/packages/leann-core/src/leann/cli.py b/packages/leann-core/src/leann/cli.py index 787cadd..f307204 100644 --- a/packages/leann-core/src/leann/cli.py +++ b/packages/leann-core/src/leann/cli.py @@ -94,6 +94,13 @@ Examples: "--backend", type=str, default="hnsw", choices=["hnsw", "diskann"] ) build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever") + build_parser.add_argument( + "--embedding-mode", + type=str, + default="sentence-transformers", + choices=["sentence-transformers", "openai", "mlx", "ollama"], + help="Embedding backend mode (default: sentence-transformers)", + ) build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild") build_parser.add_argument("--graph-degree", type=int, default=32) build_parser.add_argument("--complexity", type=int, default=64) @@ -469,6 +476,7 @@ Examples: builder = LeannBuilder( backend_name=args.backend, embedding_model=args.embedding_model, + embedding_mode=args.embedding_mode, graph_degree=args.graph_degree, complexity=args.complexity, is_compact=args.compact, diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 95fa9e4..b90b4cf 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -35,7 +35,7 @@ def compute_embeddings( Args: texts: List of texts to compute embeddings for model_name: Model name - mode: Computation mode ('sentence-transformers', 'openai', 'mlx') + mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama') is_build: Whether this is a build operation (shows progress bar) batch_size: Batch size for processing adaptive_optimization: Whether to use adaptive optimization based on batch size @@ -55,6 +55,8 @@ def compute_embeddings( return compute_embeddings_openai(texts, model_name) elif mode == "mlx": return compute_embeddings_mlx(texts, model_name) + elif mode == "ollama": + return compute_embeddings_ollama(texts, model_name, is_build=is_build) else: raise ValueError(f"Unsupported embedding mode: {mode}") @@ -365,3 +367,117 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = # Stack numpy arrays return np.stack(all_embeddings) + + +def compute_embeddings_ollama( + texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434" +) -> np.ndarray: + """ + Compute embeddings using Ollama API. + + Args: + texts: List of texts to compute embeddings for + model_name: Ollama model name (e.g., "nomic-embed-text-v2", "mxbai-embed-large") + is_build: Whether this is a build operation (shows progress bar) + host: Ollama host URL (default: http://localhost:11434) + + Returns: + Normalized embeddings array, shape: (len(texts), embedding_dim) + """ + try: + import requests + except ImportError: + raise ImportError( + "The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests" + ) + + if not texts: + raise ValueError("Cannot compute embeddings for empty text list") + + logger.info( + f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'" + ) + + # Check if Ollama is running + try: + response = requests.get(f"{host}/api/version", timeout=5) + response.raise_for_status() + except Exception as e: + raise RuntimeError( + f"Could not connect to Ollama at {host}. Please ensure Ollama is running. Error: {e}" + ) + + # Check if model exists + try: + response = requests.get(f"{host}/api/tags", timeout=5) + response.raise_for_status() + models = response.json() + model_names = [model["name"] for model in models.get("models", [])] + + # Check if model exists (handle versioned names like nomic-embed-text-v2:latest) + model_found = any( + model_name in name or name.startswith(f"{model_name}:") for name in model_names + ) + + if not model_found: + error_msg = f"Model '{model_name}' not found in Ollama. Available models: {', '.join(model_names)}" + error_msg += f"\n\nTo install the model, run:\n ollama pull {model_name}" + raise ValueError(error_msg) + except requests.exceptions.RequestException as e: + logger.warning(f"Could not verify model existence: {e}") + + # Process embeddings + all_embeddings = [] + batch_size = 100 # Process in batches to avoid overwhelming the API + + # Add progress bar if in build mode + try: + if is_build: + from tqdm import tqdm + + batch_iterator = tqdm( + range(0, len(texts), batch_size), + desc="Computing Ollama embeddings", + unit="batch", + total=(len(texts) + batch_size - 1) // batch_size, + ) + else: + batch_iterator = range(0, len(texts), batch_size) + except ImportError: + batch_iterator = range(0, len(texts), batch_size) + + for i in batch_iterator: + batch_texts = texts[i : i + batch_size] + + for text in batch_texts: + try: + response = requests.post( + f"{host}/api/embeddings", json={"model": model_name, "prompt": text}, timeout=30 + ) + response.raise_for_status() + + result = response.json() + embedding = result.get("embedding") + + if embedding is None: + raise ValueError(f"No embedding returned for text: {text[:50]}...") + + all_embeddings.append(embedding) + + except requests.exceptions.RequestException as e: + logger.error(f"Failed to get embedding for text: {e}") + raise RuntimeError(f"Error getting embedding from Ollama: {e}") + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise + + # Convert to numpy array and normalize + embeddings = np.array(all_embeddings, dtype=np.float32) + + # Normalize embeddings (L2 normalization) + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero + + logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}") + + return embeddings