update some search in copali
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -45,6 +46,7 @@ class LeannMultiVector:
|
||||
"is_recompute": is_recompute,
|
||||
}
|
||||
self._labels_meta: list[dict] = []
|
||||
self._docid_to_indices: dict[int, list[int]] | None = None
|
||||
|
||||
def _meta_dict(self) -> dict:
|
||||
return {
|
||||
@@ -80,6 +82,10 @@ class LeannMultiVector:
|
||||
index_path_obj = Path(self.index_path)
|
||||
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||
|
||||
def _embeddings_path(self) -> Path:
|
||||
index_path_obj = Path(self.index_path)
|
||||
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
|
||||
|
||||
def create_index(self) -> None:
|
||||
if not self._pending_items:
|
||||
return
|
||||
@@ -121,6 +127,9 @@ class LeannMultiVector:
|
||||
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||
_json.dump(labels_meta, f)
|
||||
|
||||
# Persist embeddings for exact reranking
|
||||
np.save(self._embeddings_path(), embeddings_np)
|
||||
|
||||
self._labels_meta = labels_meta
|
||||
|
||||
def _load_labels_meta_if_needed(self) -> None:
|
||||
@@ -133,6 +142,19 @@ class LeannMultiVector:
|
||||
with open(labels_path, encoding="utf-8") as f:
|
||||
self._labels_meta = _json.load(f)
|
||||
|
||||
def _build_docid_to_indices_if_needed(self) -> None:
|
||||
if self._docid_to_indices is not None:
|
||||
return
|
||||
self._load_labels_meta_if_needed()
|
||||
mapping: dict[int, list[int]] = {}
|
||||
for idx, meta in enumerate(self._labels_meta):
|
||||
try:
|
||||
doc_id = int(meta["doc_id"]) # type: ignore[index]
|
||||
except Exception:
|
||||
continue
|
||||
mapping.setdefault(doc_id, []).append(idx)
|
||||
self._docid_to_indices = mapping
|
||||
|
||||
def search(
|
||||
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||
) -> list[tuple[float, int]]:
|
||||
@@ -180,3 +202,139 @@ class LeannMultiVector:
|
||||
|
||||
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
|
||||
def search_exact(
|
||||
self,
|
||||
data: np.ndarray,
|
||||
topk: int,
|
||||
*,
|
||||
first_stage_k: int = 200,
|
||||
max_workers: int = 32,
|
||||
) -> list[tuple[float, int]]:
|
||||
"""
|
||||
High-precision MaxSim reranking over candidate documents.
|
||||
|
||||
Steps:
|
||||
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
|
||||
2) For each candidate doc, load all its token embeddings and compute
|
||||
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
|
||||
|
||||
Returns top-k list of (score, doc_id).
|
||||
"""
|
||||
# Normalize inputs
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(1, -1)
|
||||
if data.dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
self._load_labels_meta_if_needed()
|
||||
self._build_docid_to_indices_if_needed()
|
||||
|
||||
emb_path = self._embeddings_path()
|
||||
if not emb_path.exists():
|
||||
# Fallback to approximate if we don't have persisted embeddings
|
||||
return self.search(data, topk, first_stage_k=first_stage_k)
|
||||
|
||||
# Memory-map embeddings to avoid loading all into RAM
|
||||
all_embeddings = np.load(emb_path, mmap_mode="r")
|
||||
if all_embeddings.dtype != np.float32:
|
||||
all_embeddings = all_embeddings.astype(np.float32)
|
||||
|
||||
# First-stage ANN to collect candidate doc_ids
|
||||
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
||||
raw = searcher.search(
|
||||
data,
|
||||
first_stage_k,
|
||||
recompute_embeddings=False,
|
||||
complexity=128,
|
||||
beam_width=1,
|
||||
prune_ratio=0.0,
|
||||
batch_size=0,
|
||||
)
|
||||
labels = raw.get("labels")
|
||||
if labels is None:
|
||||
return []
|
||||
candidate_doc_ids: set[int] = set()
|
||||
for batch in labels:
|
||||
for sid in batch:
|
||||
try:
|
||||
idx = int(sid)
|
||||
except Exception:
|
||||
continue
|
||||
if 0 <= idx < len(self._labels_meta):
|
||||
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
|
||||
|
||||
# Exact scoring per doc (parallelized)
|
||||
assert self._docid_to_indices is not None
|
||||
|
||||
def _score_one(doc_id: int) -> tuple[float, int]:
|
||||
token_indices = self._docid_to_indices.get(doc_id, [])
|
||||
if not token_indices:
|
||||
return (0.0, doc_id)
|
||||
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
||||
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
|
||||
sim = np.dot(data, doc_vecs.T)
|
||||
# nan-safe
|
||||
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
||||
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
||||
return (float(score), doc_id)
|
||||
|
||||
scores: list[tuple[float, int]] = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
|
||||
for fut in concurrent.futures.as_completed(futures):
|
||||
scores.append(fut.result())
|
||||
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
|
||||
def search_exact_all(
|
||||
self,
|
||||
data: np.ndarray,
|
||||
topk: int,
|
||||
*,
|
||||
max_workers: int = 32,
|
||||
) -> list[tuple[float, int]]:
|
||||
"""
|
||||
Exact MaxSim over ALL documents (no ANN pre-filtering).
|
||||
|
||||
This computes, for each document, sum_i max_j dot(q_i, d_j).
|
||||
It memory-maps the persisted token-embedding matrix for scalability.
|
||||
"""
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(1, -1)
|
||||
if data.dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
self._load_labels_meta_if_needed()
|
||||
self._build_docid_to_indices_if_needed()
|
||||
|
||||
emb_path = self._embeddings_path()
|
||||
if not emb_path.exists():
|
||||
return self.search(data, topk)
|
||||
|
||||
all_embeddings = np.load(emb_path, mmap_mode="r")
|
||||
if all_embeddings.dtype != np.float32:
|
||||
all_embeddings = all_embeddings.astype(np.float32)
|
||||
|
||||
assert self._docid_to_indices is not None
|
||||
candidate_doc_ids = list(self._docid_to_indices.keys())
|
||||
|
||||
def _score_one(doc_id: int) -> tuple[float, int]:
|
||||
token_indices = self._docid_to_indices.get(doc_id, [])
|
||||
if not token_indices:
|
||||
return (0.0, doc_id)
|
||||
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
||||
sim = np.dot(data, doc_vecs.T)
|
||||
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
||||
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
||||
return (float(score), doc_id)
|
||||
|
||||
scores: list[tuple[float, int]] = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
|
||||
for fut in concurrent.futures.as_completed(futures):
|
||||
scores.append(fut.result())
|
||||
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -230,12 +231,18 @@ def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) ->
|
||||
return retriever
|
||||
|
||||
|
||||
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
||||
def _load_retriever_if_index_exists(index_path: str) -> Optional[LeannMultiVector]:
|
||||
index_base = Path(index_path)
|
||||
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||
if index_base.exists() and meta.exists() and labels.exists():
|
||||
try:
|
||||
with open(meta, "r", encoding="utf-8") as f:
|
||||
meta_json = json.load(f)
|
||||
dim = int(meta_json.get("dimensions", 128))
|
||||
except Exception:
|
||||
dim = 128
|
||||
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||
return None
|
||||
|
||||
@@ -390,11 +397,7 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
# Step 3: Build or load index
|
||||
retriever: Optional[LeannMultiVector] = None
|
||||
if not REBUILD_INDEX:
|
||||
try:
|
||||
one_vec = _embed_images(model, processor, [images[0]])[0]
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
||||
except Exception:
|
||||
retriever = None
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
|
||||
if retriever is None:
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
Reference in New Issue
Block a user