update vidore
This commit is contained in:
@@ -2,13 +2,18 @@
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import argparse
|
||||
import faulthandler
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Enable faulthandler to get stack trace on segfault
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
from leann_multi_vector import ( # utility functions/classes
|
||||
_ensure_repo_paths_importable,
|
||||
@@ -20,6 +25,11 @@ from leann_multi_vector import ( # utility functions/classes
|
||||
_build_index,
|
||||
_load_retriever_if_index_exists,
|
||||
_generate_similarity_map,
|
||||
_build_fast_plaid_index,
|
||||
_load_fast_plaid_index_if_exists,
|
||||
_search_fast_plaid,
|
||||
_get_fast_plaid_image,
|
||||
_get_fast_plaid_metadata,
|
||||
QwenVL,
|
||||
)
|
||||
|
||||
@@ -69,6 +79,8 @@ PAGES_DIR: str = "./pages"
|
||||
# Index + retrieval settings
|
||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||
INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||
# Fast-Plaid index settings (alternative to LEANN index)
|
||||
# These are now command-line arguments (see CLI overrides section)
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = False
|
||||
@@ -98,24 +110,64 @@ parser.add_argument(
|
||||
default=QUERY,
|
||||
help=f"Query string to search for. Default: '{QUERY}'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set to True to use fast-plaid instead of LEANN. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default="./indexes/colvision_fastplaid",
|
||||
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
||||
)
|
||||
cli_args, _unknown = parser.parse_known_args()
|
||||
SEARCH_METHOD: str = cli_args.search_method
|
||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||
|
||||
# %%
|
||||
|
||||
# Step 1: Check if we can skip data loading (index already exists)
|
||||
retriever: Optional[Any] = None
|
||||
fast_plaid_index: Optional[Any] = None
|
||||
need_to_build_index = REBUILD_INDEX
|
||||
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid index handling
|
||||
if not REBUILD_INDEX:
|
||||
try:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is not None:
|
||||
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Fast-Plaid index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
except Exception as e:
|
||||
# If loading fails (e.g., memory error, corrupted index), rebuild
|
||||
print(f"Warning: Failed to load Fast-Plaid index: {e}")
|
||||
print("Will rebuild the index...")
|
||||
need_to_build_index = True
|
||||
fast_plaid_index = None
|
||||
else:
|
||||
print(f"Index not found, will build new index")
|
||||
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
# Original LEANN index handling
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild index")
|
||||
need_to_build_index = True
|
||||
|
||||
# Step 2: Load data only if we need to build the index
|
||||
@@ -347,6 +399,19 @@ if need_to_build_index:
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
print(f"Loaded {len(images)} images")
|
||||
|
||||
# Memory check before loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
print("Skipping dataset loading (using existing index)")
|
||||
filepaths = [] # Not needed when using existing index
|
||||
@@ -355,23 +420,91 @@ else:
|
||||
|
||||
# %%
|
||||
# Step 3: Load model and processor (only if we need to build index or perform search)
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
print("Step 3: Loading model and processor...")
|
||||
print(f" Model: {MODEL}")
|
||||
try:
|
||||
import sys
|
||||
print(f" Python version: {sys.version}")
|
||||
print(f" Python executable: {sys.executable}")
|
||||
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
# Memory check after loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
# Step 4: Build index if needed
|
||||
if need_to_build_index and retriever is None:
|
||||
print("Building index...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
# Clear memory
|
||||
del images, filepaths, doc_vecs
|
||||
if need_to_build_index:
|
||||
print("Step 4: Building index...")
|
||||
print(f" Number of images: {len(images)}")
|
||||
print(f" Number of filepaths: {len(filepaths)}")
|
||||
|
||||
try:
|
||||
print(" Embedding images...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
print(f" Embedded {len(doc_vecs)} documents")
|
||||
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
|
||||
except Exception as e:
|
||||
print(f"Error embedding images: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Build Fast-Plaid index
|
||||
print(" Building Fast-Plaid index...")
|
||||
try:
|
||||
fast_plaid_index, build_secs = _build_fast_plaid_index(
|
||||
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
|
||||
)
|
||||
from pathlib import Path
|
||||
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
|
||||
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
|
||||
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
|
||||
except Exception as e:
|
||||
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
else:
|
||||
# Build original LEANN index
|
||||
try:
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
except Exception as e:
|
||||
print(f"Error building LEANN index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
|
||||
# Note: Images are now stored in the index, retriever will load them on-demand from disk
|
||||
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them
|
||||
|
||||
|
||||
# %%
|
||||
@@ -380,44 +513,67 @@ _t0 = time.perf_counter()
|
||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||
query_embed_secs = time.perf_counter() - _t0
|
||||
|
||||
query_np = q_vec.float().numpy()
|
||||
|
||||
print(f"[Search] Method: {SEARCH_METHOD}")
|
||||
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
|
||||
|
||||
# Run the selected search method and time it
|
||||
_t0 = time.perf_counter()
|
||||
if SEARCH_METHOD == "ann":
|
||||
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact":
|
||||
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact-all":
|
||||
results = retriever.search_exact_all(query_np, topk=TOPK)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid search
|
||||
if fast_plaid_index is None:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is None:
|
||||
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
|
||||
|
||||
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
|
||||
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
|
||||
else:
|
||||
results = []
|
||||
# Original LEANN search
|
||||
query_np = q_vec.float().numpy()
|
||||
|
||||
if SEARCH_METHOD == "ann":
|
||||
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact":
|
||||
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact-all":
|
||||
results = retriever.search_exact_all(query_np, topk=TOPK)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
|
||||
else:
|
||||
results = []
|
||||
if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
top_images: list[Image.Image] = []
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
# Retrieve image from index instead of memory
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
# Retrieve image and metadata based on index type
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid: load image and get metadata
|
||||
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not find image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
|
||||
top_images.append(image)
|
||||
else:
|
||||
# Original LEANN: retrieve from retriever
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
top_images.append(image)
|
||||
|
||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||
top_images.append(image)
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
|
||||
Reference in New Issue
Block a user