add colqwen stuff and pass ruff

This commit is contained in:
yichuan-w
2025-09-22 22:01:29 +00:00
parent 72455bb269
commit 94d9a203a2
7 changed files with 98815 additions and 99376 deletions

View File

@@ -1,13 +1,12 @@
## Jupyter-style notebook script
#%%
# %%
# uv pip install matplotlib qwen_vl_utils
import os
import re
import sys
from pathlib import Path
from typing import List, Optional, Tuple, cast, Any
from typing import Any, Optional, cast
import numpy as np
from PIL import Image
from tqdm import tqdm
@@ -27,7 +26,7 @@ _ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
#%%
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
@@ -46,26 +45,26 @@ PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 1
FIRST_STAGE_K: int = 50
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
# Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = -1 # -1 means auto-select the most salient token
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 128
#%%
# %%
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> Tuple[List[str], List[Image.Image]]:
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
@@ -80,7 +79,9 @@ def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: i
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError("pdf2image is required to convert PDF to images. Install via pip install pdf2image") from e
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
@@ -93,7 +94,11 @@ def _select_device_and_dtype():
device_str = (
"cuda"
if torch.cuda.is_available()
else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
@@ -115,17 +120,20 @@ def _select_device_and_dtype():
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
from colpali_engine.models import ColQwen2, ColQwen2Processor
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = "flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
@@ -145,7 +153,7 @@ def _load_colvision(model_choice: str):
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: List[Image.Image]) -> List[Any]:
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
@@ -160,13 +168,16 @@ def _embed_images(model, processor, images: List[Image.Image]) -> List[Any]:
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: List[Any] = []
doc_vecs: list[Any] = []
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(device_type="cuda", dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16):
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
@@ -174,7 +185,7 @@ def _embed_images(model, processor, images: List[Image.Image]) -> List[Any]:
return doc_vecs
def _embed_queries(model, processor, queries: List[str]) -> List[Any]:
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
@@ -188,12 +199,15 @@ def _embed_queries(model, processor, queries: List[str]) -> List[Any]:
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: List[Any] = []
q_vecs: list[Any] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(device_type="cuda", dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16):
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
@@ -201,7 +215,7 @@ def _embed_queries(model, processor, queries: List[str]) -> List[Any]:
return q_vecs
def _build_index(index_path: str, doc_vecs: List[Any], filepaths: List[str]) -> LeannMultiVector:
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
@@ -233,7 +247,7 @@ def _generate_similarity_map(
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> Tuple[int, float]:
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
@@ -288,8 +302,8 @@ def _generate_similarity_map(
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -305,11 +319,12 @@ class QwenVL:
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: List[Image.Image], max_new_tokens: int = 128) -> str:
from qwen_vl_utils import process_vision_info
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
@@ -319,17 +334,25 @@ class QwenVL:
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
#%%
# %%
# Step 1: Prepare data
if USE_HF_DATASET:
@@ -337,32 +360,33 @@ if USE_HF_DATASET:
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: List[str] = []
images: List[Image.Image] = []
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset"):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = (
f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
)
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
print(identifier)
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist.")
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
#%%
# %%
# Step 2: Load model and processor
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
#%%
# %%
#%%
# %%
# Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX:
@@ -377,8 +401,7 @@ if retriever is None:
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
#%%
# %%
# Step 4: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
@@ -386,7 +409,7 @@ if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: List[Image.Image] = []
top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1):
path = filepaths[doc_id]
# For HF dataset, path is a descriptive identifier, not a real file path
@@ -395,9 +418,10 @@ else:
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
base = _Path(SAVE_TOP_IMAGE)
base.parent.mkdir(parents=True, exist_ok=True)
for rank, img in enumerate(top_images[: TOPK], start=1):
for rank, img in enumerate(top_images[:TOPK], start=1):
if base.suffix:
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
else:
@@ -405,14 +429,16 @@ else:
img.save(str(out_path))
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
#%%
# %%
# Step 5: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
for rank, img in enumerate(top_images[: TOPK], start=1):
for rank, img in enumerate(top_images[:TOPK], start=1):
if output_base:
if output_base.suffix:
out_dir = output_base.parent
@@ -433,17 +459,19 @@ if results and SIMILARITY_MAP:
output_path=out_path,
)
if out_path:
print(f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}")
print(
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
)
else:
print(f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})")
print(
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
)
#%%
# %%
# Step 6: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[: TOPK], max_new_tokens=MAX_NEW_TOKENS)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
print("\nAnswer:")
print(response)