Compare commits
7 Commits
feat/add-c
...
feature/cu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9996c29618 | ||
|
|
12951ad4d5 | ||
|
|
a878d2459b | ||
|
|
6c39a3427f | ||
|
|
17cbd07b25 | ||
|
|
3629ccf8f7 | ||
|
|
a0bbf831db |
13
README.md
13
README.md
@@ -201,7 +201,7 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
|
|||||||
|
|
||||||
#### LLM Backend
|
#### LLM Backend
|
||||||
|
|
||||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
|
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -269,6 +269,7 @@ Below is a list of base URLs for common providers to get you started.
|
|||||||
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||||
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||||
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||||
|
| **Anthropic** | `https://api.anthropic.com/v1` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -328,7 +329,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
|||||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||||
|
|
||||||
# LLM Parameters (Text generation models)
|
# LLM Parameters (Text generation models)
|
||||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
|
||||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||||
|
|
||||||
@@ -1057,10 +1058,10 @@ Options:
|
|||||||
leann ask INDEX_NAME [OPTIONS]
|
leann ask INDEX_NAME [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
|
||||||
--model MODEL Model name (default: qwen3:8b)
|
--model MODEL Model name (default: qwen3:8b)
|
||||||
--interactive Interactive chat mode
|
--interactive Interactive chat mode
|
||||||
--top-k N Retrieval count (default: 20)
|
--top-k N Retrieval count (default: 20)
|
||||||
```
|
```
|
||||||
|
|
||||||
**List Command:**
|
**List Command:**
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import glob
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
@@ -11,6 +13,8 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||||
@@ -96,12 +100,63 @@ def _natural_sort_key(name: str) -> int:
|
|||||||
return int(m.group()) if m else 0
|
return int(m.group()) if m else 0
|
||||||
|
|
||||||
|
|
||||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
def _load_images_from_dir(
|
||||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
pages_dir: str, recursive: bool = False
|
||||||
filenames = sorted(filenames, key=_natural_sort_key)
|
) -> tuple[list[str], list[Image.Image]]:
|
||||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
"""
|
||||||
images = [Image.open(p) for p in filepaths]
|
Load images from a directory.
|
||||||
return filepaths, images
|
|
||||||
|
Args:
|
||||||
|
pages_dir: Directory path containing images
|
||||||
|
recursive: If True, recursively search subdirectories (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filepaths, images)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Supported image extensions
|
||||||
|
extensions = ("*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG", "*.webp", "*.WEBP")
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
# Recursive search
|
||||||
|
filepaths = []
|
||||||
|
for ext in extensions:
|
||||||
|
pattern = os.path.join(pages_dir, "**", ext)
|
||||||
|
filepaths.extend(glob.glob(pattern, recursive=True))
|
||||||
|
else:
|
||||||
|
# Non-recursive search (only top-level directory)
|
||||||
|
filepaths = []
|
||||||
|
for ext in extensions:
|
||||||
|
pattern = os.path.join(pages_dir, ext)
|
||||||
|
filepaths.extend(glob.glob(pattern))
|
||||||
|
|
||||||
|
# Sort files naturally
|
||||||
|
filepaths = sorted(filepaths, key=lambda x: _natural_sort_key(os.path.basename(x)))
|
||||||
|
|
||||||
|
# Load images with error handling
|
||||||
|
images = []
|
||||||
|
valid_filepaths = []
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
for filepath in filepaths:
|
||||||
|
try:
|
||||||
|
img = Image.open(filepath)
|
||||||
|
# Convert to RGB if necessary (handles RGBA, P, etc.)
|
||||||
|
if img.mode != "RGB":
|
||||||
|
img = img.convert("RGB")
|
||||||
|
images.append(img)
|
||||||
|
valid_filepaths.append(filepath)
|
||||||
|
except Exception as e:
|
||||||
|
failed_count += 1
|
||||||
|
print(f"Warning: Failed to load image {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if failed_count > 0:
|
||||||
|
print(
|
||||||
|
f"Warning: Failed to load {failed_count} image(s) out of {len(filepaths)} total files"
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_filepaths, images
|
||||||
|
|
||||||
|
|
||||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||||
@@ -151,36 +206,99 @@ def _select_device_and_dtype():
|
|||||||
|
|
||||||
|
|
||||||
def _load_colvision(model_choice: str):
|
def _load_colvision(model_choice: str):
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
from colpali_engine.models import (
|
||||||
|
ColPali,
|
||||||
|
ColQwen2,
|
||||||
|
ColQwen2_5,
|
||||||
|
ColQwen2_5_Processor,
|
||||||
|
ColQwen2Processor,
|
||||||
|
)
|
||||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
|
||||||
|
# Set environment variables to ensure models are downloaded from HuggingFace
|
||||||
|
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
|
||||||
|
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||||
|
|
||||||
|
# Log model loading info
|
||||||
|
logger.info(f"Loading ColVision model: {model_choice}")
|
||||||
|
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
|
||||||
|
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
|
||||||
|
|
||||||
device_str, device, dtype = _select_device_and_dtype()
|
device_str, device, dtype = _select_device_and_dtype()
|
||||||
|
|
||||||
|
# Determine model name and type
|
||||||
|
# IMPORTANT: Check colqwen2.5 BEFORE colqwen2 to avoid false matches
|
||||||
|
model_choice_lower = model_choice.lower()
|
||||||
if model_choice == "colqwen2":
|
if model_choice == "colqwen2":
|
||||||
model_name = "vidore/colqwen2-v1.0"
|
model_name = "vidore/colqwen2-v1.0"
|
||||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
model_type = "colqwen2"
|
||||||
attn_implementation = (
|
elif model_choice == "colqwen2.5" or model_choice == "colqwen25":
|
||||||
"flash_attention_2"
|
model_name = "vidore/colqwen2.5-v0.2"
|
||||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
model_type = "colqwen2.5"
|
||||||
else "eager"
|
elif model_choice == "colpali":
|
||||||
)
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model_type = "colpali"
|
||||||
|
elif (
|
||||||
|
"colqwen2.5" in model_choice_lower
|
||||||
|
or "colqwen25" in model_choice_lower
|
||||||
|
or "colqwen2_5" in model_choice_lower
|
||||||
|
):
|
||||||
|
# Handle HuggingFace model names like "vidore/colqwen2.5-v0.2"
|
||||||
|
model_name = model_choice
|
||||||
|
model_type = "colqwen2.5"
|
||||||
|
elif "colqwen2" in model_choice_lower and "colqwen2-v1.0" in model_choice_lower:
|
||||||
|
# Handle HuggingFace model names like "vidore/colqwen2-v1.0" (but not colqwen2.5)
|
||||||
|
model_name = model_choice
|
||||||
|
model_type = "colqwen2"
|
||||||
|
elif "colpali" in model_choice_lower:
|
||||||
|
# Handle HuggingFace model names like "vidore/colpali-v1.2"
|
||||||
|
model_name = model_choice
|
||||||
|
model_type = "colpali"
|
||||||
|
else:
|
||||||
|
# Default to colpali for backward compatibility
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model_type = "colpali"
|
||||||
|
|
||||||
|
# Load model based on type
|
||||||
|
attn_implementation = (
|
||||||
|
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model from HuggingFace Hub (not Google Drive)
|
||||||
|
# Use local_files_only=False to ensure download from HF if not cached
|
||||||
|
if model_type == "colqwen2.5":
|
||||||
|
model = ColQwen2_5.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||||
|
).eval()
|
||||||
|
processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
|
||||||
|
elif model_type == "colqwen2":
|
||||||
model = ColQwen2.from_pretrained(
|
model = ColQwen2.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
|
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||||
).eval()
|
).eval()
|
||||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False)
|
||||||
else:
|
else: # colpali
|
||||||
model_name = "vidore/colpali-v1.2"
|
|
||||||
model = ColPali.from_pretrained(
|
model = ColPali.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
|
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||||
).eval()
|
).eval()
|
||||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
processor = cast(
|
||||||
|
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
|
||||||
|
)
|
||||||
|
|
||||||
return model_name, model, processor, device_str, device, dtype
|
return model_name, model, processor, device_str, device, dtype
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
|||||||
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||||
DATASET_NAMES = [
|
DATASET_NAMES = [
|
||||||
"weaviate/arXiv-AI-papers-multi-vector",
|
"weaviate/arXiv-AI-papers-multi-vector",
|
||||||
("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||||
]
|
]
|
||||||
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||||
# Set to None to try loading all available splits automatically
|
# Set to None to try loading all available splits automatically
|
||||||
@@ -75,6 +75,11 @@ MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
|||||||
# Local pages (used when USE_HF_DATASET == False)
|
# Local pages (used when USE_HF_DATASET == False)
|
||||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||||
PAGES_DIR: str = "./pages"
|
PAGES_DIR: str = "./pages"
|
||||||
|
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
|
||||||
|
# If set, images will be loaded directly from this folder
|
||||||
|
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
|
||||||
|
# Whether to recursively search subdirectories when loading from custom folder
|
||||||
|
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
|
||||||
|
|
||||||
# Index + retrieval settings
|
# Index + retrieval settings
|
||||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||||
@@ -83,7 +88,7 @@ INDEX_PATH: str = "./indexes/colvision_large.leann"
|
|||||||
# These are now command-line arguments (see CLI overrides section)
|
# These are now command-line arguments (see CLI overrides section)
|
||||||
TOPK: int = 3
|
TOPK: int = 3
|
||||||
FIRST_STAGE_K: int = 500
|
FIRST_STAGE_K: int = 500
|
||||||
REBUILD_INDEX: bool = True
|
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
|
||||||
|
|
||||||
# Artifacts
|
# Artifacts
|
||||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||||
@@ -128,12 +133,33 @@ parser.add_argument(
|
|||||||
default=TOPK,
|
default=TOPK,
|
||||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--custom-folder",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recursive",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Recursively search subdirectories when loading images from custom folder. Default: False",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-index",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
|
||||||
|
)
|
||||||
cli_args, _unknown = parser.parse_known_args()
|
cli_args, _unknown = parser.parse_known_args()
|
||||||
SEARCH_METHOD: str = cli_args.search_method
|
SEARCH_METHOD: str = cli_args.search_method
|
||||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||||
|
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
|
||||||
|
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
|
||||||
|
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
@@ -180,7 +206,23 @@ else:
|
|||||||
# Step 2: Load data only if we need to build the index
|
# Step 2: Load data only if we need to build the index
|
||||||
if need_to_build_index:
|
if need_to_build_index:
|
||||||
print("Loading dataset...")
|
print("Loading dataset...")
|
||||||
if USE_HF_DATASET:
|
# Check for custom folder path first (takes precedence)
|
||||||
|
if CUSTOM_FOLDER_PATH:
|
||||||
|
if not os.path.isdir(CUSTOM_FOLDER_PATH):
|
||||||
|
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
|
||||||
|
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
|
||||||
|
if CUSTOM_FOLDER_RECURSIVE:
|
||||||
|
print(" (recursive mode: searching subdirectories)")
|
||||||
|
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
|
||||||
|
print(f" Found {len(filepaths)} image files")
|
||||||
|
if not images:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
|
||||||
|
)
|
||||||
|
print(f" Successfully loaded {len(images)} images")
|
||||||
|
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||||
|
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||||
|
elif USE_HF_DATASET:
|
||||||
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
||||||
|
|
||||||
# Determine which datasets to load
|
# Determine which datasets to load
|
||||||
@@ -621,7 +663,6 @@ else:
|
|||||||
except Exception:
|
except Exception:
|
||||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||||
|
|
||||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 6: Similarity maps for top-K results
|
# Step 6: Similarity maps for top-K results
|
||||||
|
|||||||
@@ -90,6 +90,51 @@ VIDORE_V1_TASKS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Task name aliases (short names -> full names)
|
||||||
|
TASK_ALIASES = {
|
||||||
|
"arxivqa": "VidoreArxivQARetrieval",
|
||||||
|
"docvqa": "VidoreDocVQARetrieval",
|
||||||
|
"infovqa": "VidoreInfoVQARetrieval",
|
||||||
|
"tabfquad": "VidoreTabfquadRetrieval",
|
||||||
|
"tatdqa": "VidoreTatdqaRetrieval",
|
||||||
|
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||||
|
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||||
|
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||||
|
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||||
|
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_task_name(task_name: str) -> str:
|
||||||
|
"""Normalize task name (handle aliases)."""
|
||||||
|
task_name_lower = task_name.lower()
|
||||||
|
if task_name in VIDORE_V1_TASKS:
|
||||||
|
return task_name
|
||||||
|
if task_name_lower in TASK_ALIASES:
|
||||||
|
return TASK_ALIASES[task_name_lower]
|
||||||
|
# Try partial match
|
||||||
|
for alias, full_name in TASK_ALIASES.items():
|
||||||
|
if alias in task_name_lower or task_name_lower in alias:
|
||||||
|
return full_name
|
||||||
|
return task_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_model_name(model_name: str) -> str:
|
||||||
|
"""Get a safe model name for use in file paths."""
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
# If it's a path, use basename or hash
|
||||||
|
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||||
|
# Use basename if it's reasonable, otherwise use hash
|
||||||
|
basename = os.path.basename(model_name.rstrip("/"))
|
||||||
|
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||||
|
return basename
|
||||||
|
# Use hash for very long or problematic paths
|
||||||
|
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||||
|
# For HuggingFace model names, replace / with _
|
||||||
|
return model_name.replace("/", "_").replace(":", "_")
|
||||||
|
|
||||||
|
|
||||||
def load_vidore_v1_data(
|
def load_vidore_v1_data(
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
@@ -181,6 +226,9 @@ def evaluate_task(
|
|||||||
print(f"Evaluating task: {task_name}")
|
print(f"Evaluating task: {task_name}")
|
||||||
print(f"{'=' * 80}")
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Normalize task name (handle aliases)
|
||||||
|
task_name = normalize_task_name(task_name)
|
||||||
|
|
||||||
# Get task config
|
# Get task config
|
||||||
if task_name not in VIDORE_V1_TASKS:
|
if task_name not in VIDORE_V1_TASKS:
|
||||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||||
@@ -223,11 +271,13 @@ def evaluate_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build or load index
|
# Build or load index
|
||||||
|
# Use safe model name for index path (different models need different indexes)
|
||||||
|
safe_model_name = get_safe_model_name(model_name)
|
||||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
if index_path_full is None:
|
if index_path_full is None:
|
||||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||||
if use_fast_plaid:
|
if use_fast_plaid:
|
||||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||||
|
|
||||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
corpus=corpus,
|
corpus=corpus,
|
||||||
@@ -281,8 +331,7 @@ def main():
|
|||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
default="colqwen2",
|
default="colqwen2",
|
||||||
choices=["colqwen2", "colpali"],
|
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||||
help="Model to use",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task",
|
"--task",
|
||||||
@@ -350,11 +399,11 @@ def main():
|
|||||||
|
|
||||||
# Determine tasks to evaluate
|
# Determine tasks to evaluate
|
||||||
if args.task:
|
if args.task:
|
||||||
tasks_to_eval = [args.task]
|
tasks_to_eval = [normalize_task_name(args.task)]
|
||||||
elif args.tasks.lower() == "all":
|
elif args.tasks.lower() == "all":
|
||||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||||
else:
|
else:
|
||||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||||
|
|
||||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||||
|
|
||||||
|
|||||||
@@ -454,7 +454,7 @@ leann search my-index "your query" \
|
|||||||
|
|
||||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||||
|
|
||||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# One-time: install and configure SkyPilot
|
# One-time: install and configure SkyPilot
|
||||||
|
|||||||
@@ -1251,15 +1251,15 @@ class LeannChat:
|
|||||||
"Please provide the best answer you can based on this context and your knowledge."
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
)
|
)
|
||||||
|
|
||||||
print("The context provided to the LLM is:")
|
logger.info("The context provided to the LLM is:")
|
||||||
print(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
logger.info(f"{'Relevance':<10} | {'Chunk id':<10} | {'Content':<60} | {'Source':<80}")
|
||||||
print("-" * 150)
|
logger.info("-" * 150)
|
||||||
for r in results:
|
for r in results:
|
||||||
chunk_relevance = f"{r.score:.3f}"
|
chunk_relevance = f"{r.score:.3f}"
|
||||||
chunk_id = r.id
|
chunk_id = r.id
|
||||||
chunk_content = r.text[:60]
|
chunk_content = r.text[:60]
|
||||||
chunk_source = r.metadata.get("source", "")[:80]
|
chunk_source = r.metadata.get("source", "")[:80]
|
||||||
print(
|
logger.info(
|
||||||
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
|
f"{chunk_relevance:<10} | {chunk_id:<10} | {chunk_content:<60} | {chunk_source:<80}"
|
||||||
)
|
)
|
||||||
ask_time = time.time()
|
ask_time = time.time()
|
||||||
|
|||||||
@@ -12,7 +12,13 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import (
|
||||||
|
resolve_anthropic_api_key,
|
||||||
|
resolve_anthropic_base_url,
|
||||||
|
resolve_ollama_host,
|
||||||
|
resolve_openai_api_key,
|
||||||
|
resolve_openai_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -845,6 +851,81 @@ class OpenAIChat(LLMInterface):
|
|||||||
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicChat(LLMInterface):
|
||||||
|
"""LLM interface for Anthropic Claude models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "claude-haiku-4-5",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.base_url = resolve_anthropic_base_url(base_url)
|
||||||
|
self.api_key = resolve_anthropic_api_key(api_key)
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Initializing Anthropic Chat with model='%s' and base_url='%s'",
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
# Allow custom Anthropic-compatible endpoints via base_url
|
||||||
|
self.client = anthropic.Anthropic(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'anthropic' library is required for Anthropic models. Please install it with 'pip install anthropic'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
logger.info(f"Sending request to Anthropic with model {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Anthropic API parameters
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if "temperature" in kwargs:
|
||||||
|
params["temperature"] = kwargs["temperature"]
|
||||||
|
if "top_p" in kwargs:
|
||||||
|
params["top_p"] = kwargs["top_p"]
|
||||||
|
|
||||||
|
response = self.client.messages.create(**params)
|
||||||
|
|
||||||
|
# Extract text from response
|
||||||
|
response_text = response.content[0].text
|
||||||
|
|
||||||
|
# Log token usage
|
||||||
|
print(
|
||||||
|
f"Total tokens = {response.usage.input_tokens + response.usage.output_tokens}, "
|
||||||
|
f"input tokens = {response.usage.input_tokens}, "
|
||||||
|
f"output tokens = {response.usage.output_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.stop_reason == "max_tokens":
|
||||||
|
print("The query is exceeding the maximum allowed number of tokens")
|
||||||
|
|
||||||
|
return response_text.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error communicating with Anthropic: {e}")
|
||||||
|
return f"Error: Could not get a response from Anthropic. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
class SimulatedChat(LLMInterface):
|
class SimulatedChat(LLMInterface):
|
||||||
"""A simple simulated chat for testing and development."""
|
"""A simple simulated chat for testing and development."""
|
||||||
|
|
||||||
@@ -897,6 +978,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
)
|
)
|
||||||
elif llm_type == "gemini":
|
elif llm_type == "gemini":
|
||||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
|
elif llm_type == "anthropic":
|
||||||
|
return AnthropicChat(
|
||||||
|
model=model or "claude-3-5-sonnet-20241022",
|
||||||
|
api_key=llm_config.get("api_key"),
|
||||||
|
base_url=llm_config.get("base_url"),
|
||||||
|
)
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
return SimulatedChat()
|
return SimulatedChat()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -11,7 +11,12 @@ from tqdm import tqdm
|
|||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .interactive_utils import create_cli_session
|
from .interactive_utils import create_cli_session
|
||||||
from .registry import register_project_directory
|
from .registry import register_project_directory
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import (
|
||||||
|
resolve_anthropic_base_url,
|
||||||
|
resolve_ollama_host,
|
||||||
|
resolve_openai_api_key,
|
||||||
|
resolve_openai_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||||
@@ -291,7 +296,7 @@ Examples:
|
|||||||
"--llm",
|
"--llm",
|
||||||
type=str,
|
type=str,
|
||||||
default="ollama",
|
default="ollama",
|
||||||
choices=["simulated", "ollama", "hf", "openai"],
|
choices=["simulated", "ollama", "hf", "openai", "anthropic"],
|
||||||
help="LLM provider (default: ollama)",
|
help="LLM provider (default: ollama)",
|
||||||
)
|
)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
@@ -341,7 +346,7 @@ Examples:
|
|||||||
"--api-key",
|
"--api-key",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
help="API key for cloud LLM providers (OpenAI, Anthropic)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
@@ -1616,6 +1621,12 @@ Examples:
|
|||||||
resolved_api_key = resolve_openai_api_key(args.api_key)
|
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||||
if resolved_api_key:
|
if resolved_api_key:
|
||||||
llm_config["api_key"] = resolved_api_key
|
llm_config["api_key"] = resolved_api_key
|
||||||
|
elif args.llm == "anthropic":
|
||||||
|
# For Anthropic, pass base_url and API key if provided
|
||||||
|
if args.api_base:
|
||||||
|
llm_config["base_url"] = resolve_anthropic_base_url(args.api_base)
|
||||||
|
if args.api_key:
|
||||||
|
llm_config["api_key"] = args.api_key
|
||||||
|
|
||||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Any
|
|||||||
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||||
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||||
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
_DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com"
|
||||||
|
|
||||||
|
|
||||||
def _clean_url(value: str) -> str:
|
def _clean_url(value: str) -> str:
|
||||||
@@ -52,6 +53,23 @@ def resolve_openai_base_url(explicit: str | None = None) -> str:
|
|||||||
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_anthropic_base_url(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the base URL for Anthropic-compatible services."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_ANTHROPIC_BASE_URL"),
|
||||||
|
os.getenv("ANTHROPIC_BASE_URL"),
|
||||||
|
os.getenv("LOCAL_ANTHROPIC_BASE_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_ANTHROPIC_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||||
"""Resolve the API key for OpenAI-compatible services."""
|
"""Resolve the API key for OpenAI-compatible services."""
|
||||||
|
|
||||||
@@ -61,6 +79,15 @@ def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
|||||||
return os.getenv("OPENAI_API_KEY")
|
return os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_anthropic_api_key(explicit: str | None = None) -> str | None:
|
||||||
|
"""Resolve the API key for Anthropic services."""
|
||||||
|
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
|
||||||
|
return os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||||
"""Serialize provider options for child processes."""
|
"""Serialize provider options for child processes."""
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,11 @@ leann build my-project --docs $(git ls-files)
|
|||||||
# Start Claude Code
|
# Start Claude Code
|
||||||
claude
|
claude
|
||||||
```
|
```
|
||||||
|
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
leann build my-project --docs $(git ls-files) --no-recompute
|
||||||
|
```
|
||||||
|
|
||||||
## 🚀 Advanced Usage Examples to build the index
|
## 🚀 Advanced Usage Examples to build the index
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user