Compare commits
2 Commits
fix/drop-p
...
feat/confi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
effeb47e94 | ||
|
|
4115613b10 |
@@ -11,6 +11,7 @@ from typing import Any
|
|||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
from leann.registry import register_project_directory
|
from leann.registry import register_project_directory
|
||||||
|
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -78,6 +79,24 @@ class BaseRAGExample(ABC):
|
|||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||||
)
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# LLM parameters
|
# LLM parameters
|
||||||
llm_group = parser.add_argument_group("LLM Parameters")
|
llm_group = parser.add_argument_group("LLM Parameters")
|
||||||
@@ -97,8 +116,8 @@ class BaseRAGExample(ABC):
|
|||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--llm-host",
|
"--llm-host",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://localhost:11434",
|
default=None,
|
||||||
help="Host for Ollama API (default: http://localhost:11434)",
|
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--thinking-budget",
|
"--thinking-budget",
|
||||||
@@ -107,6 +126,18 @@ class BaseRAGExample(ABC):
|
|||||||
default=None,
|
default=None,
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
)
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# AST Chunking parameters
|
# AST Chunking parameters
|
||||||
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||||
@@ -205,9 +236,13 @@ class BaseRAGExample(ABC):
|
|||||||
|
|
||||||
if args.llm == "openai":
|
if args.llm == "openai":
|
||||||
config["model"] = args.llm_model or "gpt-4o"
|
config["model"] = args.llm_model or "gpt-4o"
|
||||||
|
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||||
|
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||||
|
if resolved_key:
|
||||||
|
config["api_key"] = resolved_key
|
||||||
elif args.llm == "ollama":
|
elif args.llm == "ollama":
|
||||||
config["model"] = args.llm_model or "llama3.2:1b"
|
config["model"] = args.llm_model or "llama3.2:1b"
|
||||||
config["host"] = args.llm_host
|
config["host"] = resolve_ollama_host(args.llm_host)
|
||||||
elif args.llm == "hf":
|
elif args.llm == "hf":
|
||||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
elif args.llm == "simulated":
|
elif args.llm == "simulated":
|
||||||
@@ -223,10 +258,20 @@ class BaseRAGExample(ABC):
|
|||||||
print(f"\n[Building Index] Creating {self.name} index...")
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
print(f"Total text chunks: {len(texts)}")
|
print(f"Total text chunks: {len(texts)}")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend_name,
|
backend_name=args.backend_name,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
graph_degree=args.graph_degree,
|
graph_degree=args.graph_degree,
|
||||||
complexity=args.build_complexity,
|
complexity=args.build_complexity,
|
||||||
is_compact=not args.no_compact,
|
is_compact=not args.no_compact,
|
||||||
|
|||||||
@@ -83,6 +83,81 @@ ollama pull nomic-embed-text
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## Local & Remote Inference Endpoints
|
||||||
|
|
||||||
|
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
|
||||||
|
|
||||||
|
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables.
|
||||||
|
|
||||||
|
### One-Time Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
|
||||||
|
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
|
||||||
|
export OPENAI_BASE_URL="http://localhost:1234/v1"
|
||||||
|
|
||||||
|
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
|
||||||
|
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
|
||||||
|
```
|
||||||
|
|
||||||
|
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
|
||||||
|
|
||||||
|
### Passing Hosts Per Command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build an index with a remote embedding server
|
||||||
|
leann build my-notes \
|
||||||
|
--docs ./notes \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-qwen3-embedding-0.6b \
|
||||||
|
--embedding-api-base http://192.168.1.50:1234/v1 \
|
||||||
|
--embedding-api-key local-dev-key
|
||||||
|
|
||||||
|
# Query using a local LM Studio instance via OpenAI-compatible API
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm openai \
|
||||||
|
--llm-model qwen3-8b \
|
||||||
|
--api-base http://localhost:1234/v1 \
|
||||||
|
--api-key local-dev-key
|
||||||
|
|
||||||
|
# Query an Ollama instance running on another box
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm ollama \
|
||||||
|
--llm-model qwen3:14b \
|
||||||
|
--host http://192.168.1.101:11434
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
|
||||||
|
|
||||||
|
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
|
||||||
|
- Configure router or cloud provider port forwarding.
|
||||||
|
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
|
||||||
|
|
||||||
|
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
|
||||||
|
|
||||||
|
**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials.
|
||||||
|
|
||||||
|
### Python API Usage
|
||||||
|
|
||||||
|
You can pass the same configuration from Python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_model="text-embedding-qwen3-embedding-0.6b",
|
||||||
|
embedding_options={
|
||||||
|
"base_url": "http://192.168.1.50:1234/v1",
|
||||||
|
"api_key": "local-dev-key",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.build_index("./indexes/my-notes", chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
@@ -32,6 +32,16 @@ if not logger.handlers:
|
|||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||||
|
try:
|
||||||
|
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||||
|
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||||
|
PROVIDER_OPTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def create_diskann_embedding_server(
|
def create_diskann_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: Optional[str] = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
@@ -181,7 +191,12 @@ def create_diskann_embedding_server(
|
|||||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||||
|
|
||||||
# Process embeddings using unified computation
|
# Process embeddings using unified computation
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
@@ -296,7 +311,12 @@ def create_diskann_embedding_server(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Process the request
|
# Process the request
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -45,6 +45,15 @@ if log_path:
|
|||||||
|
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
|
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||||
|
try:
|
||||||
|
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||||
|
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||||
|
PROVIDER_OPTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: Optional[str] = None,
|
||||||
@@ -151,7 +160,12 @@ def create_hnsw_embedding_server(
|
|||||||
):
|
):
|
||||||
last_request_type = "text"
|
last_request_type = "text"
|
||||||
last_request_length = len(request)
|
last_request_length = len(request)
|
||||||
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
request,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
@@ -200,7 +214,10 @@ def create_hnsw_embedding_server(
|
|||||||
if texts:
|
if texts:
|
||||||
try:
|
try:
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
texts, model_name, mode=embedding_mode
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
@@ -265,7 +282,12 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
if texts:
|
if texts:
|
||||||
try:
|
try:
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ def compute_embeddings(
|
|||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
is_build=False,
|
is_build=False,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
@@ -72,6 +73,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -278,6 +280,7 @@ class LeannBuilder:
|
|||||||
embedding_model: str = "facebook/contriever",
|
embedding_model: str = "facebook/contriever",
|
||||||
dimensions: Optional[int] = None,
|
dimensions: Optional[int] = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
embedding_options: Optional[dict[str, Any]] = None,
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
):
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
@@ -300,6 +303,7 @@ class LeannBuilder:
|
|||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
|
self.embedding_options = embedding_options or {}
|
||||||
|
|
||||||
# Check if we need to use cosine distance for normalized embeddings
|
# Check if we need to use cosine distance for normalized embeddings
|
||||||
normalized_embeddings_models = {
|
normalized_embeddings_models = {
|
||||||
@@ -407,6 +411,7 @@ class LeannBuilder:
|
|||||||
self.embedding_model,
|
self.embedding_model,
|
||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)[0]
|
)[0]
|
||||||
)
|
)
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
@@ -446,6 +451,7 @@ class LeannBuilder:
|
|||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
is_build=True,
|
is_build=True,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
@@ -472,6 +478,9 @@ class LeannBuilder:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.embedding_options:
|
||||||
|
meta_data["embedding_options"] = self.embedding_options
|
||||||
|
|
||||||
# Add storage status flags for HNSW backend
|
# Add storage status flags for HNSW backend
|
||||||
if self.backend_name == "hnsw":
|
if self.backend_name == "hnsw":
|
||||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
@@ -592,6 +601,9 @@ class LeannBuilder:
|
|||||||
"embeddings_source": str(embeddings_file),
|
"embeddings_source": str(embeddings_file),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.embedding_options:
|
||||||
|
meta_data["embedding_options"] = self.embedding_options
|
||||||
|
|
||||||
# Add storage status flags for HNSW backend
|
# Add storage status flags for HNSW backend
|
||||||
if self.backend_name == "hnsw":
|
if self.backend_name == "hnsw":
|
||||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
@@ -673,6 +685,7 @@ class LeannBuilder:
|
|||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
is_build=True,
|
is_build=True,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_dim = embeddings.shape[1]
|
embedding_dim = embeddings.shape[1]
|
||||||
@@ -771,6 +784,7 @@ class LeannSearcher:
|
|||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
|
self.embedding_options = self.meta_data.get("embedding_options", {})
|
||||||
# Delegate portability handling to PassageManager
|
# Delegate portability handling to PassageManager
|
||||||
self.passage_manager = PassageManager(
|
self.passage_manager = PassageManager(
|
||||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||||
@@ -782,6 +796,8 @@ class LeannSearcher:
|
|||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = enable_warmup
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
|
if self.embedding_options:
|
||||||
|
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
index_path, **final_kwargs
|
index_path, **final_kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(
|
def validate_model_and_suggest(
|
||||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
model_name: str, llm_type: str, host: Optional[str] = None
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models(host)
|
resolved_host = resolve_ollama_host(host)
|
||||||
|
available_models = check_ollama_models(resolved_host)
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
|
||||||
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
|
|||||||
class OllamaChat(LLMInterface):
|
class OllamaChat(LLMInterface):
|
||||||
"""LLM interface for Ollama models."""
|
"""LLM interface for Ollama models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.host = host
|
self.host = resolve_ollama_host(host)
|
||||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# Check if the Ollama server is responsive
|
# Check if the Ollama server is responsive
|
||||||
if host:
|
if self.host:
|
||||||
requests.get(host)
|
requests.get(self.host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
model_error = validate_model_and_suggest(model, "ollama", self.host)
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
|
|||||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
logger.error(
|
||||||
|
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||||
|
)
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
@@ -737,21 +742,31 @@ class GeminiChat(LLMInterface):
|
|||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.base_url = resolve_openai_base_url(base_url)
|
||||||
|
self.api_key = resolve_openai_api_key(api_key)
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
logger.info(
|
||||||
|
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
self.client = openai.OpenAI(api_key=self.api_key)
|
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||||
@@ -841,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
return OllamaChat(
|
return OllamaChat(
|
||||||
model=model or "llama3:8b",
|
model=model or "llama3:8b",
|
||||||
host=llm_config.get("host", "http://localhost:11434"),
|
host=llm_config.get("host"),
|
||||||
)
|
)
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
return OpenAIChat(
|
||||||
|
model=model or "gpt-4o",
|
||||||
|
api_key=llm_config.get("api_key"),
|
||||||
|
base_url=llm_config.get("base_url"),
|
||||||
|
)
|
||||||
elif llm_type == "gemini":
|
elif llm_type == "gemini":
|
||||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .registry import register_project_directory
|
from .registry import register_project_directory
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||||
@@ -123,6 +124,24 @@ Examples:
|
|||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode (default: sentence-transformers)",
|
help="Embedding backend mode (default: sentence-transformers)",
|
||||||
)
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||||
)
|
)
|
||||||
@@ -248,7 +267,12 @@ Examples:
|
|||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
||||||
)
|
)
|
||||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
ask_parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
|
)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
||||||
)
|
)
|
||||||
@@ -277,6 +301,18 @@ Examples:
|
|||||||
default=None,
|
default=None,
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
)
|
)
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
|
||||||
|
)
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
subparsers.add_parser("list", help="List all indexes")
|
subparsers.add_parser("list", help="List all indexes")
|
||||||
@@ -1325,10 +1361,20 @@ Examples:
|
|||||||
|
|
||||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend,
|
backend_name=args.backend,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
graph_degree=args.graph_degree,
|
graph_degree=args.graph_degree,
|
||||||
complexity=args.complexity,
|
complexity=args.complexity,
|
||||||
is_compact=args.compact,
|
is_compact=args.compact,
|
||||||
@@ -1476,7 +1522,12 @@ Examples:
|
|||||||
|
|
||||||
llm_config = {"type": args.llm, "model": args.model}
|
llm_config = {"type": args.llm, "model": args.model}
|
||||||
if args.llm == "ollama":
|
if args.llm == "ollama":
|
||||||
llm_config["host"] = args.host
|
llm_config["host"] = resolve_ollama_host(args.host)
|
||||||
|
elif args.llm == "openai":
|
||||||
|
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
|
||||||
|
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||||
|
if resolved_api_key:
|
||||||
|
llm_config["api_key"] = resolved_api_key
|
||||||
|
|
||||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||||
|
|
||||||
|
|||||||
@@ -7,11 +7,13 @@ Preserves all optimization parameters to ensure performance
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -31,6 +33,7 @@ def compute_embeddings(
|
|||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
manual_tokenize: bool = False,
|
manual_tokenize: bool = False,
|
||||||
max_length: int = 512,
|
max_length: int = 512,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Unified embedding computation entry point
|
Unified embedding computation entry point
|
||||||
@@ -46,6 +49,8 @@ def compute_embeddings(
|
|||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
"""
|
"""
|
||||||
|
provider_options = provider_options or {}
|
||||||
|
|
||||||
if mode == "sentence-transformers":
|
if mode == "sentence-transformers":
|
||||||
return compute_embeddings_sentence_transformers(
|
return compute_embeddings_sentence_transformers(
|
||||||
texts,
|
texts,
|
||||||
@@ -57,11 +62,21 @@ def compute_embeddings(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
base_url=provider_options.get("base_url"),
|
||||||
|
api_key=provider_options.get("api_key"),
|
||||||
|
)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
elif mode == "ollama":
|
elif mode == "ollama":
|
||||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
return compute_embeddings_ollama(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
is_build=is_build,
|
||||||
|
host=provider_options.get("host"),
|
||||||
|
)
|
||||||
elif mode == "gemini":
|
elif mode == "gemini":
|
||||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||||
else:
|
else:
|
||||||
@@ -353,12 +368,15 @@ def compute_embeddings_sentence_transformers(
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
def compute_embeddings_openai(
|
||||||
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
try:
|
try:
|
||||||
import os
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"OpenAI package not installed: {e}")
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
@@ -373,16 +391,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||||
)
|
)
|
||||||
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
resolved_base_url = resolve_openai_base_url(base_url)
|
||||||
if not api_key:
|
resolved_api_key = resolve_openai_api_key(api_key)
|
||||||
|
|
||||||
|
if not resolved_api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
# Cache OpenAI client
|
# Cache OpenAI client
|
||||||
cache_key = "openai_client"
|
cache_key = f"openai_client::{resolved_base_url}"
|
||||||
if cache_key in _model_cache:
|
if cache_key in _model_cache:
|
||||||
client = _model_cache[cache_key]
|
client = _model_cache[cache_key]
|
||||||
else:
|
else:
|
||||||
client = openai.OpenAI(api_key=api_key)
|
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||||
_model_cache[cache_key] = client
|
_model_cache[cache_key] = client
|
||||||
logger.info("OpenAI client cached")
|
logger.info("OpenAI client cached")
|
||||||
|
|
||||||
@@ -507,7 +527,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings_ollama(
|
def compute_embeddings_ollama(
|
||||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
is_build: bool = False,
|
||||||
|
host: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with simplified batch processing.
|
Compute embeddings using Ollama API with simplified batch processing.
|
||||||
@@ -518,7 +541,7 @@ def compute_embeddings_ollama(
|
|||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
host: Ollama host URL (default: http://localhost:11434)
|
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
@@ -533,17 +556,19 @@ def compute_embeddings_ollama(
|
|||||||
if not texts:
|
if not texts:
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
|
|
||||||
|
resolved_host = resolve_ollama_host(host)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Ollama is running
|
# Check if Ollama is running
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{host}/api/version", timeout=5)
|
response = requests.get(f"{resolved_host}/api/version", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
|
||||||
"Please ensure Ollama is running:\n"
|
"Please ensure Ollama is running:\n"
|
||||||
" • macOS/Linux: ollama serve\n"
|
" • macOS/Linux: ollama serve\n"
|
||||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||||
@@ -555,7 +580,7 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
# Check if model exists and provide helpful suggestions
|
# Check if model exists and provide helpful suggestions
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
models = response.json()
|
models = response.json()
|
||||||
model_names = [model["name"] for model in models.get("models", [])]
|
model_names = [model["name"] for model in models.get("models", [])]
|
||||||
@@ -618,7 +643,9 @@ def compute_embeddings_ollama(
|
|||||||
# Verify the model supports embeddings by testing it
|
# Verify the model supports embeddings by testing it
|
||||||
try:
|
try:
|
||||||
test_response = requests.post(
|
test_response = requests.post(
|
||||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
f"{resolved_host}/api/embeddings",
|
||||||
|
json={"model": model_name, "prompt": "test"},
|
||||||
|
timeout=10,
|
||||||
)
|
)
|
||||||
if test_response.status_code != 200:
|
if test_response.status_code != 200:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
@@ -665,7 +692,7 @@ def compute_embeddings_ollama(
|
|||||||
while retry_count < max_retries:
|
while retry_count < max_retries:
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{host}/api/embeddings",
|
f"{resolved_host}/api/embeddings",
|
||||||
json={"model": model_name, "prompt": truncated_text},
|
json={"model": model_name, "prompt": truncated_text},
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from .settings import encode_provider_options
|
||||||
|
|
||||||
# Lightweight, self-contained server manager with no cross-process inspection
|
# Lightweight, self-contained server manager with no cross-process inspection
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
@@ -82,16 +84,40 @@ class EmbeddingServerManager:
|
|||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||||
|
provider_options = kwargs.pop("provider_options", None)
|
||||||
|
|
||||||
|
config_signature = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
|
||||||
# If this manager already has a live server, just reuse it
|
# If this manager already has a live server, just reuse it
|
||||||
if self.server_process and self.server_process.poll() is None and self.server_port:
|
if (
|
||||||
|
self.server_process
|
||||||
|
and self.server_process.poll() is None
|
||||||
|
and self.server_port
|
||||||
|
and self._server_config == config_signature
|
||||||
|
):
|
||||||
logger.info("Reusing in-process server")
|
logger.info("Reusing in-process server")
|
||||||
return True, self.server_port
|
return True, self.server_port
|
||||||
|
|
||||||
|
# Configuration changed, stop existing server before starting a new one
|
||||||
|
if self.server_process and self.server_process.poll() is None:
|
||||||
|
logger.info("Existing server configuration differs; restarting embedding server")
|
||||||
|
self.stop_server()
|
||||||
|
|
||||||
# For Colab environment, use a different strategy
|
# For Colab environment, use a different strategy
|
||||||
if _is_colab_environment():
|
if _is_colab_environment():
|
||||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
return self._start_server_colab(
|
||||||
|
port,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=provider_options,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Always pick a fresh available port
|
# Always pick a fresh available port
|
||||||
try:
|
try:
|
||||||
@@ -101,13 +127,21 @@ class EmbeddingServerManager:
|
|||||||
return False, port
|
return False, port
|
||||||
|
|
||||||
# Start a new server
|
# Start a new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(
|
||||||
|
actual_port,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=provider_options,
|
||||||
|
config_signature=config_signature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _start_server_colab(
|
def _start_server_colab(
|
||||||
self,
|
self,
|
||||||
port: int,
|
port: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start server with Colab-specific configuration."""
|
"""Start server with Colab-specific configuration."""
|
||||||
@@ -125,8 +159,20 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# In Colab, we'll use a more direct approach
|
# In Colab, we'll use a more direct approach
|
||||||
self._launch_server_process_colab(command, actual_port)
|
self._launch_server_process_colab(
|
||||||
return self._wait_for_server_ready_colab(actual_port)
|
command,
|
||||||
|
actual_port,
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||||
|
if started:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
return started, ready_port
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||||
return False, actual_port
|
return False, actual_port
|
||||||
@@ -134,7 +180,13 @@ class EmbeddingServerManager:
|
|||||||
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self,
|
||||||
|
port: int,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
config_signature: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start a new embedding server on the given port."""
|
"""Start a new embedding server on the given port."""
|
||||||
logger.info(f"Starting embedding server on port {port}...")
|
logger.info(f"Starting embedding server on port {port}...")
|
||||||
@@ -142,8 +194,20 @@ class EmbeddingServerManager:
|
|||||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._launch_server_process(command, port)
|
self._launch_server_process(
|
||||||
return self._wait_for_server_ready(port)
|
command,
|
||||||
|
port,
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
started, ready_port = self._wait_for_server_ready(port)
|
||||||
|
if started:
|
||||||
|
self._server_config = config_signature or {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
return started, ready_port
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start embedding server: {e}")
|
logger.error(f"Failed to start embedding server: {e}")
|
||||||
return False, port
|
return False, port
|
||||||
@@ -173,7 +237,12 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
return command
|
return command
|
||||||
|
|
||||||
def _launch_server_process(self, command: list, port: int) -> None:
|
def _launch_server_process(
|
||||||
|
self,
|
||||||
|
command: list,
|
||||||
|
port: int,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
"""Launch the server process."""
|
"""Launch the server process."""
|
||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
logger.info(f"Command: {' '.join(command)}")
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
@@ -193,14 +262,20 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
# Start embedding server subprocess
|
# Start embedding server subprocess
|
||||||
logger.info(f"Starting server process with command: {' '.join(command)}")
|
logger.info(f"Starting server process with command: {' '.join(command)}")
|
||||||
|
env = os.environ.copy()
|
||||||
|
encoded_options = encode_provider_options(provider_options)
|
||||||
|
if encoded_options:
|
||||||
|
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=stdout_target,
|
stdout=stdout_target,
|
||||||
stderr=stderr_target,
|
stderr=stderr_target,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
# Record config for in-process reuse
|
# Record config for in-process reuse (best effort; refined later when ready)
|
||||||
try:
|
try:
|
||||||
self._server_config = {
|
self._server_config = {
|
||||||
"model_name": command[command.index("--model-name") + 1]
|
"model_name": command[command.index("--model-name") + 1]
|
||||||
@@ -212,12 +287,14 @@ class EmbeddingServerManager:
|
|||||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||||
if "--embedding-mode" in command
|
if "--embedding-mode" in command
|
||||||
else "sentence-transformers",
|
else "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
self._server_config = {
|
self._server_config = {
|
||||||
"model_name": "",
|
"model_name": "",
|
||||||
"passages_file": "",
|
"passages_file": "",
|
||||||
"embedding_mode": "sentence-transformers",
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
}
|
}
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
@@ -322,16 +399,27 @@ class EmbeddingServerManager:
|
|||||||
# Removed: cross-process adoption no longer supported
|
# Removed: cross-process adoption no longer supported
|
||||||
return
|
return
|
||||||
|
|
||||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
def _launch_server_process_colab(
|
||||||
|
self,
|
||||||
|
command: list,
|
||||||
|
port: int,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
"""Launch the server process with Colab-specific settings."""
|
"""Launch the server process with Colab-specific settings."""
|
||||||
logger.info(f"Colab Command: {' '.join(command)}")
|
logger.info(f"Colab Command: {' '.join(command)}")
|
||||||
|
|
||||||
# In Colab, we need to be more careful about process management
|
# In Colab, we need to be more careful about process management
|
||||||
|
env = os.environ.copy()
|
||||||
|
encoded_options = encode_provider_options(provider_options)
|
||||||
|
if encoded_options:
|
||||||
|
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
text=True,
|
text=True,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||||
@@ -345,6 +433,7 @@ class EmbeddingServerManager:
|
|||||||
"model_name": "",
|
"model_name": "",
|
||||||
"passages_file": "",
|
"passages_file": "",
|
||||||
"embedding_mode": "sentence-transformers",
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
||||||
|
|
||||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
self.embedding_options = self.meta.get("embedding_options", {})
|
||||||
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
backend_module_name=backend_module_name,
|
backend_module_name=backend_module_name,
|
||||||
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
distance_metric=distance_metric,
|
distance_metric=distance_metric,
|
||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||||
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
from .embedding_compute import compute_embeddings
|
from .embedding_compute import compute_embeddings
|
||||||
|
|
||||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
return compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.embedding_model,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
|
)
|
||||||
|
|
||||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||||
"""Compute embeddings using the ZMQ embedding server."""
|
"""Compute embeddings using the ZMQ embedding server."""
|
||||||
|
|||||||
74
packages/leann-core/src/leann/settings.py
Normal file
74
packages/leann-core/src/leann/settings.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Runtime configuration helpers for LEANN."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||||
|
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||||
|
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_url(value: str) -> str:
|
||||||
|
"""Normalize URL strings by stripping trailing slashes."""
|
||||||
|
|
||||||
|
return value.rstrip("/") if value else value
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_ollama_host(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the Ollama-compatible endpoint to use."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_LOCAL_LLM_HOST"),
|
||||||
|
os.getenv("LEANN_OLLAMA_HOST"),
|
||||||
|
os.getenv("OLLAMA_HOST"),
|
||||||
|
os.getenv("LOCAL_LLM_ENDPOINT"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_OLLAMA_HOST)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the base URL for OpenAI-compatible services."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_OPENAI_BASE_URL"),
|
||||||
|
os.getenv("OPENAI_BASE_URL"),
|
||||||
|
os.getenv("LOCAL_OPENAI_BASE_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||||
|
"""Resolve the API key for OpenAI-compatible services."""
|
||||||
|
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
|
||||||
|
return os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||||
|
"""Serialize provider options for child processes."""
|
||||||
|
|
||||||
|
if not options:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.dumps(options)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
# Fall back to empty payload if serialization fails
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user