reproduce docvqa results and add debug file

This commit is contained in:
yichuan-w
2025-12-03 08:54:55 +00:00
parent 07afe546ea
commit 1c690e4a8a
5 changed files with 450 additions and 222 deletions

View File

@@ -83,7 +83,7 @@ INDEX_PATH: str = "./indexes/colvision_large.leann"
# These are now command-line arguments (see CLI overrides section)
TOPK: int = 3
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
REBUILD_INDEX: bool = True
# Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
@@ -122,11 +122,18 @@ parser.add_argument(
default="./indexes/colvision_fastplaid",
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
)
parser.add_argument(
"--topk",
type=int,
default=TOPK,
help=f"Number of top results to retrieve. Default: {TOPK}",
)
cli_args, _unknown = parser.parse_known_args()
SEARCH_METHOD: str = cli_args.search_method
QUERY = cli_args.query # Override QUERY with CLI argument if provided
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
# %%
@@ -399,7 +406,7 @@ if need_to_build_index:
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print(f"Loaded {len(images)} images")
# Memory check before loading model
try:
import psutil
@@ -426,10 +433,10 @@ try:
import sys
print(f" Python version: {sys.version}")
print(f" Python executable: {sys.executable}")
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
# Memory check after loading model
try:
import psutil
@@ -457,7 +464,7 @@ if need_to_build_index:
print("Step 4: Building index...")
print(f" Number of images: {len(images)}")
print(f" Number of filepaths: {len(filepaths)}")
try:
print(" Embedding images...")
doc_vecs = _embed_images(model, processor, images)
@@ -468,7 +475,7 @@ if need_to_build_index:
import traceback
traceback.print_exc()
raise
if USE_FAST_PLAID:
# Build Fast-Plaid index
print(" Building Fast-Plaid index...")
@@ -523,13 +530,13 @@ if USE_FAST_PLAID:
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
if fast_plaid_index is None:
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
else:
# Original LEANN search
query_np = q_vec.float().numpy()
if SEARCH_METHOD == "ann":
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
@@ -548,7 +555,10 @@ if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
print("\n[DEBUG] Retrieval details:")
top_images: list[Image.Image] = []
image_hashes = {} # Track image hashes to detect duplicates
for rank, (score, doc_id) in enumerate(results, start=1):
# Retrieve image and metadata based on index type
if USE_FAST_PLAID:
@@ -557,7 +567,7 @@ else:
if image is None:
print(f"Warning: Could not find image for doc_id {doc_id}")
continue
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
top_images.append(image)
@@ -571,9 +581,27 @@ else:
metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown"
top_images.append(image)
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
# Calculate image hash to detect duplicates
import hashlib
import io
# Convert image to bytes for hashing
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
image_bytes = img_bytes.getvalue()
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
# Check if this image was already seen
duplicate_info = ""
if image_hash in image_hashes:
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
else:
image_hashes[image_hash] = rank
# Print detailed information
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
if metadata:
print(f" Metadata: {metadata}")
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path