From db7ba27ff6edbb64f82c0d1e5ef8e1659fba1640 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Tue, 23 Sep 2025 15:12:13 -0700 Subject: [PATCH] feat: Add support for configurable local LLM endpoints (#115) * feat: support configurable local llm endpoints * docs --- apps/base_rag_example.py | 51 +++++++- docs/configuration-guide.md | 75 ++++++++++++ .../diskann_embedding_server.py | 26 +++- .../hnsw_embedding_server.py | 30 ++++- packages/leann-core/src/leann/api.py | 16 +++ packages/leann-core/src/leann/chat.py | 51 +++++--- packages/leann-core/src/leann/cli.py | 55 ++++++++- .../leann-core/src/leann/embedding_compute.py | 63 +++++++--- .../src/leann/embedding_server_manager.py | 111 ++++++++++++++++-- .../leann-core/src/leann/searcher_base.py | 9 +- packages/leann-core/src/leann/settings.py | 74 ++++++++++++ 11 files changed, 503 insertions(+), 58 deletions(-) create mode 100644 packages/leann-core/src/leann/settings.py diff --git a/apps/base_rag_example.py b/apps/base_rag_example.py index be1be04..d07c3d1 100644 --- a/apps/base_rag_example.py +++ b/apps/base_rag_example.py @@ -11,6 +11,7 @@ from typing import Any import dotenv from leann.api import LeannBuilder, LeannChat from leann.registry import register_project_directory +from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url dotenv.load_dotenv() @@ -78,6 +79,24 @@ class BaseRAGExample(ABC): choices=["sentence-transformers", "openai", "mlx", "ollama"], help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama", ) + embedding_group.add_argument( + "--embedding-host", + type=str, + default=None, + help="Override Ollama-compatible embedding host", + ) + embedding_group.add_argument( + "--embedding-api-base", + type=str, + default=None, + help="Base URL for OpenAI-compatible embedding services", + ) + embedding_group.add_argument( + "--embedding-api-key", + type=str, + default=None, + help="API key for embedding service (defaults to OPENAI_API_KEY)", + ) # LLM parameters llm_group = parser.add_argument_group("LLM Parameters") @@ -97,8 +116,8 @@ class BaseRAGExample(ABC): llm_group.add_argument( "--llm-host", type=str, - default="http://localhost:11434", - help="Host for Ollama API (default: http://localhost:11434)", + default=None, + help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)", ) llm_group.add_argument( "--thinking-budget", @@ -107,6 +126,18 @@ class BaseRAGExample(ABC): default=None, help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.", ) + llm_group.add_argument( + "--llm-api-base", + type=str, + default=None, + help="Base URL for OpenAI-compatible APIs", + ) + llm_group.add_argument( + "--llm-api-key", + type=str, + default=None, + help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)", + ) # AST Chunking parameters ast_group = parser.add_argument_group("AST Chunking Parameters") @@ -205,9 +236,13 @@ class BaseRAGExample(ABC): if args.llm == "openai": config["model"] = args.llm_model or "gpt-4o" + config["base_url"] = resolve_openai_base_url(args.llm_api_base) + resolved_key = resolve_openai_api_key(args.llm_api_key) + if resolved_key: + config["api_key"] = resolved_key elif args.llm == "ollama": config["model"] = args.llm_model or "llama3.2:1b" - config["host"] = args.llm_host + config["host"] = resolve_ollama_host(args.llm_host) elif args.llm == "hf": config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct" elif args.llm == "simulated": @@ -223,10 +258,20 @@ class BaseRAGExample(ABC): print(f"\n[Building Index] Creating {self.name} index...") print(f"Total text chunks: {len(texts)}") + embedding_options: dict[str, Any] = {} + if args.embedding_mode == "ollama": + embedding_options["host"] = resolve_ollama_host(args.embedding_host) + elif args.embedding_mode == "openai": + embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base) + resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key) + if resolved_embedding_key: + embedding_options["api_key"] = resolved_embedding_key + builder = LeannBuilder( backend_name=args.backend_name, embedding_model=args.embedding_model, embedding_mode=args.embedding_mode, + embedding_options=embedding_options or None, graph_degree=args.graph_degree, complexity=args.build_complexity, is_compact=not args.no_compact, diff --git a/docs/configuration-guide.md b/docs/configuration-guide.md index c1402a1..ac170f1 100644 --- a/docs/configuration-guide.md +++ b/docs/configuration-guide.md @@ -83,6 +83,81 @@ ollama pull nomic-embed-text +## Local & Remote Inference Endpoints + +> Applies to both LLMs (`leann ask`) and embeddings (`leann build`). + +LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables. + +### One-Time Environment Setup + +```bash +# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc. +export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys +export OPENAI_BASE_URL="http://localhost:1234/v1" + +# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.) +export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT +``` + +LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work. + +### Passing Hosts Per Command + +```bash +# Build an index with a remote embedding server +leann build my-notes \ + --docs ./notes \ + --embedding-mode openai \ + --embedding-model text-embedding-qwen3-embedding-0.6b \ + --embedding-api-base http://192.168.1.50:1234/v1 \ + --embedding-api-key local-dev-key + +# Query using a local LM Studio instance via OpenAI-compatible API +leann ask my-notes \ + --llm openai \ + --llm-model qwen3-8b \ + --api-base http://localhost:1234/v1 \ + --api-key local-dev-key + +# Query an Ollama instance running on another box +leann ask my-notes \ + --llm ollama \ + --llm-model qwen3:14b \ + --host http://192.168.1.101:11434 +``` + +⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include: + +- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama. +- Configure router or cloud provider port forwarding. +- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`. + +When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box. + +**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials. + +### Python API Usage + +You can pass the same configuration from Python: + +```python +from leann.api import LeannBuilder + +builder = LeannBuilder( + backend_name="hnsw", + embedding_mode="openai", + embedding_model="text-embedding-qwen3-embedding-0.6b", + embedding_options={ + "base_url": "http://192.168.1.50:1234/v1", + "api_key": "local-dev-key", + }, +) +builder.build_index("./indexes/my-notes", chunks) +``` + +`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you). + ## Index Selection: Matching Your Scale ### HNSW (Hierarchical Navigable Small World) diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py index 8389ddf..592fddb 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py @@ -10,7 +10,7 @@ import sys import threading import time from pathlib import Path -from typing import Optional +from typing import Any, Optional import numpy as np import zmq @@ -32,6 +32,16 @@ if not logger.handlers: logger.propagate = False +_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS") +try: + PROVIDER_OPTIONS: dict[str, Any] = ( + json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {} + ) +except json.JSONDecodeError: + logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options") + PROVIDER_OPTIONS = {} + + def create_diskann_embedding_server( passages_file: Optional[str] = None, zmq_port: int = 5555, @@ -181,7 +191,12 @@ def create_diskann_embedding_server( logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5 # Process embeddings using unified computation - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) + embeddings = compute_embeddings( + texts, + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) logger.info( f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" ) @@ -296,7 +311,12 @@ def create_diskann_embedding_server( continue # Process the request - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) + embeddings = compute_embeddings( + texts, + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) logger.info(f"Computed embeddings shape: {embeddings.shape}") # Validation diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index d2c4852..b6a0eb0 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -10,7 +10,7 @@ import sys import threading import time from pathlib import Path -from typing import Optional +from typing import Any, Optional import msgpack import numpy as np @@ -45,6 +45,15 @@ if log_path: logger.propagate = False +_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS") +try: + PROVIDER_OPTIONS: dict[str, Any] = ( + json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {} + ) +except json.JSONDecodeError: + logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options") + PROVIDER_OPTIONS = {} + def create_hnsw_embedding_server( passages_file: Optional[str] = None, @@ -151,7 +160,12 @@ def create_hnsw_embedding_server( ): last_request_type = "text" last_request_length = len(request) - embeddings = compute_embeddings(request, model_name, mode=embedding_mode) + embeddings = compute_embeddings( + request, + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) rep_socket.send(msgpack.packb(embeddings.tolist())) e2e_end = time.time() logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") @@ -200,7 +214,10 @@ def create_hnsw_embedding_server( if texts: try: embeddings = compute_embeddings( - texts, model_name, mode=embedding_mode + texts, + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, ) logger.info( f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" @@ -265,7 +282,12 @@ def create_hnsw_embedding_server( if texts: try: - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) + embeddings = compute_embeddings( + texts, + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) logger.info( f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" ) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 0d4040a..dfa4685 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -39,6 +39,7 @@ def compute_embeddings( use_server: bool = True, port: Optional[int] = None, is_build=False, + provider_options: Optional[dict[str, Any]] = None, ) -> np.ndarray: """ Computes embeddings using different backends. @@ -72,6 +73,7 @@ def compute_embeddings( model_name, mode=mode, is_build=is_build, + provider_options=provider_options, ) @@ -278,6 +280,7 @@ class LeannBuilder: embedding_model: str = "facebook/contriever", dimensions: Optional[int] = None, embedding_mode: str = "sentence-transformers", + embedding_options: Optional[dict[str, Any]] = None, **backend_kwargs, ): self.backend_name = backend_name @@ -300,6 +303,7 @@ class LeannBuilder: self.embedding_model = embedding_model self.dimensions = dimensions self.embedding_mode = embedding_mode + self.embedding_options = embedding_options or {} # Check if we need to use cosine distance for normalized embeddings normalized_embeddings_models = { @@ -407,6 +411,7 @@ class LeannBuilder: self.embedding_model, self.embedding_mode, use_server=False, + provider_options=self.embedding_options, )[0] ) path = Path(index_path) @@ -446,6 +451,7 @@ class LeannBuilder: self.embedding_mode, use_server=False, is_build=True, + provider_options=self.embedding_options, ) string_ids = [chunk["id"] for chunk in self.chunks] current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} @@ -472,6 +478,9 @@ class LeannBuilder: ], } + if self.embedding_options: + meta_data["embedding_options"] = self.embedding_options + # Add storage status flags for HNSW backend if self.backend_name == "hnsw": is_compact = self.backend_kwargs.get("is_compact", True) @@ -592,6 +601,9 @@ class LeannBuilder: "embeddings_source": str(embeddings_file), } + if self.embedding_options: + meta_data["embedding_options"] = self.embedding_options + # Add storage status flags for HNSW backend if self.backend_name == "hnsw": is_compact = self.backend_kwargs.get("is_compact", True) @@ -673,6 +685,7 @@ class LeannBuilder: self.embedding_mode, use_server=False, is_build=True, + provider_options=self.embedding_options, ) embedding_dim = embeddings.shape[1] @@ -771,6 +784,7 @@ class LeannSearcher: self.embedding_model = self.meta_data["embedding_model"] # Support both old and new format self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") + self.embedding_options = self.meta_data.get("embedding_options", {}) # Delegate portability handling to PassageManager self.passage_manager = PassageManager( self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str @@ -782,6 +796,8 @@ class LeannSearcher: raise ValueError(f"Backend '{backend_name}' not found.") final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs["enable_warmup"] = enable_warmup + if self.embedding_options: + final_kwargs.setdefault("embedding_options", self.embedding_options) self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( index_path, **final_kwargs ) diff --git a/packages/leann-core/src/leann/chat.py b/packages/leann-core/src/leann/chat.py index 391c59d..8135daf 100644 --- a/packages/leann-core/src/leann/chat.py +++ b/packages/leann-core/src/leann/chat.py @@ -12,6 +12,8 @@ from typing import Any, Optional import torch +from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url + # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]: def validate_model_and_suggest( - model_name: str, llm_type: str, host: str = "http://localhost:11434" + model_name: str, llm_type: str, host: Optional[str] = None ) -> Optional[str]: """Validate model name and provide suggestions if invalid""" if llm_type == "ollama": - available_models = check_ollama_models(host) + resolved_host = resolve_ollama_host(host) + available_models = check_ollama_models(resolved_host) if available_models and model_name not in available_models: error_msg = f"Model '{model_name}' not found in your local Ollama installation." @@ -457,19 +460,19 @@ class LLMInterface(ABC): class OllamaChat(LLMInterface): """LLM interface for Ollama models.""" - def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"): + def __init__(self, model: str = "llama3:8b", host: Optional[str] = None): self.model = model - self.host = host - logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'") + self.host = resolve_ollama_host(host) + logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'") try: import requests # Check if the Ollama server is responsive - if host: - requests.get(host) + if self.host: + requests.get(self.host) # Pre-check model availability with helpful suggestions - model_error = validate_model_and_suggest(model, "ollama", host) + model_error = validate_model_and_suggest(model, "ollama", self.host) if model_error: raise ValueError(model_error) @@ -478,9 +481,11 @@ class OllamaChat(LLMInterface): "The 'requests' library is required for Ollama. Please install it with 'pip install requests'." ) except requests.exceptions.ConnectionError: - logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") + logger.error( + f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running." + ) raise ConnectionError( - f"Could not connect to Ollama at {host}. Please ensure Ollama is running." + f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running." ) def ask(self, prompt: str, **kwargs) -> str: @@ -737,21 +742,31 @@ class GeminiChat(LLMInterface): class OpenAIChat(LLMInterface): """LLM interface for OpenAI models.""" - def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None): + def __init__( + self, + model: str = "gpt-4o", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + ): self.model = model - self.api_key = api_key or os.getenv("OPENAI_API_KEY") + self.base_url = resolve_openai_base_url(base_url) + self.api_key = resolve_openai_api_key(api_key) if not self.api_key: raise ValueError( "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." ) - logger.info(f"Initializing OpenAI Chat with model='{model}'") + logger.info( + "Initializing OpenAI Chat with model='%s' and base_url='%s'", + model, + self.base_url, + ) try: import openai - self.client = openai.OpenAI(api_key=self.api_key) + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) except ImportError: raise ImportError( "The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'." @@ -841,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface: if llm_type == "ollama": return OllamaChat( model=model or "llama3:8b", - host=llm_config.get("host", "http://localhost:11434"), + host=llm_config.get("host"), ) elif llm_type == "hf": return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat") elif llm_type == "openai": - return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key")) + return OpenAIChat( + model=model or "gpt-4o", + api_key=llm_config.get("api_key"), + base_url=llm_config.get("base_url"), + ) elif llm_type == "gemini": return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) elif llm_type == "simulated": diff --git a/packages/leann-core/src/leann/cli.py b/packages/leann-core/src/leann/cli.py index 2d514e2..1b1e298 100644 --- a/packages/leann-core/src/leann/cli.py +++ b/packages/leann-core/src/leann/cli.py @@ -9,6 +9,7 @@ from tqdm import tqdm from .api import LeannBuilder, LeannChat, LeannSearcher from .registry import register_project_directory +from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url def extract_pdf_text_with_pymupdf(file_path: str) -> str: @@ -123,6 +124,24 @@ Examples: choices=["sentence-transformers", "openai", "mlx", "ollama"], help="Embedding backend mode (default: sentence-transformers)", ) + build_parser.add_argument( + "--embedding-host", + type=str, + default=None, + help="Override Ollama-compatible embedding host", + ) + build_parser.add_argument( + "--embedding-api-base", + type=str, + default=None, + help="Base URL for OpenAI-compatible embedding services", + ) + build_parser.add_argument( + "--embedding-api-key", + type=str, + default=None, + help="API key for embedding service (defaults to OPENAI_API_KEY)", + ) build_parser.add_argument( "--force", "-f", action="store_true", help="Force rebuild existing index" ) @@ -248,7 +267,12 @@ Examples: ask_parser.add_argument( "--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)" ) - ask_parser.add_argument("--host", type=str, default="http://localhost:11434") + ask_parser.add_argument( + "--host", + type=str, + default=None, + help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)", + ) ask_parser.add_argument( "--interactive", "-i", action="store_true", help="Interactive chat mode" ) @@ -277,6 +301,18 @@ Examples: default=None, help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.", ) + ask_parser.add_argument( + "--api-base", + type=str, + default=None, + help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)", + ) + ask_parser.add_argument( + "--api-key", + type=str, + default=None, + help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)", + ) # List command subparsers.add_parser("list", help="List all indexes") @@ -1325,10 +1361,20 @@ Examples: print(f"Building index '{index_name}' with {args.backend} backend...") + embedding_options: dict[str, Any] = {} + if args.embedding_mode == "ollama": + embedding_options["host"] = resolve_ollama_host(args.embedding_host) + elif args.embedding_mode == "openai": + embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base) + resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key) + if resolved_embedding_key: + embedding_options["api_key"] = resolved_embedding_key + builder = LeannBuilder( backend_name=args.backend, embedding_model=args.embedding_model, embedding_mode=args.embedding_mode, + embedding_options=embedding_options or None, graph_degree=args.graph_degree, complexity=args.complexity, is_compact=args.compact, @@ -1476,7 +1522,12 @@ Examples: llm_config = {"type": args.llm, "model": args.model} if args.llm == "ollama": - llm_config["host"] = args.host + llm_config["host"] = resolve_ollama_host(args.host) + elif args.llm == "openai": + llm_config["base_url"] = resolve_openai_base_url(args.api_base) + resolved_api_key = resolve_openai_api_key(args.api_key) + if resolved_api_key: + llm_config["api_key"] = resolved_api_key chat = LeannChat(index_path=index_path, llm_config=llm_config) diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 83f112a..a01bd3b 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -7,11 +7,13 @@ Preserves all optimization parameters to ensure performance import logging import os import time -from typing import Any +from typing import Any, Optional import numpy as np import torch +from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url + # Set up logger with proper level logger = logging.getLogger(__name__) LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() @@ -31,6 +33,7 @@ def compute_embeddings( adaptive_optimization: bool = True, manual_tokenize: bool = False, max_length: int = 512, + provider_options: Optional[dict[str, Any]] = None, ) -> np.ndarray: """ Unified embedding computation entry point @@ -46,6 +49,8 @@ def compute_embeddings( Returns: Normalized embeddings array, shape: (len(texts), embedding_dim) """ + provider_options = provider_options or {} + if mode == "sentence-transformers": return compute_embeddings_sentence_transformers( texts, @@ -57,11 +62,21 @@ def compute_embeddings( max_length=max_length, ) elif mode == "openai": - return compute_embeddings_openai(texts, model_name) + return compute_embeddings_openai( + texts, + model_name, + base_url=provider_options.get("base_url"), + api_key=provider_options.get("api_key"), + ) elif mode == "mlx": return compute_embeddings_mlx(texts, model_name) elif mode == "ollama": - return compute_embeddings_ollama(texts, model_name, is_build=is_build) + return compute_embeddings_ollama( + texts, + model_name, + is_build=is_build, + host=provider_options.get("host"), + ) elif mode == "gemini": return compute_embeddings_gemini(texts, model_name, is_build=is_build) else: @@ -353,12 +368,15 @@ def compute_embeddings_sentence_transformers( return embeddings -def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray: +def compute_embeddings_openai( + texts: list[str], + model_name: str, + base_url: Optional[str] = None, + api_key: Optional[str] = None, +) -> np.ndarray: # TODO: @yichuan-w add progress bar only in build mode """Compute embeddings using OpenAI API""" try: - import os - import openai except ImportError as e: raise ImportError(f"OpenAI package not installed: {e}") @@ -373,16 +391,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray: f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI." ) - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: + resolved_base_url = resolve_openai_base_url(base_url) + resolved_api_key = resolve_openai_api_key(api_key) + + if not resolved_api_key: raise RuntimeError("OPENAI_API_KEY environment variable not set") # Cache OpenAI client - cache_key = "openai_client" + cache_key = f"openai_client::{resolved_base_url}" if cache_key in _model_cache: client = _model_cache[cache_key] else: - client = openai.OpenAI(api_key=api_key) + client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url) _model_cache[cache_key] = client logger.info("OpenAI client cached") @@ -507,7 +527,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = def compute_embeddings_ollama( - texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434" + texts: list[str], + model_name: str, + is_build: bool = False, + host: Optional[str] = None, ) -> np.ndarray: """ Compute embeddings using Ollama API with simplified batch processing. @@ -518,7 +541,7 @@ def compute_embeddings_ollama( texts: List of texts to compute embeddings for model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large") is_build: Whether this is a build operation (shows progress bar) - host: Ollama host URL (default: http://localhost:11434) + host: Ollama host URL (defaults to environment or http://localhost:11434) Returns: Normalized embeddings array, shape: (len(texts), embedding_dim) @@ -533,17 +556,19 @@ def compute_embeddings_ollama( if not texts: raise ValueError("Cannot compute embeddings for empty text list") + resolved_host = resolve_ollama_host(host) + logger.info( - f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'" + f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'" ) # Check if Ollama is running try: - response = requests.get(f"{host}/api/version", timeout=5) + response = requests.get(f"{resolved_host}/api/version", timeout=5) response.raise_for_status() except requests.exceptions.ConnectionError: error_msg = ( - f"❌ Could not connect to Ollama at {host}.\n\n" + f"❌ Could not connect to Ollama at {resolved_host}.\n\n" "Please ensure Ollama is running:\n" " • macOS/Linux: ollama serve\n" " • Windows: Make sure Ollama is running in the system tray\n\n" @@ -555,7 +580,7 @@ def compute_embeddings_ollama( # Check if model exists and provide helpful suggestions try: - response = requests.get(f"{host}/api/tags", timeout=5) + response = requests.get(f"{resolved_host}/api/tags", timeout=5) response.raise_for_status() models = response.json() model_names = [model["name"] for model in models.get("models", [])] @@ -618,7 +643,9 @@ def compute_embeddings_ollama( # Verify the model supports embeddings by testing it try: test_response = requests.post( - f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10 + f"{resolved_host}/api/embeddings", + json={"model": model_name, "prompt": "test"}, + timeout=10, ) if test_response.status_code != 200: error_msg = ( @@ -665,7 +692,7 @@ def compute_embeddings_ollama( while retry_count < max_retries: try: response = requests.post( - f"{host}/api/embeddings", + f"{resolved_host}/api/embeddings", json={"model": model_name, "prompt": truncated_text}, timeout=30, ) diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index 3d7c31e..7a2fcb2 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -8,6 +8,8 @@ import time from pathlib import Path from typing import Optional +from .settings import encode_provider_options + # Lightweight, self-contained server manager with no cross-process inspection # Set up logging based on environment variable @@ -82,16 +84,40 @@ class EmbeddingServerManager: ) -> tuple[bool, int]: """Start the embedding server.""" # passages_file may be present in kwargs for server CLI, but we don't need it here + provider_options = kwargs.pop("provider_options", None) + + config_signature = { + "model_name": model_name, + "passages_file": kwargs.get("passages_file", ""), + "embedding_mode": embedding_mode, + "provider_options": provider_options or {}, + } # If this manager already has a live server, just reuse it - if self.server_process and self.server_process.poll() is None and self.server_port: + if ( + self.server_process + and self.server_process.poll() is None + and self.server_port + and self._server_config == config_signature + ): logger.info("Reusing in-process server") return True, self.server_port + # Configuration changed, stop existing server before starting a new one + if self.server_process and self.server_process.poll() is None: + logger.info("Existing server configuration differs; restarting embedding server") + self.stop_server() + # For Colab environment, use a different strategy if _is_colab_environment(): logger.info("Detected Colab environment, using alternative startup strategy") - return self._start_server_colab(port, model_name, embedding_mode, **kwargs) + return self._start_server_colab( + port, + model_name, + embedding_mode, + provider_options=provider_options, + **kwargs, + ) # Always pick a fresh available port try: @@ -101,13 +127,21 @@ class EmbeddingServerManager: return False, port # Start a new server - return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) + return self._start_new_server( + actual_port, + model_name, + embedding_mode, + provider_options=provider_options, + config_signature=config_signature, + **kwargs, + ) def _start_server_colab( self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", + provider_options: Optional[dict] = None, **kwargs, ) -> tuple[bool, int]: """Start server with Colab-specific configuration.""" @@ -125,8 +159,20 @@ class EmbeddingServerManager: try: # In Colab, we'll use a more direct approach - self._launch_server_process_colab(command, actual_port) - return self._wait_for_server_ready_colab(actual_port) + self._launch_server_process_colab( + command, + actual_port, + provider_options=provider_options, + ) + started, ready_port = self._wait_for_server_ready_colab(actual_port) + if started: + self._server_config = { + "model_name": model_name, + "passages_file": kwargs.get("passages_file", ""), + "embedding_mode": embedding_mode, + "provider_options": provider_options or {}, + } + return started, ready_port except Exception as e: logger.error(f"Failed to start embedding server in Colab: {e}") return False, actual_port @@ -134,7 +180,13 @@ class EmbeddingServerManager: # Note: No compatibility check needed; manager is per-searcher and configs are stable per instance def _start_new_server( - self, port: int, model_name: str, embedding_mode: str, **kwargs + self, + port: int, + model_name: str, + embedding_mode: str, + provider_options: Optional[dict] = None, + config_signature: Optional[dict] = None, + **kwargs, ) -> tuple[bool, int]: """Start a new embedding server on the given port.""" logger.info(f"Starting embedding server on port {port}...") @@ -142,8 +194,20 @@ class EmbeddingServerManager: command = self._build_server_command(port, model_name, embedding_mode, **kwargs) try: - self._launch_server_process(command, port) - return self._wait_for_server_ready(port) + self._launch_server_process( + command, + port, + provider_options=provider_options, + ) + started, ready_port = self._wait_for_server_ready(port) + if started: + self._server_config = config_signature or { + "model_name": model_name, + "passages_file": kwargs.get("passages_file", ""), + "embedding_mode": embedding_mode, + "provider_options": provider_options or {}, + } + return started, ready_port except Exception as e: logger.error(f"Failed to start embedding server: {e}") return False, port @@ -173,7 +237,12 @@ class EmbeddingServerManager: return command - def _launch_server_process(self, command: list, port: int) -> None: + def _launch_server_process( + self, + command: list, + port: int, + provider_options: Optional[dict] = None, + ) -> None: """Launch the server process.""" project_root = Path(__file__).parent.parent.parent.parent.parent logger.info(f"Command: {' '.join(command)}") @@ -193,14 +262,20 @@ class EmbeddingServerManager: # Start embedding server subprocess logger.info(f"Starting server process with command: {' '.join(command)}") + env = os.environ.copy() + encoded_options = encode_provider_options(provider_options) + if encoded_options: + env["LEANN_EMBEDDING_OPTIONS"] = encoded_options + self.server_process = subprocess.Popen( command, cwd=project_root, stdout=stdout_target, stderr=stderr_target, + env=env, ) self.server_port = port - # Record config for in-process reuse + # Record config for in-process reuse (best effort; refined later when ready) try: self._server_config = { "model_name": command[command.index("--model-name") + 1] @@ -212,12 +287,14 @@ class EmbeddingServerManager: "embedding_mode": command[command.index("--embedding-mode") + 1] if "--embedding-mode" in command else "sentence-transformers", + "provider_options": provider_options or {}, } except Exception: self._server_config = { "model_name": "", "passages_file": "", "embedding_mode": "sentence-transformers", + "provider_options": provider_options or {}, } logger.info(f"Server process started with PID: {self.server_process.pid}") @@ -322,16 +399,27 @@ class EmbeddingServerManager: # Removed: cross-process adoption no longer supported return - def _launch_server_process_colab(self, command: list, port: int) -> None: + def _launch_server_process_colab( + self, + command: list, + port: int, + provider_options: Optional[dict] = None, + ) -> None: """Launch the server process with Colab-specific settings.""" logger.info(f"Colab Command: {' '.join(command)}") # In Colab, we need to be more careful about process management + env = os.environ.copy() + encoded_options = encode_provider_options(provider_options) + if encoded_options: + env["LEANN_EMBEDDING_OPTIONS"] = encoded_options + self.server_process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, + env=env, ) self.server_port = port logger.info(f"Colab server process started with PID: {self.server_process.pid}") @@ -345,6 +433,7 @@ class EmbeddingServerManager: "model_name": "", "passages_file": "", "embedding_mode": "sentence-transformers", + "provider_options": provider_options or {}, } def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]: diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index ff368c8..4726605 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): print("WARNING: embedding_model not found in meta.json. Recompute will fail.") self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") + self.embedding_options = self.meta.get("embedding_options", {}) self.embedding_server_manager = EmbeddingServerManager( backend_module_name=backend_module_name, @@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): passages_file=passages_source_file, distance_metric=distance_metric, enable_warmup=kwargs.get("enable_warmup", False), + provider_options=self.embedding_options, ) if not server_started: raise RuntimeError(f"Failed to start embedding server on port {actual_port}") @@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): from .embedding_compute import compute_embeddings embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") - return compute_embeddings([query], self.embedding_model, embedding_mode) + return compute_embeddings( + [query], + self.embedding_model, + embedding_mode, + provider_options=self.embedding_options, + ) def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray: """Compute embeddings using the ZMQ embedding server.""" diff --git a/packages/leann-core/src/leann/settings.py b/packages/leann-core/src/leann/settings.py new file mode 100644 index 0000000..3f0042e --- /dev/null +++ b/packages/leann-core/src/leann/settings.py @@ -0,0 +1,74 @@ +"""Runtime configuration helpers for LEANN.""" + +from __future__ import annotations + +import json +import os +from typing import Any + +# Default fallbacks to preserve current behaviour while keeping them in one place. +_DEFAULT_OLLAMA_HOST = "http://localhost:11434" +_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1" + + +def _clean_url(value: str) -> str: + """Normalize URL strings by stripping trailing slashes.""" + + return value.rstrip("/") if value else value + + +def resolve_ollama_host(explicit: str | None = None) -> str: + """Resolve the Ollama-compatible endpoint to use.""" + + candidates = ( + explicit, + os.getenv("LEANN_LOCAL_LLM_HOST"), + os.getenv("LEANN_OLLAMA_HOST"), + os.getenv("OLLAMA_HOST"), + os.getenv("LOCAL_LLM_ENDPOINT"), + ) + + for candidate in candidates: + if candidate: + return _clean_url(candidate) + + return _clean_url(_DEFAULT_OLLAMA_HOST) + + +def resolve_openai_base_url(explicit: str | None = None) -> str: + """Resolve the base URL for OpenAI-compatible services.""" + + candidates = ( + explicit, + os.getenv("LEANN_OPENAI_BASE_URL"), + os.getenv("OPENAI_BASE_URL"), + os.getenv("LOCAL_OPENAI_BASE_URL"), + ) + + for candidate in candidates: + if candidate: + return _clean_url(candidate) + + return _clean_url(_DEFAULT_OPENAI_BASE_URL) + + +def resolve_openai_api_key(explicit: str | None = None) -> str | None: + """Resolve the API key for OpenAI-compatible services.""" + + if explicit: + return explicit + + return os.getenv("OPENAI_API_KEY") + + +def encode_provider_options(options: dict[str, Any] | None) -> str | None: + """Serialize provider options for child processes.""" + + if not options: + return None + + try: + return json.dumps(options) + except (TypeError, ValueError): + # Fall back to empty payload if serialization fails + return None