import concurrent.futures import json import os import re import sys from pathlib import Path from typing import Any, Optional, cast import numpy as np from PIL import Image from tqdm import tqdm def _ensure_repo_paths_importable(current_file: str) -> None: """Make local leann packages importable without installing (mirrors multi-vector-leann.py).""" _repo_root = Path(current_file).resolve().parents[3] _leann_core_src = _repo_root / "packages" / "leann-core" / "src" _leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw" if str(_leann_core_src) not in sys.path: sys.path.append(str(_leann_core_src)) if str(_leann_hnsw_pkg) not in sys.path: sys.path.append(str(_leann_hnsw_pkg)) def _find_backend_module_file() -> Optional[Path]: """Best-effort locate the backend leann_multi_vector.py file, avoiding this file.""" this_file = Path(__file__).resolve() candidates: list[Path] = [] # Common in-repo location repo_root = this_file.parents[3] candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py") candidates.append( repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py" ) for cand in candidates: try: if cand.exists() and cand.resolve() != this_file: return cand.resolve() except Exception: pass # Fallback: scan sys.path for another leann_multi_vector.py different from this file for p in list(sys.path): try: cand = Path(p) / "leann_multi_vector.py" if cand.exists() and cand.resolve() != this_file: return cand.resolve() except Exception: continue return None _BACKEND_LEANN_CLASS: Optional[type] = None def _get_backend_leann_multi_vector() -> type: """Load backend LeannMultiVector class even if this file shadows its module name.""" global _BACKEND_LEANN_CLASS if _BACKEND_LEANN_CLASS is not None: return _BACKEND_LEANN_CLASS backend_path = _find_backend_module_file() if backend_path is None: # Fallback to local implementation in this module try: cls = LeannMultiVector # type: ignore[name-defined] _BACKEND_LEANN_CLASS = cls return cls except Exception as e: raise ImportError( "Could not locate backend 'leann_multi_vector.py' and no local implementation found. " "Ensure the leann backend is available under packages/leann-backend-hnsw or installed." ) from e import importlib.util module_name = "leann_hnsw_backend_module" spec = importlib.util.spec_from_file_location(module_name, str(backend_path)) if spec is None or spec.loader is None: raise ImportError(f"Failed to create spec for backend module at {backend_path}") backend_module = importlib.util.module_from_spec(spec) sys.modules[module_name] = backend_module spec.loader.exec_module(backend_module) # type: ignore[assignment] if not hasattr(backend_module, "LeannMultiVector"): raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}") _BACKEND_LEANN_CLASS = backend_module.LeannMultiVector return _BACKEND_LEANN_CLASS def _natural_sort_key(name: str) -> int: m = re.search(r"\d+", name) return int(m.group()) if m else 0 def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]: filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))] filenames = sorted(filenames, key=_natural_sort_key) filepaths = [os.path.join(pages_dir, n) for n in filenames] images = [Image.open(p) for p in filepaths] return filepaths, images def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None: if not pdf_path: return os.makedirs(pages_dir, exist_ok=True) try: from pdf2image import convert_from_path except Exception as e: raise RuntimeError( "pdf2image is required to convert PDF to images. Install via pip install pdf2image" ) from e images = convert_from_path(pdf_path, dpi=dpi) for i, image in enumerate(images): image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG") def _select_device_and_dtype(): import torch from colpali_engine.utils.torch_utils import get_torch_device device_str = ( "cuda" if torch.cuda.is_available() else ( "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu" ) ) device = get_torch_device(device_str) # Stable dtype selection to avoid NaNs: # - CUDA: prefer bfloat16 if supported, else float16 # - MPS: use float32 (fp16 on MPS can produce NaNs in some ops) # - CPU: float32 if device_str == "cuda": dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 try: torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+ except Exception: pass elif device_str == "mps": dtype = torch.float32 else: dtype = torch.float32 return device_str, device, dtype def _load_colvision(model_choice: str): import torch from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor from transformers.utils.import_utils import is_flash_attn_2_available device_str, device, dtype = _select_device_and_dtype() if model_choice == "colqwen2": model_name = "vidore/colqwen2-v1.0" # On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available attn_implementation = ( "flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager" ) model = ColQwen2.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map=device, attn_implementation=attn_implementation, ).eval() processor = ColQwen2Processor.from_pretrained(model_name) else: model_name = "vidore/colpali-v1.2" model = ColPali.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map=device, ).eval() processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) return model_name, model, processor, device_str, device, dtype def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]: import torch from colpali_engine.utils.torch_utils import ListDataset from torch.utils.data import DataLoader # Ensure deterministic eval and autocast for stability model.eval() dataloader = DataLoader( dataset=ListDataset[Image.Image](images), batch_size=1, shuffle=False, collate_fn=lambda x: processor.process_images(x), ) doc_vecs: list[Any] = [] for batch_doc in tqdm(dataloader, desc="Embedding images"): with torch.no_grad(): batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} # autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32 if model.device.type == "cuda": with torch.autocast( device_type="cuda", dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16, ): embeddings_doc = model(**batch_doc) else: embeddings_doc = model(**batch_doc) doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu")))) return doc_vecs def _embed_queries(model, processor, queries: list[str]) -> list[Any]: import torch from colpali_engine.utils.torch_utils import ListDataset from torch.utils.data import DataLoader model.eval() dataloader = DataLoader( dataset=ListDataset[str](queries), batch_size=1, shuffle=False, collate_fn=lambda x: processor.process_queries(x), ) q_vecs: list[Any] = [] for batch_query in tqdm(dataloader, desc="Embedding queries"): with torch.no_grad(): batch_query = {k: v.to(model.device) for k, v in batch_query.items()} if model.device.type == "cuda": with torch.autocast( device_type="cuda", dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16, ): embeddings_query = model(**batch_query) else: embeddings_query = model(**batch_query) q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu")))) return q_vecs def _build_index( index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image] ) -> Any: LeannMultiVector = _get_backend_leann_multi_vector() dim = int(doc_vecs[0].shape[-1]) retriever = LeannMultiVector(index_path=index_path, dim=dim) retriever.create_collection() for i, vec in enumerate(doc_vecs): data = { "colbert_vecs": vec.float().numpy(), "doc_id": i, "filepath": filepaths[i], "image": images[i], # Include the original image } retriever.insert(data) retriever.create_index() return retriever def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]: LeannMultiVector = _get_backend_leann_multi_vector() index_base = Path(index_path) # Check for the actual HNSW index file written by the backend + our sidecar files index_file = index_base.parent / f"{index_base.stem}.index" meta = index_base.parent / f"{index_base.name}.meta.json" labels = index_base.parent / f"{index_base.name}.labels.json" if index_file.exists() and meta.exists() and labels.exists(): try: with open(meta, encoding="utf-8") as f: meta_json = json.load(f) dim = int(meta_json.get("dimensions", 128)) except Exception: dim = 128 return LeannMultiVector(index_path=index_path, dim=dim) return None def _generate_similarity_map( model, processor, image: Image.Image, query: str, token_idx: Optional[int] = None, output_path: Optional[str] = None, ) -> tuple[int, float]: import torch from colpali_engine.interpretability import ( get_similarity_maps_from_embeddings, plot_similarity_map, ) batch_images = processor.process_images([image]).to(model.device) batch_queries = processor.process_queries([query]).to(model.device) with torch.no_grad(): image_embeddings = model.forward(**batch_images) query_embeddings = model.forward(**batch_queries) n_patches = processor.get_n_patches( image_size=image.size, spatial_merge_size=getattr(model, "spatial_merge_size", None), ) image_mask = processor.get_image_mask(batch_images) batched_similarity_maps = get_similarity_maps_from_embeddings( image_embeddings=image_embeddings, query_embeddings=query_embeddings, n_patches=n_patches, image_mask=image_mask, ) similarity_maps = batched_similarity_maps[0] # Determine token index if not provided: choose the token with highest max score if token_idx is None: per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values token_idx = int(per_token_max.argmax().item()) max_sim_score = similarity_maps[token_idx, :, :].max().item() if output_path: import matplotlib.pyplot as plt fig, ax = plot_similarity_map( image=image, similarity_map=similarity_maps[token_idx], figsize=(14, 14), show_colorbar=False, ) ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12) os.makedirs(os.path.dirname(output_path), exist_ok=True) plt.savefig(output_path, bbox_inches="tight") plt.close(fig) return token_idx, float(max_sim_score) class QwenVL: def __init__(self, device: str): from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration from transformers.utils.import_utils import is_flash_attn_2_available attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager" self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map=device, attn_implementation=attn_implementation, ) min_pixels = 256 * 28 * 28 max_pixels = 1280 * 28 * 28 self.processor = AutoProcessor.from_pretrained( "Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels ) def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str: import base64 from io import BytesIO from qwen_vl_utils import process_vision_info content = [] for img in images: buffer = BytesIO() img.save(buffer, format="jpeg") img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") content.append({"type": "image", "image": f"data:image;base64,{img_base64}"}) content.append({"type": "text", "text": query}) messages = [{"role": "user", "content": content}] text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ) inputs = inputs.to(self.model.device) generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] return self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # Ensure repo paths are importable for dynamic backend loading _ensure_repo_paths_importable(__file__) from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402 class LeannMultiVector: def __init__( self, index_path: str, dim: int = 128, distance_metric: str = "mips", m: int = 16, ef_construction: int = 500, is_compact: bool = False, is_recompute: bool = False, embedding_model_name: str = "colvision", ) -> None: self.index_path = index_path self.dim = dim self.embedding_model_name = embedding_model_name self._pending_items: list[dict] = [] self._backend_kwargs = { "distance_metric": distance_metric, "M": m, "efConstruction": ef_construction, "is_compact": is_compact, "is_recompute": is_recompute, } self._labels_meta: list[dict] = [] self._docid_to_indices: dict[int, list[int]] | None = None def _meta_dict(self) -> dict: return { "version": "1.0", "backend_name": "hnsw", "embedding_model": self.embedding_model_name, "embedding_mode": "custom", "dimensions": self.dim, "backend_kwargs": self._backend_kwargs, "is_compact": self._backend_kwargs.get("is_compact", True), "is_pruned": self._backend_kwargs.get("is_compact", True) and self._backend_kwargs.get("is_recompute", True), } def create_collection(self) -> None: path = Path(self.index_path) path.parent.mkdir(parents=True, exist_ok=True) def insert(self, data: dict) -> None: self._pending_items.append( { "doc_id": int(data["doc_id"]), "filepath": data.get("filepath", ""), "colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]], "image": data.get("image"), # PIL Image object (optional) } ) def _labels_path(self) -> Path: index_path_obj = Path(self.index_path) return index_path_obj.parent / f"{index_path_obj.name}.labels.json" def _meta_path(self) -> Path: index_path_obj = Path(self.index_path) return index_path_obj.parent / f"{index_path_obj.name}.meta.json" def _embeddings_path(self) -> Path: index_path_obj = Path(self.index_path) return index_path_obj.parent / f"{index_path_obj.name}.emb.npy" def _images_dir_path(self) -> Path: """Directory where original images are stored.""" index_path_obj = Path(self.index_path) return index_path_obj.parent / f"{index_path_obj.name}.images" def create_index(self) -> None: if not self._pending_items: return embeddings: list[np.ndarray] = [] labels_meta: list[dict] = [] # Create images directory if needed images_dir = self._images_dir_path() images_dir.mkdir(parents=True, exist_ok=True) for item in self._pending_items: doc_id = int(item["doc_id"]) filepath = item.get("filepath", "") colbert_vecs = item["colbert_vecs"] image = item.get("image") # Save image if provided image_path = "" if image is not None and isinstance(image, Image.Image): image_filename = f"doc_{doc_id}.png" image_path = str(images_dir / image_filename) image.save(image_path, "PNG") for seq_id, vec in enumerate(colbert_vecs): vec_np = np.asarray(vec, dtype=np.float32) embeddings.append(vec_np) labels_meta.append( { "id": f"{doc_id}:{seq_id}", "doc_id": doc_id, "seq_id": int(seq_id), "filepath": filepath, "image_path": image_path, # Store the path to the saved image } ) if not embeddings: return embeddings_np = np.vstack(embeddings).astype(np.float32) print(embeddings_np.shape) builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim}) ids = [str(i) for i in range(embeddings_np.shape[0])] builder.build(embeddings_np, ids, self.index_path) import json as _json with open(self._meta_path(), "w", encoding="utf-8") as f: _json.dump(self._meta_dict(), f, indent=2) with open(self._labels_path(), "w", encoding="utf-8") as f: _json.dump(labels_meta, f) # Persist embeddings for exact reranking np.save(self._embeddings_path(), embeddings_np) self._labels_meta = labels_meta def _load_labels_meta_if_needed(self) -> None: if self._labels_meta: return labels_path = self._labels_path() if labels_path.exists(): import json as _json with open(labels_path, encoding="utf-8") as f: self._labels_meta = _json.load(f) def _build_docid_to_indices_if_needed(self) -> None: if self._docid_to_indices is not None: return self._load_labels_meta_if_needed() mapping: dict[int, list[int]] = {} for idx, meta in enumerate(self._labels_meta): try: doc_id = int(meta["doc_id"]) # type: ignore[index] except Exception: continue mapping.setdefault(doc_id, []).append(idx) self._docid_to_indices = mapping def search( self, data: np.ndarray, topk: int, first_stage_k: int = 50 ) -> list[tuple[float, int]]: if data.ndim == 1: data = data.reshape(1, -1) if data.dtype != np.float32: data = data.astype(np.float32) self._load_labels_meta_if_needed() searcher = HNSWSearcher(self.index_path, meta=self._meta_dict()) raw = searcher.search( data, first_stage_k, recompute_embeddings=False, complexity=128, beam_width=1, prune_ratio=0.0, batch_size=0, ) labels = raw.get("labels") distances = raw.get("distances") if labels is None or distances is None: return [] doc_scores: dict[int, float] = {} B = len(labels) for b in range(B): per_doc_best: dict[int, float] = {} for k, sid in enumerate(labels[b]): try: idx = int(sid) except Exception: continue if 0 <= idx < len(self._labels_meta): doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index] else: continue score = float(distances[b][k]) if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]): per_doc_best[doc_id] = score for doc_id, best_score in per_doc_best.items(): doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True) return scores[:topk] if len(scores) >= topk else scores def search_exact( self, data: np.ndarray, topk: int, *, first_stage_k: int = 200, max_workers: int = 32, ) -> list[tuple[float, int]]: """ High-precision MaxSim reranking over candidate documents. Steps: 1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors). 2) For each candidate doc, load all its token embeddings and compute MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))). Returns top-k list of (score, doc_id). """ # Normalize inputs if data.ndim == 1: data = data.reshape(1, -1) if data.dtype != np.float32: data = data.astype(np.float32) self._load_labels_meta_if_needed() self._build_docid_to_indices_if_needed() emb_path = self._embeddings_path() if not emb_path.exists(): # Fallback to approximate if we don't have persisted embeddings return self.search(data, topk, first_stage_k=first_stage_k) # Memory-map embeddings to avoid loading all into RAM all_embeddings = np.load(emb_path, mmap_mode="r") if all_embeddings.dtype != np.float32: all_embeddings = all_embeddings.astype(np.float32) # First-stage ANN to collect candidate doc_ids searcher = HNSWSearcher(self.index_path, meta=self._meta_dict()) raw = searcher.search( data, first_stage_k, recompute_embeddings=False, complexity=128, beam_width=1, prune_ratio=0.0, batch_size=0, ) labels = raw.get("labels") if labels is None: return [] candidate_doc_ids: set[int] = set() for batch in labels: for sid in batch: try: idx = int(sid) except Exception: continue if 0 <= idx < len(self._labels_meta): candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index] # Exact scoring per doc (parallelized) assert self._docid_to_indices is not None def _score_one(doc_id: int) -> tuple[float, int]: token_indices = self._docid_to_indices.get(doc_id, []) if not token_indices: return (0.0, doc_id) doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32) # (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q sim = np.dot(data, doc_vecs.T) # nan-safe sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30) score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum() return (float(score), doc_id) scores: list[tuple[float, int]] = [] with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids] for fut in concurrent.futures.as_completed(futures): scores.append(fut.result()) scores.sort(key=lambda x: x[0], reverse=True) return scores[:topk] if len(scores) >= topk else scores def search_exact_all( self, data: np.ndarray, topk: int, *, max_workers: int = 32, ) -> list[tuple[float, int]]: """ Exact MaxSim over ALL documents (no ANN pre-filtering). This computes, for each document, sum_i max_j dot(q_i, d_j). It memory-maps the persisted token-embedding matrix for scalability. """ if data.ndim == 1: data = data.reshape(1, -1) if data.dtype != np.float32: data = data.astype(np.float32) self._load_labels_meta_if_needed() self._build_docid_to_indices_if_needed() emb_path = self._embeddings_path() if not emb_path.exists(): return self.search(data, topk) all_embeddings = np.load(emb_path, mmap_mode="r") if all_embeddings.dtype != np.float32: all_embeddings = all_embeddings.astype(np.float32) assert self._docid_to_indices is not None candidate_doc_ids = list(self._docid_to_indices.keys()) def _score_one(doc_id: int) -> tuple[float, int]: token_indices = self._docid_to_indices.get(doc_id, []) if not token_indices: return (0.0, doc_id) doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32) sim = np.dot(data, doc_vecs.T) sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30) score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum() return (float(score), doc_id) scores: list[tuple[float, int]] = [] with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: futures = [ex.submit(_score_one, d) for d in candidate_doc_ids] for fut in concurrent.futures.as_completed(futures): scores.append(fut.result()) scores.sort(key=lambda x: x[0], reverse=True) return scores[:topk] if len(scores) >= topk else scores def get_image(self, doc_id: int) -> Optional[Image.Image]: """ Retrieve the original image for a given doc_id from the index. Args: doc_id: The document ID Returns: PIL Image object if found, None otherwise """ self._load_labels_meta_if_needed() # Find the image_path for this doc_id (all seq_ids for same doc share the same image_path) for meta in self._labels_meta: if meta.get("doc_id") == doc_id: image_path = meta.get("image_path", "") if image_path and Path(image_path).exists(): return Image.open(image_path) break return None def get_metadata(self, doc_id: int) -> Optional[dict]: """ Retrieve metadata for a given doc_id. Args: doc_id: The document ID Returns: Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise """ self._load_labels_meta_if_needed() for meta in self._labels_meta: if meta.get("doc_id") == doc_id: return { "doc_id": doc_id, "filepath": meta.get("filepath", ""), "image_path": meta.get("image_path", ""), } return None