update some search in copali

This commit is contained in:
yichuan-w
2025-11-08 08:53:03 +00:00
parent 2406c41eef
commit dc6c9f696e
2 changed files with 167 additions and 6 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
import concurrent.futures
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@@ -45,6 +46,7 @@ class LeannMultiVector:
"is_recompute": is_recompute, "is_recompute": is_recompute,
} }
self._labels_meta: list[dict] = [] self._labels_meta: list[dict] = []
self._docid_to_indices: dict[int, list[int]] | None = None
def _meta_dict(self) -> dict: def _meta_dict(self) -> dict:
return { return {
@@ -80,6 +82,10 @@ class LeannMultiVector:
index_path_obj = Path(self.index_path) index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json" return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def _embeddings_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
def create_index(self) -> None: def create_index(self) -> None:
if not self._pending_items: if not self._pending_items:
return return
@@ -121,6 +127,9 @@ class LeannMultiVector:
with open(self._labels_path(), "w", encoding="utf-8") as f: with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f) _json.dump(labels_meta, f)
# Persist embeddings for exact reranking
np.save(self._embeddings_path(), embeddings_np)
self._labels_meta = labels_meta self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None: def _load_labels_meta_if_needed(self) -> None:
@@ -133,6 +142,19 @@ class LeannMultiVector:
with open(labels_path, encoding="utf-8") as f: with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f) self._labels_meta = _json.load(f)
def _build_docid_to_indices_if_needed(self) -> None:
if self._docid_to_indices is not None:
return
self._load_labels_meta_if_needed()
mapping: dict[int, list[int]] = {}
for idx, meta in enumerate(self._labels_meta):
try:
doc_id = int(meta["doc_id"]) # type: ignore[index]
except Exception:
continue
mapping.setdefault(doc_id, []).append(idx)
self._docid_to_indices = mapping
def search( def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50 self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]: ) -> list[tuple[float, int]]:
@@ -180,3 +202,139 @@ class LeannMultiVector:
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True) scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores return scores[:topk] if len(scores) >= topk else scores
def search_exact(
self,
data: np.ndarray,
topk: int,
*,
first_stage_k: int = 200,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
High-precision MaxSim reranking over candidate documents.
Steps:
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
2) For each candidate doc, load all its token embeddings and compute
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
Returns top-k list of (score, doc_id).
"""
# Normalize inputs
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
# Fallback to approximate if we don't have persisted embeddings
return self.search(data, topk, first_stage_k=first_stage_k)
# Memory-map embeddings to avoid loading all into RAM
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
# First-stage ANN to collect candidate doc_ids
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
if labels is None:
return []
candidate_doc_ids: set[int] = set()
for batch in labels:
for sid in batch:
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
# Exact scoring per doc (parallelized)
assert self._docid_to_indices is not None
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
sim = np.dot(data, doc_vecs.T)
# nan-safe
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact_all(
self,
data: np.ndarray,
topk: int,
*,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
Exact MaxSim over ALL documents (no ANN pre-filtering).
This computes, for each document, sum_i max_j dot(q_i, d_j).
It memory-maps the persisted token-embedding matrix for scalability.
"""
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores

View File

@@ -2,6 +2,7 @@
# %% # %%
# uv pip install matplotlib qwen_vl_utils # uv pip install matplotlib qwen_vl_utils
import os import os
import json
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
@@ -230,12 +231,18 @@ def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) ->
return retriever return retriever
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]: def _load_retriever_if_index_exists(index_path: str) -> Optional[LeannMultiVector]:
index_base = Path(index_path) index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist # Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json" meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json" labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists(): if index_base.exists() and meta.exists() and labels.exists():
try:
with open(meta, "r", encoding="utf-8") as f:
meta_json = json.load(f)
dim = int(meta_json.get("dimensions", 128))
except Exception:
dim = 128
return LeannMultiVector(index_path=index_path, dim=dim) return LeannMultiVector(index_path=index_path, dim=dim)
return None return None
@@ -390,11 +397,7 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# Step 3: Build or load index # Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX: if not REBUILD_INDEX:
try: retriever = _load_retriever_if_index_exists(INDEX_PATH)
one_vec = _embed_images(model, processor, [images[0]])[0]
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
except Exception:
retriever = None
if retriever is None: if retriever is None:
doc_vecs = _embed_images(model, processor, images) doc_vecs = _embed_images(model, processor, images)