Compare commits

..

1 Commits

Author SHA1 Message Date
yichuan-w
aaadb00e44 Add ColQwen2.5 model support and improve model selection
- Add ColQwen2.5 and ColQwen2_5_Processor imports
- Implement smart model type detection for colqwen2, colqwen2.5, and colpali
- Add task name aliases for easier benchmark invocation
- Add safe model name handling for file paths and index naming
- Support custom model paths including LoRA adapters
- Improve model choice validation and error handling

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-05 11:35:30 +00:00
9 changed files with 28 additions and 274 deletions

View File

@@ -201,7 +201,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
#### LLM Backend
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
<details>
@@ -269,7 +269,6 @@ Below is a list of base URLs for common providers to get you started.
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
| **Mistral AI** | `https://api.mistral.ai/v1` |
| **Anthropic** | `https://api.anthropic.com/v1` |
@@ -329,7 +328,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
# LLM Parameters (Text generation models)
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
@@ -1058,10 +1057,10 @@ Options:
leann ask INDEX_NAME [OPTIONS]
Options:
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
--model MODEL Model name (default: qwen3:8b)
--interactive Interactive chat mode
--top-k N Retrieval count (default: 20)
--llm {ollama,openai,hf} LLM provider (default: ollama)
--model MODEL Model name (default: qwen3:8b)
--interactive Interactive chat mode
--top-k N Retrieval count (default: 20)
```
**List Command:**

View File

@@ -1,7 +1,5 @@
import concurrent.futures
import glob
import json
import logging
import os
import re
import sys
@@ -13,8 +11,6 @@ import numpy as np
from PIL import Image
from tqdm import tqdm
logger = logging.getLogger(__name__)
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
@@ -100,63 +96,12 @@ def _natural_sort_key(name: str) -> int:
return int(m.group()) if m else 0
def _load_images_from_dir(
pages_dir: str, recursive: bool = False
) -> tuple[list[str], list[Image.Image]]:
"""
Load images from a directory.
Args:
pages_dir: Directory path containing images
recursive: If True, recursively search subdirectories (default: False)
Returns:
Tuple of (filepaths, images)
"""
# Supported image extensions
extensions = ("*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG", "*.webp", "*.WEBP")
if recursive:
# Recursive search
filepaths = []
for ext in extensions:
pattern = os.path.join(pages_dir, "**", ext)
filepaths.extend(glob.glob(pattern, recursive=True))
else:
# Non-recursive search (only top-level directory)
filepaths = []
for ext in extensions:
pattern = os.path.join(pages_dir, ext)
filepaths.extend(glob.glob(pattern))
# Sort files naturally
filepaths = sorted(filepaths, key=lambda x: _natural_sort_key(os.path.basename(x)))
# Load images with error handling
images = []
valid_filepaths = []
failed_count = 0
for filepath in filepaths:
try:
img = Image.open(filepath)
# Convert to RGB if necessary (handles RGBA, P, etc.)
if img.mode != "RGB":
img = img.convert("RGB")
images.append(img)
valid_filepaths.append(filepath)
except Exception as e:
failed_count += 1
print(f"Warning: Failed to load image {filepath}: {e}")
continue
if failed_count > 0:
print(
f"Warning: Failed to load {failed_count} image(s) out of {len(filepaths)} total files"
)
return valid_filepaths, images
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
@@ -206,8 +151,6 @@ def _select_device_and_dtype():
def _load_colvision(model_choice: str):
import os
import torch
from colpali_engine.models import (
ColPali,
@@ -219,16 +162,6 @@ def _load_colvision(model_choice: str):
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
# Set environment variables to ensure models are downloaded from HuggingFace
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
# Log model loading info
logger.info(f"Loading ColVision model: {model_choice}")
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
device_str, device, dtype = _select_device_and_dtype()
# Determine model name and type
@@ -269,36 +202,29 @@ def _load_colvision(model_choice: str):
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
)
# Load model from HuggingFace Hub (not Google Drive)
# Use local_files_only=False to ensure download from HF if not cached
if model_type == "colqwen2.5":
model = ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval()
processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
processor = ColQwen2_5_Processor.from_pretrained(model_name)
elif model_type == "colqwen2":
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval()
processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False)
processor = ColQwen2Processor.from_pretrained(model_name)
else: # colpali
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval()
processor = cast(
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
)
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype

View File

@@ -62,7 +62,7 @@ DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
DATASET_NAMES = [
"weaviate/arXiv-AI-papers-multi-vector",
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
]
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
# Set to None to try loading all available splits automatically
@@ -75,11 +75,6 @@ MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False)
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages"
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
# If set, images will be loaded directly from this folder
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
# Whether to recursively search subdirectories when loading from custom folder
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
# Index + retrieval settings
# Use a different index path for larger dataset to avoid overwriting existing index
@@ -88,7 +83,7 @@ INDEX_PATH: str = "./indexes/colvision_large.leann"
# These are now command-line arguments (see CLI overrides section)
TOPK: int = 3
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
REBUILD_INDEX: bool = True
# Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
@@ -133,33 +128,12 @@ parser.add_argument(
default=TOPK,
help=f"Number of top results to retrieve. Default: {TOPK}",
)
parser.add_argument(
"--custom-folder",
type=str,
default=None,
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
)
parser.add_argument(
"--recursive",
action="store_true",
default=False,
help="Recursively search subdirectories when loading images from custom folder. Default: False",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
default=False,
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
)
cli_args, _unknown = parser.parse_known_args()
SEARCH_METHOD: str = cli_args.search_method
QUERY = cli_args.query # Override QUERY with CLI argument if provided
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
# %%
@@ -206,23 +180,7 @@ else:
# Step 2: Load data only if we need to build the index
if need_to_build_index:
print("Loading dataset...")
# Check for custom folder path first (takes precedence)
if CUSTOM_FOLDER_PATH:
if not os.path.isdir(CUSTOM_FOLDER_PATH):
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
if CUSTOM_FOLDER_RECURSIVE:
print(" (recursive mode: searching subdirectories)")
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
print(f" Found {len(filepaths)} image files")
if not images:
raise RuntimeError(
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
)
print(f" Successfully loaded {len(images)} images")
# Use filenames as identifiers instead of full paths for cleaner metadata
filepaths = [os.path.basename(fp) for fp in filepaths]
elif USE_HF_DATASET:
if USE_HF_DATASET:
from datasets import load_dataset, concatenate_datasets, DatasetDict
# Determine which datasets to load
@@ -663,6 +621,7 @@ else:
except Exception:
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 6: Similarity maps for top-K results

View File

@@ -454,7 +454,7 @@ leann search my-index "your query" \
### 2) Run remote builds with SkyPilot (cloud GPU)
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
```bash
# One-time: install and configure SkyPilot

View File

@@ -1251,15 +1251,15 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge."
)
logger.info("The context provided to the LLM is:")
logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
logger.info("-" * 150)
print("The context provided to the LLM is:")
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
print("-" * 150)
for r in results:
chunk_relevance = f"{r.score:.3f}"
chunk_id = r.id
chunk_content = r.text[:60]
chunk_source = r.metadata.get("source", "")[:80]
logger.info(
print(
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
)
ask_time = time.time()

View File

@@ -12,13 +12,7 @@ from typing import Any, Optional
import torch
from .settings import (
resolve_anthropic_api_key,
resolve_anthropic_base_url,
resolve_ollama_host,
resolve_openai_api_key,
resolve_openai_base_url,
)
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -851,81 +845,6 @@ class OpenAIChat(LLMInterface):
return f"Error: Could not get a response from OpenAI. Details: {e}"
class AnthropicChat(LLMInterface):
"""LLM interface for Anthropic Claude models."""
def __init__(
self,
model: str = "claude-haiku-4-5",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model
self.base_url = resolve_anthropic_base_url(base_url)
self.api_key = resolve_anthropic_api_key(api_key)
if not self.api_key:
raise ValueError(
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
)
logger.info(
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try:
import anthropic
# Allow custom Anthropic-compatible endpoints via base_url
self.client = anthropic.Anthropic(
api_key=self.api_key,
base_url=self.base_url,
)
except ImportError:
raise ImportError(
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
)
def ask(self, prompt: str, **kwargs) -> str:
logger.info(f"Sending request to Anthropic with model {self.model}")
try:
# Anthropic API parameters
params = {
"model": self.model,
"max_tokens": kwargs.get("max_tokens", 1000),
"messages": [{"role": "user", "content": prompt}],
}
# Add optional parameters
if "temperature" in kwargs:
params["temperature"] = kwargs["temperature"]
if "top_p" in kwargs:
params["top_p"] = kwargs["top_p"]
response = self.client.messages.create(**params)
# Extract text from response
response_text = response.content[0].text
# Log token usage
print(
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
f"input tokens = {response.usage.input_tokens}, "
f"output tokens = {response.usage.output_tokens}"
)
if response.stop_reason == "max_tokens":
print("The query is exceeding the maximum allowed number of tokens")
return response_text.strip()
except Exception as e:
logger.error(f"Error communicating with Anthropic: {e}")
return f"Error: Could not get a response from Anthropic. Details: {e}"
class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development."""
@@ -978,12 +897,6 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
)
elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "anthropic":
return AnthropicChat(
model=model or "claude-3-5-sonnet-20241022",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "simulated":
return SimulatedChat()
else:

View File

@@ -11,12 +11,7 @@ from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher
from .interactive_utils import create_cli_session
from .registry import register_project_directory
from .settings import (
resolve_anthropic_base_url,
resolve_ollama_host,
resolve_openai_api_key,
resolve_openai_base_url,
)
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
@@ -296,7 +291,7 @@ Examples:
"--llm",
type=str,
default="ollama",
choices=["simulated", "ollama", "hf", "openai", "anthropic"],
choices=["simulated", "ollama", "hf", "openai"],
help="LLM provider (default: ollama)",
)
ask_parser.add_argument(
@@ -346,7 +341,7 @@ Examples:
"--api-key",
type=str,
default=None,
help="API key for cloud LLM providers (OpenAI, Anthropic)",
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# List command
@@ -1621,12 +1616,6 @@ Examples:
resolved_api_key = resolve_openai_api_key(args.api_key)
if resolved_api_key:
llm_config["api_key"] = resolved_api_key
elif args.llm == "anthropic":
# For Anthropic, pass base_url and API key if provided
if args.api_base:
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
if args.api_key:
llm_config["api_key"] = args.api_key
chat = LeannChat(index_path=index_path, llm_config=llm_config)

View File

@@ -9,7 +9,6 @@ from typing import Any
# Default fallbacks to preserve current behaviour while keeping them in one place.
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
def _clean_url(value: str) -> str:
@@ -53,23 +52,6 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
"""Resolve the base URL for Anthropic-compatible services."""
candidates = (
explicit,
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
os.getenv("ANTHROPIC_BASE_URL"),
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for OpenAI-compatible services."""
@@ -79,15 +61,6 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
return os.getenv("OPENAI_API_KEY")
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for Anthropic services."""
if explicit:
return explicit
return os.getenv("ANTHROPIC_API_KEY")
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
"""Serialize provider options for child processes."""

View File

@@ -53,11 +53,6 @@ leann build my-project --docs $(git ls-files)
# Start Claude Code
claude
```
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
```bash
leann build my-project --docs $(git ls-files) --no-recompute
```
## 🚀 Advanced Usage Examples to build the index