robust multi-vector
This commit is contained in:
@@ -1,13 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||
_repo_root = Path(current_file).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
@@ -17,6 +22,380 @@ def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
|
||||
def _find_backend_module_file() -> Optional[Path]:
|
||||
"""Best-effort locate the backend leann_multi_vector.py file, avoiding this file."""
|
||||
this_file = Path(__file__).resolve()
|
||||
candidates: list[Path] = []
|
||||
|
||||
# Common in-repo location
|
||||
repo_root = this_file.parents[3]
|
||||
candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py")
|
||||
candidates.append(
|
||||
repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py"
|
||||
)
|
||||
|
||||
for cand in candidates:
|
||||
try:
|
||||
if cand.exists() and cand.resolve() != this_file:
|
||||
return cand.resolve()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: scan sys.path for another leann_multi_vector.py different from this file
|
||||
for p in list(sys.path):
|
||||
try:
|
||||
cand = Path(p) / "leann_multi_vector.py"
|
||||
if cand.exists() and cand.resolve() != this_file:
|
||||
return cand.resolve()
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
_BACKEND_LEANN_CLASS: Optional[type] = None
|
||||
|
||||
|
||||
def _get_backend_leann_multi_vector() -> type:
|
||||
"""Load backend LeannMultiVector class even if this file shadows its module name."""
|
||||
global _BACKEND_LEANN_CLASS
|
||||
if _BACKEND_LEANN_CLASS is not None:
|
||||
return _BACKEND_LEANN_CLASS
|
||||
|
||||
backend_path = _find_backend_module_file()
|
||||
if backend_path is None:
|
||||
# Fallback to local implementation in this module
|
||||
try:
|
||||
cls = LeannMultiVector # type: ignore[name-defined]
|
||||
_BACKEND_LEANN_CLASS = cls
|
||||
return cls
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
"Could not locate backend 'leann_multi_vector.py' and no local implementation found. "
|
||||
"Ensure the leann backend is available under packages/leann-backend-hnsw or installed."
|
||||
) from e
|
||||
|
||||
import importlib.util
|
||||
|
||||
module_name = "leann_hnsw_backend_module"
|
||||
spec = importlib.util.spec_from_file_location(module_name, str(backend_path))
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Failed to create spec for backend module at {backend_path}")
|
||||
backend_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = backend_module
|
||||
spec.loader.exec_module(backend_module) # type: ignore[assignment]
|
||||
|
||||
if not hasattr(backend_module, "LeannMultiVector"):
|
||||
raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}")
|
||||
_BACKEND_LEANN_CLASS = backend_module.LeannMultiVector
|
||||
return _BACKEND_LEANN_CLASS
|
||||
|
||||
|
||||
def _natural_sort_key(name: str) -> int:
|
||||
m = re.search(r"\d+", name)
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||
filenames = sorted(filenames, key=_natural_sort_key)
|
||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||
images = [Image.open(p) for p in filepaths]
|
||||
return filepaths, images
|
||||
|
||||
|
||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||
if not pdf_path:
|
||||
return
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
try:
|
||||
from pdf2image import convert_from_path
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||
) from e
|
||||
images = convert_from_path(pdf_path, dpi=dpi)
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||
|
||||
|
||||
def _select_device_and_dtype():
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import get_torch_device
|
||||
|
||||
device_str = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
"mps"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
device = get_torch_device(device_str)
|
||||
# Stable dtype selection to avoid NaNs:
|
||||
# - CUDA: prefer bfloat16 if supported, else float16
|
||||
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||
# - CPU: float32
|
||||
if device_str == "cuda":
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||
except Exception:
|
||||
pass
|
||||
elif device_str == "mps":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = torch.float32
|
||||
return device_str, device, dtype
|
||||
|
||||
|
||||
def _load_colvision(model_choice: str):
|
||||
import torch
|
||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
device_str, device, dtype = _select_device_and_dtype()
|
||||
|
||||
if model_choice == "colqwen2":
|
||||
model_name = "vidore/colqwen2-v1.0"
|
||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||
attn_implementation = (
|
||||
"flash_attention_2"
|
||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||
else "eager"
|
||||
)
|
||||
model = ColQwen2.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
).eval()
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
else:
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
).eval()
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
|
||||
return model_name, model, processor, device_str, device, dtype
|
||||
|
||||
|
||||
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Ensure deterministic eval and autocast for stability
|
||||
model.eval()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[Image.Image](images),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
|
||||
doc_vecs: list[Any] = []
|
||||
for batch_doc in tqdm(dataloader, desc="Embedding images"):
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
embeddings_doc = model(**batch_doc)
|
||||
else:
|
||||
embeddings_doc = model(**batch_doc)
|
||||
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||
return doc_vecs
|
||||
|
||||
|
||||
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
model.eval()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
q_vecs: list[Any] = []
|
||||
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
embeddings_query = model(**batch_query)
|
||||
else:
|
||||
embeddings_query = model(**batch_query)
|
||||
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
return q_vecs
|
||||
|
||||
|
||||
def _build_index(
|
||||
index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image]
|
||||
) -> Any:
|
||||
LeannMultiVector = _get_backend_leann_multi_vector()
|
||||
dim = int(doc_vecs[0].shape[-1])
|
||||
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||
retriever.create_collection()
|
||||
for i, vec in enumerate(doc_vecs):
|
||||
data = {
|
||||
"colbert_vecs": vec.float().numpy(),
|
||||
"doc_id": i,
|
||||
"filepath": filepaths[i],
|
||||
"image": images[i], # Include the original image
|
||||
}
|
||||
retriever.insert(data)
|
||||
retriever.create_index()
|
||||
return retriever
|
||||
|
||||
|
||||
def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
|
||||
LeannMultiVector = _get_backend_leann_multi_vector()
|
||||
index_base = Path(index_path)
|
||||
# Check for the actual HNSW index file written by the backend + our sidecar files
|
||||
index_file = index_base.parent / f"{index_base.stem}.index"
|
||||
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||
if index_file.exists() and meta.exists() and labels.exists():
|
||||
try:
|
||||
with open(meta, encoding="utf-8") as f:
|
||||
meta_json = json.load(f)
|
||||
dim = int(meta_json.get("dimensions", 128))
|
||||
except Exception:
|
||||
dim = 128
|
||||
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_similarity_map(
|
||||
model,
|
||||
processor,
|
||||
image: Image.Image,
|
||||
query: str,
|
||||
token_idx: Optional[int] = None,
|
||||
output_path: Optional[str] = None,
|
||||
) -> tuple[int, float]:
|
||||
import torch
|
||||
from colpali_engine.interpretability import (
|
||||
get_similarity_maps_from_embeddings,
|
||||
plot_similarity_map,
|
||||
)
|
||||
|
||||
batch_images = processor.process_images([image]).to(model.device)
|
||||
batch_queries = processor.process_queries([query]).to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_embeddings = model.forward(**batch_images)
|
||||
query_embeddings = model.forward(**batch_queries)
|
||||
|
||||
n_patches = processor.get_n_patches(
|
||||
image_size=image.size,
|
||||
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||
)
|
||||
image_mask = processor.get_image_mask(batch_images)
|
||||
|
||||
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||
image_embeddings=image_embeddings,
|
||||
query_embeddings=query_embeddings,
|
||||
n_patches=n_patches,
|
||||
image_mask=image_mask,
|
||||
)
|
||||
|
||||
similarity_maps = batched_similarity_maps[0]
|
||||
|
||||
# Determine token index if not provided: choose the token with highest max score
|
||||
if token_idx is None:
|
||||
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||
token_idx = int(per_token_max.argmax().item())
|
||||
|
||||
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||
|
||||
if output_path:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plot_similarity_map(
|
||||
image=image,
|
||||
similarity_map=similarity_maps[token_idx],
|
||||
figsize=(14, 14),
|
||||
show_colorbar=False,
|
||||
)
|
||||
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
plt.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
return token_idx, float(max_sim_score)
|
||||
|
||||
|
||||
class QwenVL:
|
||||
def __init__(self, device: str):
|
||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
|
||||
min_pixels = 256 * 28 * 28
|
||||
max_pixels = 1280 * 28 * 28
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||
)
|
||||
|
||||
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
content = []
|
||||
for img in images:
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="jpeg")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||
content.append({"type": "text", "text": query})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
return self.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
|
||||
|
||||
# Ensure repo paths are importable for dynamic backend loading
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
||||
@@ -71,6 +450,7 @@ class LeannMultiVector:
|
||||
"doc_id": int(data["doc_id"]),
|
||||
"filepath": data.get("filepath", ""),
|
||||
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
||||
"image": data.get("image"), # PIL Image object (optional)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -86,6 +466,11 @@ class LeannMultiVector:
|
||||
index_path_obj = Path(self.index_path)
|
||||
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
|
||||
|
||||
def _images_dir_path(self) -> Path:
|
||||
"""Directory where original images are stored."""
|
||||
index_path_obj = Path(self.index_path)
|
||||
return index_path_obj.parent / f"{index_path_obj.name}.images"
|
||||
|
||||
def create_index(self) -> None:
|
||||
if not self._pending_items:
|
||||
return
|
||||
@@ -93,10 +478,23 @@ class LeannMultiVector:
|
||||
embeddings: list[np.ndarray] = []
|
||||
labels_meta: list[dict] = []
|
||||
|
||||
# Create images directory if needed
|
||||
images_dir = self._images_dir_path()
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for item in self._pending_items:
|
||||
doc_id = int(item["doc_id"])
|
||||
filepath = item.get("filepath", "")
|
||||
colbert_vecs = item["colbert_vecs"]
|
||||
image = item.get("image")
|
||||
|
||||
# Save image if provided
|
||||
image_path = ""
|
||||
if image is not None and isinstance(image, Image.Image):
|
||||
image_filename = f"doc_{doc_id}.png"
|
||||
image_path = str(images_dir / image_filename)
|
||||
image.save(image_path, "PNG")
|
||||
|
||||
for seq_id, vec in enumerate(colbert_vecs):
|
||||
vec_np = np.asarray(vec, dtype=np.float32)
|
||||
embeddings.append(vec_np)
|
||||
@@ -106,6 +504,7 @@ class LeannMultiVector:
|
||||
"doc_id": doc_id,
|
||||
"seq_id": int(seq_id),
|
||||
"filepath": filepath,
|
||||
"image_path": image_path, # Store the path to the saved image
|
||||
}
|
||||
)
|
||||
|
||||
@@ -113,7 +512,6 @@ class LeannMultiVector:
|
||||
return
|
||||
|
||||
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||
# print shape of embeddings_np
|
||||
print(embeddings_np.shape)
|
||||
|
||||
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||
@@ -338,3 +736,45 @@ class LeannMultiVector:
|
||||
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
|
||||
def get_image(self, doc_id: int) -> Optional[Image.Image]:
|
||||
"""
|
||||
Retrieve the original image for a given doc_id from the index.
|
||||
|
||||
Args:
|
||||
doc_id: The document ID
|
||||
|
||||
Returns:
|
||||
PIL Image object if found, None otherwise
|
||||
"""
|
||||
self._load_labels_meta_if_needed()
|
||||
|
||||
# Find the image_path for this doc_id (all seq_ids for same doc share the same image_path)
|
||||
for meta in self._labels_meta:
|
||||
if meta.get("doc_id") == doc_id:
|
||||
image_path = meta.get("image_path", "")
|
||||
if image_path and Path(image_path).exists():
|
||||
return Image.open(image_path)
|
||||
break
|
||||
return None
|
||||
|
||||
def get_metadata(self, doc_id: int) -> Optional[dict]:
|
||||
"""
|
||||
Retrieve metadata for a given doc_id.
|
||||
|
||||
Args:
|
||||
doc_id: The document ID
|
||||
|
||||
Returns:
|
||||
Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise
|
||||
"""
|
||||
self._load_labels_meta_if_needed()
|
||||
|
||||
for meta in self._labels_meta:
|
||||
if meta.get("doc_id") == doc_id:
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"filepath": meta.get("filepath", ""),
|
||||
"image_path": meta.get("image_path", ""),
|
||||
}
|
||||
return None
|
||||
|
||||
@@ -2,35 +2,31 @@
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||
_repo_root = Path(current_file).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
from leann_multi_vector import ( # utility functions/classes
|
||||
_ensure_repo_paths_importable,
|
||||
_load_images_from_dir,
|
||||
_maybe_convert_pdf_to_images,
|
||||
_load_colvision,
|
||||
_embed_images,
|
||||
_embed_queries,
|
||||
_build_index,
|
||||
_load_retriever_if_index_exists,
|
||||
_generate_similarity_map,
|
||||
QwenVL,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||
|
||||
# %%
|
||||
# Config
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
|
||||
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
|
||||
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||
|
||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||
@@ -45,7 +41,7 @@ PAGES_DIR: str = "./pages"
|
||||
|
||||
# Index + retrieval settings
|
||||
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||
TOPK: int = 1
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = False
|
||||
|
||||
@@ -55,338 +51,57 @@ SIMILARITY_MAP: bool = True
|
||||
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||
ANSWER: bool = True
|
||||
MAX_NEW_TOKENS: int = 128
|
||||
|
||||
|
||||
# %%
|
||||
# Helpers
|
||||
def _natural_sort_key(name: str) -> int:
|
||||
m = re.search(r"\d+", name)
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||
filenames = sorted(filenames, key=_natural_sort_key)
|
||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||
images = [Image.open(p) for p in filepaths]
|
||||
return filepaths, images
|
||||
|
||||
|
||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||
if not pdf_path:
|
||||
return
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
try:
|
||||
from pdf2image import convert_from_path
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||
) from e
|
||||
images = convert_from_path(pdf_path, dpi=dpi)
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||
|
||||
|
||||
def _select_device_and_dtype():
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import get_torch_device
|
||||
|
||||
device_str = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
"mps"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
device = get_torch_device(device_str)
|
||||
# Stable dtype selection to avoid NaNs:
|
||||
# - CUDA: prefer bfloat16 if supported, else float16
|
||||
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||
# - CPU: float32
|
||||
if device_str == "cuda":
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||
except Exception:
|
||||
pass
|
||||
elif device_str == "mps":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = torch.float32
|
||||
return device_str, device, dtype
|
||||
|
||||
|
||||
def _load_colvision(model_choice: str):
|
||||
import torch
|
||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
device_str, device, dtype = _select_device_and_dtype()
|
||||
|
||||
if model_choice == "colqwen2":
|
||||
model_name = "vidore/colqwen2-v1.0"
|
||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||
attn_implementation = (
|
||||
"flash_attention_2"
|
||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||
else "eager"
|
||||
)
|
||||
model = ColQwen2.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
).eval()
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
else:
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
).eval()
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
|
||||
return model_name, model, processor, device_str, device, dtype
|
||||
|
||||
|
||||
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Ensure deterministic eval and autocast for stability
|
||||
model.eval()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[Image.Image](images),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
|
||||
doc_vecs: list[Any] = []
|
||||
for batch_doc in tqdm(dataloader, desc="Embedding images"):
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
embeddings_doc = model(**batch_doc)
|
||||
else:
|
||||
embeddings_doc = model(**batch_doc)
|
||||
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||
return doc_vecs
|
||||
|
||||
|
||||
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||
import torch
|
||||
from colpali_engine.utils.torch_utils import ListDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
model.eval()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
q_vecs: list[Any] = []
|
||||
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
if model.device.type == "cuda":
|
||||
with torch.autocast(
|
||||
device_type="cuda",
|
||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||
):
|
||||
embeddings_query = model(**batch_query)
|
||||
else:
|
||||
embeddings_query = model(**batch_query)
|
||||
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
return q_vecs
|
||||
|
||||
|
||||
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
|
||||
dim = int(doc_vecs[0].shape[-1])
|
||||
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||
retriever.create_collection()
|
||||
for i, vec in enumerate(doc_vecs):
|
||||
data = {
|
||||
"colbert_vecs": vec.float().numpy(),
|
||||
"doc_id": i,
|
||||
"filepath": filepaths[i],
|
||||
}
|
||||
retriever.insert(data)
|
||||
retriever.create_index()
|
||||
return retriever
|
||||
|
||||
|
||||
def _load_retriever_if_index_exists(index_path: str) -> Optional[LeannMultiVector]:
|
||||
index_base = Path(index_path)
|
||||
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||
if index_base.exists() and meta.exists() and labels.exists():
|
||||
try:
|
||||
with open(meta, "r", encoding="utf-8") as f:
|
||||
meta_json = json.load(f)
|
||||
dim = int(meta_json.get("dimensions", 128))
|
||||
except Exception:
|
||||
dim = 128
|
||||
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_similarity_map(
|
||||
model,
|
||||
processor,
|
||||
image: Image.Image,
|
||||
query: str,
|
||||
token_idx: Optional[int] = None,
|
||||
output_path: Optional[str] = None,
|
||||
) -> tuple[int, float]:
|
||||
import torch
|
||||
from colpali_engine.interpretability import (
|
||||
get_similarity_maps_from_embeddings,
|
||||
plot_similarity_map,
|
||||
)
|
||||
|
||||
batch_images = processor.process_images([image]).to(model.device)
|
||||
batch_queries = processor.process_queries([query]).to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_embeddings = model.forward(**batch_images)
|
||||
query_embeddings = model.forward(**batch_queries)
|
||||
|
||||
n_patches = processor.get_n_patches(
|
||||
image_size=image.size,
|
||||
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||
)
|
||||
image_mask = processor.get_image_mask(batch_images)
|
||||
|
||||
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||
image_embeddings=image_embeddings,
|
||||
query_embeddings=query_embeddings,
|
||||
n_patches=n_patches,
|
||||
image_mask=image_mask,
|
||||
)
|
||||
|
||||
similarity_maps = batched_similarity_maps[0]
|
||||
|
||||
# Determine token index if not provided: choose the token with highest max score
|
||||
if token_idx is None:
|
||||
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||
token_idx = int(per_token_max.argmax().item())
|
||||
|
||||
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||
|
||||
if output_path:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plot_similarity_map(
|
||||
image=image,
|
||||
similarity_map=similarity_maps[token_idx],
|
||||
figsize=(14, 14),
|
||||
show_colorbar=False,
|
||||
)
|
||||
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
plt.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
return token_idx, float(max_sim_score)
|
||||
|
||||
|
||||
class QwenVL:
|
||||
def __init__(self, device: str):
|
||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
|
||||
min_pixels = 256 * 28 * 28
|
||||
max_pixels = 1280 * 28 * 28
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||
)
|
||||
|
||||
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
content = []
|
||||
for img in images:
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="jpeg")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||
content.append({"type": "text", "text": query})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
return self.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
MAX_NEW_TOKENS: int = 1024
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# Step 1: Prepare data
|
||||
if USE_HF_DATASET:
|
||||
from datasets import load_dataset
|
||||
# Step 1: Check if we can skip data loading (index already exists)
|
||||
retriever: Optional[Any] = None
|
||||
need_to_build_index = REBUILD_INDEX
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||
filepaths: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for i in tqdm(range(N), desc="Loading dataset", total=N ):
|
||||
p = dataset[i]
|
||||
# Compose a descriptive identifier for printing later
|
||||
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||
print(identifier)
|
||||
filepaths.append(identifier)
|
||||
images.append(p["page_image"]) # PIL Image
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
|
||||
# Step 2: Load data only if we need to build the index
|
||||
if need_to_build_index:
|
||||
print("Loading dataset...")
|
||||
if USE_HF_DATASET:
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||
filepaths: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for i in tqdm(range(N), desc="Loading dataset", total=N):
|
||||
p = dataset[i]
|
||||
# Compose a descriptive identifier for printing later
|
||||
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||
filepaths.append(identifier)
|
||||
images.append(p["page_image"]) # PIL Image
|
||||
else:
|
||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
print(f"Loaded {len(images)} images")
|
||||
else:
|
||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
print("Skipping dataset loading (using existing index)")
|
||||
filepaths = [] # Not needed when using existing index
|
||||
images = [] # Not needed when using existing index
|
||||
|
||||
|
||||
# %%
|
||||
# Step 2: Load model and processor
|
||||
# Step 3: Load model and processor (only if we need to build index or perform search)
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
@@ -394,30 +109,39 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
# %%
|
||||
|
||||
# %%
|
||||
# Step 3: Build or load index
|
||||
retriever: Optional[LeannMultiVector] = None
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
|
||||
if retriever is None:
|
||||
# Step 4: Build index if needed
|
||||
if need_to_build_index and retriever is None:
|
||||
print("Building index...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
# Clear memory
|
||||
del images, filepaths, doc_vecs
|
||||
|
||||
# Note: Images are now stored in the index, retriever will load them on-demand from disk
|
||||
|
||||
|
||||
# %%
|
||||
# Step 4: Embed query and search
|
||||
# Step 5: Embed query and search
|
||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
|
||||
if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
top_images: list[Image.Image] = []
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
path = filepaths[doc_id]
|
||||
# Retrieve image from index instead of memory
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||
top_images.append(images[doc_id])
|
||||
top_images.append(image)
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
@@ -430,12 +154,17 @@ else:
|
||||
else:
|
||||
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||
img.save(str(out_path))
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
# Print the retrieval score (document-level MaxSim) alongside the saved path
|
||||
try:
|
||||
score, _doc_id = results[rank - 1]
|
||||
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
|
||||
except Exception:
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
|
||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||
|
||||
# %%
|
||||
# Step 5: Similarity maps for top-K results
|
||||
# Step 6: Similarity maps for top-K results
|
||||
if results and SIMILARITY_MAP:
|
||||
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||
from pathlib import Path as _Path
|
||||
@@ -472,7 +201,7 @@ if results and SIMILARITY_MAP:
|
||||
|
||||
|
||||
# %%
|
||||
# Step 6: Optional answer generation
|
||||
# Step 7: Optional answer generation
|
||||
if results and ANSWER:
|
||||
qwen = QwenVL(device=device_str)
|
||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||
|
||||
@@ -57,6 +57,8 @@ dependencies = [
|
||||
"tree-sitter-c-sharp>=0.20.0",
|
||||
"tree-sitter-typescript>=0.20.0",
|
||||
"torchvision>=0.23.0",
|
||||
"einops",
|
||||
"seaborn",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
Reference in New Issue
Block a user