Add custom folder support and improve image loading for multi-vector … (#188)
* Add custom folder support and improve image loading for multi-vector retrieval - Enhanced _load_images_from_dir with recursive search support and better error handling - Added support for WebP format and RGB conversion for all image modes - Added custom folder CLI arguments (--custom-folder, --recursive, --rebuild-index) - Improved documentation and removed completed TODO comment 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * Format code style in leann_multi_vector.py for better readability 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -96,12 +97,63 @@ def _natural_sort_key(name: str) -> int:
|
|||||||
return int(m.group()) if m else 0
|
return int(m.group()) if m else 0
|
||||||
|
|
||||||
|
|
||||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
def _load_images_from_dir(
|
||||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
pages_dir: str, recursive: bool = False
|
||||||
filenames = sorted(filenames, key=_natural_sort_key)
|
) -> tuple[list[str], list[Image.Image]]:
|
||||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
"""
|
||||||
images = [Image.open(p) for p in filepaths]
|
Load images from a directory.
|
||||||
return filepaths, images
|
|
||||||
|
Args:
|
||||||
|
pages_dir: Directory path containing images
|
||||||
|
recursive: If True, recursively search subdirectories (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filepaths, images)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Supported image extensions
|
||||||
|
extensions = ("*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG", "*.webp", "*.WEBP")
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
# Recursive search
|
||||||
|
filepaths = []
|
||||||
|
for ext in extensions:
|
||||||
|
pattern = os.path.join(pages_dir, "**", ext)
|
||||||
|
filepaths.extend(glob.glob(pattern, recursive=True))
|
||||||
|
else:
|
||||||
|
# Non-recursive search (only top-level directory)
|
||||||
|
filepaths = []
|
||||||
|
for ext in extensions:
|
||||||
|
pattern = os.path.join(pages_dir, ext)
|
||||||
|
filepaths.extend(glob.glob(pattern))
|
||||||
|
|
||||||
|
# Sort files naturally
|
||||||
|
filepaths = sorted(filepaths, key=lambda x: _natural_sort_key(os.path.basename(x)))
|
||||||
|
|
||||||
|
# Load images with error handling
|
||||||
|
images = []
|
||||||
|
valid_filepaths = []
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
for filepath in filepaths:
|
||||||
|
try:
|
||||||
|
img = Image.open(filepath)
|
||||||
|
# Convert to RGB if necessary (handles RGBA, P, etc.)
|
||||||
|
if img.mode != "RGB":
|
||||||
|
img = img.convert("RGB")
|
||||||
|
images.append(img)
|
||||||
|
valid_filepaths.append(filepath)
|
||||||
|
except Exception as e:
|
||||||
|
failed_count += 1
|
||||||
|
print(f"Warning: Failed to load image {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if failed_count > 0:
|
||||||
|
print(
|
||||||
|
f"Warning: Failed to load {failed_count} image(s) out of {len(filepaths)} total files"
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_filepaths, images
|
||||||
|
|
||||||
|
|
||||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
|||||||
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||||
DATASET_NAMES = [
|
DATASET_NAMES = [
|
||||||
"weaviate/arXiv-AI-papers-multi-vector",
|
"weaviate/arXiv-AI-papers-multi-vector",
|
||||||
("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||||
]
|
]
|
||||||
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||||
# Set to None to try loading all available splits automatically
|
# Set to None to try loading all available splits automatically
|
||||||
@@ -75,6 +75,11 @@ MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
|||||||
# Local pages (used when USE_HF_DATASET == False)
|
# Local pages (used when USE_HF_DATASET == False)
|
||||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||||
PAGES_DIR: str = "./pages"
|
PAGES_DIR: str = "./pages"
|
||||||
|
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
|
||||||
|
# If set, images will be loaded directly from this folder
|
||||||
|
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
|
||||||
|
# Whether to recursively search subdirectories when loading from custom folder
|
||||||
|
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
|
||||||
|
|
||||||
# Index + retrieval settings
|
# Index + retrieval settings
|
||||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||||
@@ -83,7 +88,7 @@ INDEX_PATH: str = "./indexes/colvision_large.leann"
|
|||||||
# These are now command-line arguments (see CLI overrides section)
|
# These are now command-line arguments (see CLI overrides section)
|
||||||
TOPK: int = 3
|
TOPK: int = 3
|
||||||
FIRST_STAGE_K: int = 500
|
FIRST_STAGE_K: int = 500
|
||||||
REBUILD_INDEX: bool = True
|
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
|
||||||
|
|
||||||
# Artifacts
|
# Artifacts
|
||||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||||
@@ -128,12 +133,33 @@ parser.add_argument(
|
|||||||
default=TOPK,
|
default=TOPK,
|
||||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--custom-folder",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recursive",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Recursively search subdirectories when loading images from custom folder. Default: False",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-index",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
|
||||||
|
)
|
||||||
cli_args, _unknown = parser.parse_known_args()
|
cli_args, _unknown = parser.parse_known_args()
|
||||||
SEARCH_METHOD: str = cli_args.search_method
|
SEARCH_METHOD: str = cli_args.search_method
|
||||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||||
|
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
|
||||||
|
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
|
||||||
|
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
@@ -180,7 +206,23 @@ else:
|
|||||||
# Step 2: Load data only if we need to build the index
|
# Step 2: Load data only if we need to build the index
|
||||||
if need_to_build_index:
|
if need_to_build_index:
|
||||||
print("Loading dataset...")
|
print("Loading dataset...")
|
||||||
if USE_HF_DATASET:
|
# Check for custom folder path first (takes precedence)
|
||||||
|
if CUSTOM_FOLDER_PATH:
|
||||||
|
if not os.path.isdir(CUSTOM_FOLDER_PATH):
|
||||||
|
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
|
||||||
|
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
|
||||||
|
if CUSTOM_FOLDER_RECURSIVE:
|
||||||
|
print(" (recursive mode: searching subdirectories)")
|
||||||
|
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
|
||||||
|
print(f" Found {len(filepaths)} image files")
|
||||||
|
if not images:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
|
||||||
|
)
|
||||||
|
print(f" Successfully loaded {len(images)} images")
|
||||||
|
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||||
|
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||||
|
elif USE_HF_DATASET:
|
||||||
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
||||||
|
|
||||||
# Determine which datasets to load
|
# Determine which datasets to load
|
||||||
@@ -621,7 +663,6 @@ else:
|
|||||||
except Exception:
|
except Exception:
|
||||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||||
|
|
||||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 6: Similarity maps for top-K results
|
# Step 6: Similarity maps for top-K results
|
||||||
|
|||||||
Reference in New Issue
Block a user