Add Anthropic LLM support (#185)
* Add Anthropic LLM support Signed-off-by: droctothorpe <mythicalsunlight@gmail.com> * Update skypilot link Signed-off-by: droctothorpe <mythicalsunlight@gmail.com> * Handle anthropic base_url Signed-off-by: droctothorpe <mythicalsunlight@gmail.com> * Address ruff format finding Signed-off-by: droctothorpe <mythicalsunlight@gmail.com> --------- Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>
This commit is contained in:
13
README.md
13
README.md
@@ -201,7 +201,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
|
|||||||
|
|
||||||
#### LLM Backend
|
#### LLM Backend
|
||||||
|
|
||||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
|
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -269,6 +269,7 @@ Below is a list of base URLs for common providers to get you started.
|
|||||||
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||||
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||||
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||||
|
| **Anthropic** | `https://api.anthropic.com/v1` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -328,7 +329,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
|||||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||||
|
|
||||||
# LLM Parameters (Text generation models)
|
# LLM Parameters (Text generation models)
|
||||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
|
||||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||||
|
|
||||||
@@ -1057,10 +1058,10 @@ Options:
|
|||||||
leann ask INDEX_NAME [OPTIONS]
|
leann ask INDEX_NAME [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
|
||||||
--model MODEL Model name (default: qwen3:8b)
|
--model MODEL Model name (default: qwen3:8b)
|
||||||
--interactive Interactive chat mode
|
--interactive Interactive chat mode
|
||||||
--top-k N Retrieval count (default: 20)
|
--top-k N Retrieval count (default: 20)
|
||||||
```
|
```
|
||||||
|
|
||||||
**List Command:**
|
**List Command:**
|
||||||
|
|||||||
@@ -454,7 +454,7 @@ leann search my-index "your query" \
|
|||||||
|
|
||||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||||
|
|
||||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# One-time: install and configure SkyPilot
|
# One-time: install and configure SkyPilot
|
||||||
|
|||||||
@@ -12,7 +12,13 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import (
|
||||||
|
resolve_anthropic_api_key,
|
||||||
|
resolve_anthropic_base_url,
|
||||||
|
resolve_ollama_host,
|
||||||
|
resolve_openai_api_key,
|
||||||
|
resolve_openai_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -845,6 +851,81 @@ class OpenAIChat(LLMInterface):
|
|||||||
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicChat(LLMInterface):
|
||||||
|
"""LLM interface for Anthropic Claude models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "claude-haiku-4-5",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.base_url = resolve_anthropic_base_url(base_url)
|
||||||
|
self.api_key = resolve_anthropic_api_key(api_key)
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
# Allow custom Anthropic-compatible endpoints via base_url
|
||||||
|
self.client = anthropic.Anthropic(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
logger.info(f"Sending request to Anthropic with model {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Anthropic API parameters
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if "temperature" in kwargs:
|
||||||
|
params["temperature"] = kwargs["temperature"]
|
||||||
|
if "top_p" in kwargs:
|
||||||
|
params["top_p"] = kwargs["top_p"]
|
||||||
|
|
||||||
|
response = self.client.messages.create(**params)
|
||||||
|
|
||||||
|
# Extract text from response
|
||||||
|
response_text = response.content[0].text
|
||||||
|
|
||||||
|
# Log token usage
|
||||||
|
print(
|
||||||
|
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
|
||||||
|
f"input tokens = {response.usage.input_tokens}, "
|
||||||
|
f"output tokens = {response.usage.output_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.stop_reason == "max_tokens":
|
||||||
|
print("The query is exceeding the maximum allowed number of tokens")
|
||||||
|
|
||||||
|
return response_text.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error communicating with Anthropic: {e}")
|
||||||
|
return f"Error: Could not get a response from Anthropic. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
class SimulatedChat(LLMInterface):
|
class SimulatedChat(LLMInterface):
|
||||||
"""A simple simulated chat for testing and development."""
|
"""A simple simulated chat for testing and development."""
|
||||||
|
|
||||||
@@ -897,6 +978,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
)
|
)
|
||||||
elif llm_type == "gemini":
|
elif llm_type == "gemini":
|
||||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
|
elif llm_type == "anthropic":
|
||||||
|
return AnthropicChat(
|
||||||
|
model=model or "claude-3-5-sonnet-20241022",
|
||||||
|
api_key=llm_config.get("api_key"),
|
||||||
|
base_url=llm_config.get("base_url"),
|
||||||
|
)
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
return SimulatedChat()
|
return SimulatedChat()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -11,7 +11,12 @@ from tqdm import tqdm
|
|||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .interactive_utils import create_cli_session
|
from .interactive_utils import create_cli_session
|
||||||
from .registry import register_project_directory
|
from .registry import register_project_directory
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import (
|
||||||
|
resolve_anthropic_base_url,
|
||||||
|
resolve_ollama_host,
|
||||||
|
resolve_openai_api_key,
|
||||||
|
resolve_openai_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||||
@@ -291,7 +296,7 @@ Examples:
|
|||||||
"--llm",
|
"--llm",
|
||||||
type=str,
|
type=str,
|
||||||
default="ollama",
|
default="ollama",
|
||||||
choices=["simulated", "ollama", "hf", "openai"],
|
choices=["simulated", "ollama", "hf", "openai", "anthropic"],
|
||||||
help="LLM provider (default: ollama)",
|
help="LLM provider (default: ollama)",
|
||||||
)
|
)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
@@ -341,7 +346,7 @@ Examples:
|
|||||||
"--api-key",
|
"--api-key",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
help="API key for cloud LLM providers (OpenAI, Anthropic)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
@@ -1616,6 +1621,12 @@ Examples:
|
|||||||
resolved_api_key = resolve_openai_api_key(args.api_key)
|
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||||
if resolved_api_key:
|
if resolved_api_key:
|
||||||
llm_config["api_key"] = resolved_api_key
|
llm_config["api_key"] = resolved_api_key
|
||||||
|
elif args.llm == "anthropic":
|
||||||
|
# For Anthropic, pass base_url and API key if provided
|
||||||
|
if args.api_base:
|
||||||
|
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
|
||||||
|
if args.api_key:
|
||||||
|
llm_config["api_key"] = args.api_key
|
||||||
|
|
||||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Any
|
|||||||
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||||
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||||
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
|
||||||
|
|
||||||
|
|
||||||
def _clean_url(value: str) -> str:
|
def _clean_url(value: str) -> str:
|
||||||
@@ -52,6 +53,23 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
|
|||||||
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the base URL for Anthropic-compatible services."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
|
||||||
|
os.getenv("ANTHROPIC_BASE_URL"),
|
||||||
|
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||||
"""Resolve the API key for OpenAI-compatible services."""
|
"""Resolve the API key for OpenAI-compatible services."""
|
||||||
|
|
||||||
@@ -61,6 +79,15 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
|||||||
return os.getenv("OPENAI_API_KEY")
|
return os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
|
||||||
|
"""Resolve the API key for Anthropic services."""
|
||||||
|
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
|
||||||
|
return os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||||
"""Serialize provider options for child processes."""
|
"""Serialize provider options for child processes."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user