Compare commits
4 Commits
main
...
feature/cu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9996c29618 | ||
|
|
12951ad4d5 | ||
|
|
a878d2459b | ||
|
|
6c39a3427f |
@@ -1,5 +1,7 @@
|
||||
import concurrent.futures
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@@ -11,6 +13,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||
@@ -96,12 +100,63 @@ def _natural_sort_key(name: str) -> int:
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||
filenames = sorted(filenames, key=_natural_sort_key)
|
||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||
images = [Image.open(p) for p in filepaths]
|
||||
return filepaths, images
|
||||
def _load_images_from_dir(
|
||||
pages_dir: str, recursive: bool = False
|
||||
) -> tuple[list[str], list[Image.Image]]:
|
||||
"""
|
||||
Load images from a directory.
|
||||
|
||||
Args:
|
||||
pages_dir: Directory path containing images
|
||||
recursive: If True, recursively search subdirectories (default: False)
|
||||
|
||||
Returns:
|
||||
Tuple of (filepaths, images)
|
||||
"""
|
||||
|
||||
# Supported image extensions
|
||||
extensions = ("*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG", "*.webp", "*.WEBP")
|
||||
|
||||
if recursive:
|
||||
# Recursive search
|
||||
filepaths = []
|
||||
for ext in extensions:
|
||||
pattern = os.path.join(pages_dir, "**", ext)
|
||||
filepaths.extend(glob.glob(pattern, recursive=True))
|
||||
else:
|
||||
# Non-recursive search (only top-level directory)
|
||||
filepaths = []
|
||||
for ext in extensions:
|
||||
pattern = os.path.join(pages_dir, ext)
|
||||
filepaths.extend(glob.glob(pattern))
|
||||
|
||||
# Sort files naturally
|
||||
filepaths = sorted(filepaths, key=lambda x: _natural_sort_key(os.path.basename(x)))
|
||||
|
||||
# Load images with error handling
|
||||
images = []
|
||||
valid_filepaths = []
|
||||
failed_count = 0
|
||||
|
||||
for filepath in filepaths:
|
||||
try:
|
||||
img = Image.open(filepath)
|
||||
# Convert to RGB if necessary (handles RGBA, P, etc.)
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
images.append(img)
|
||||
valid_filepaths.append(filepath)
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
print(f"Warning: Failed to load image {filepath}: {e}")
|
||||
continue
|
||||
|
||||
if failed_count > 0:
|
||||
print(
|
||||
f"Warning: Failed to load {failed_count} image(s) out of {len(filepaths)} total files"
|
||||
)
|
||||
|
||||
return valid_filepaths, images
|
||||
|
||||
|
||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||
@@ -151,6 +206,8 @@ def _select_device_and_dtype():
|
||||
|
||||
|
||||
def _load_colvision(model_choice: str):
|
||||
import os
|
||||
|
||||
import torch
|
||||
from colpali_engine.models import (
|
||||
ColPali,
|
||||
@@ -162,6 +219,16 @@ def _load_colvision(model_choice: str):
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
|
||||
# Set environment variables to ensure models are downloaded from HuggingFace
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
|
||||
# Log model loading info
|
||||
logger.info(f"Loading ColVision model: {model_choice}")
|
||||
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
|
||||
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
|
||||
|
||||
device_str, device, dtype = _select_device_and_dtype()
|
||||
|
||||
# Determine model name and type
|
||||
@@ -202,29 +269,36 @@ def _load_colvision(model_choice: str):
|
||||
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
|
||||
)
|
||||
|
||||
# Load model from HuggingFace Hub (not Google Drive)
|
||||
# Use local_files_only=False to ensure download from HF if not cached
|
||||
if model_type == "colqwen2.5":
|
||||
model = ColQwen2_5.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||
).eval()
|
||||
processor = ColQwen2_5_Processor.from_pretrained(model_name)
|
||||
processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
|
||||
elif model_type == "colqwen2":
|
||||
model = ColQwen2.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||
).eval()
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False)
|
||||
else: # colpali
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||
).eval()
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
processor = cast(
|
||||
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
|
||||
)
|
||||
|
||||
return model_name, model, processor, device_str, device, dtype
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||
DATASET_NAMES = [
|
||||
"weaviate/arXiv-AI-papers-multi-vector",
|
||||
("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||
]
|
||||
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||
# Set to None to try loading all available splits automatically
|
||||
@@ -75,6 +75,11 @@ MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||
# Local pages (used when USE_HF_DATASET == False)
|
||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||
PAGES_DIR: str = "./pages"
|
||||
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
|
||||
# If set, images will be loaded directly from this folder
|
||||
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
|
||||
# Whether to recursively search subdirectories when loading from custom folder
|
||||
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
|
||||
|
||||
# Index + retrieval settings
|
||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||
@@ -83,7 +88,7 @@ INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||
# These are now command-line arguments (see CLI overrides section)
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = True
|
||||
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
@@ -128,12 +133,33 @@ parser.add_argument(
|
||||
default=TOPK,
|
||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Recursively search subdirectories when loading images from custom folder. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
|
||||
)
|
||||
cli_args, _unknown = parser.parse_known_args()
|
||||
SEARCH_METHOD: str = cli_args.search_method
|
||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
|
||||
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
|
||||
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
|
||||
|
||||
# %%
|
||||
|
||||
@@ -180,7 +206,23 @@ else:
|
||||
# Step 2: Load data only if we need to build the index
|
||||
if need_to_build_index:
|
||||
print("Loading dataset...")
|
||||
if USE_HF_DATASET:
|
||||
# Check for custom folder path first (takes precedence)
|
||||
if CUSTOM_FOLDER_PATH:
|
||||
if not os.path.isdir(CUSTOM_FOLDER_PATH):
|
||||
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
|
||||
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
|
||||
if CUSTOM_FOLDER_RECURSIVE:
|
||||
print(" (recursive mode: searching subdirectories)")
|
||||
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
|
||||
print(f" Found {len(filepaths)} image files")
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
|
||||
)
|
||||
print(f" Successfully loaded {len(images)} images")
|
||||
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||
elif USE_HF_DATASET:
|
||||
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
||||
|
||||
# Determine which datasets to load
|
||||
@@ -621,7 +663,6 @@ else:
|
||||
except Exception:
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
|
||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||
|
||||
# %%
|
||||
# Step 6: Similarity maps for top-K results
|
||||
|
||||
@@ -53,6 +53,11 @@ leann build my-project --docs $(git ls-files)
|
||||
# Start Claude Code
|
||||
claude
|
||||
```
|
||||
**Performance tip**: For maximum speed when storage space is not a concern, add the `--no-recompute` flag to your build command. This materializes all tensors and stores them on disk, avoiding recomputation on subsequent builds:
|
||||
|
||||
```bash
|
||||
leann build my-project --docs $(git ls-files) --no-recompute
|
||||
```
|
||||
|
||||
## 🚀 Advanced Usage Examples to build the index
|
||||
|
||||
|
||||
Reference in New Issue
Block a user