add colqwen stuff and pass ruff

This commit is contained in:
yichuan-w
2025-09-22 22:01:29 +00:00
parent 72455bb269
commit 94d9a203a2
7 changed files with 98815 additions and 99376 deletions

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import sys
from pathlib import Path
from typing import List, Tuple
import numpy as np
@@ -28,8 +27,8 @@ class LeannMultiVector:
index_path: str,
dim: int = 128,
distance_metric: str = "mips",
M: int = 16,
efConstruction: int = 500,
m: int = 16,
ef_construction: int = 500,
is_compact: bool = False,
is_recompute: bool = False,
embedding_model_name: str = "colvision",
@@ -37,15 +36,15 @@ class LeannMultiVector:
self.index_path = index_path
self.dim = dim
self.embedding_model_name = embedding_model_name
self._pending_items: List[dict] = []
self._pending_items: list[dict] = []
self._backend_kwargs = {
"distance_metric": distance_metric,
"M": M,
"efConstruction": efConstruction,
"M": m,
"efConstruction": ef_construction,
"is_compact": is_compact,
"is_recompute": is_recompute,
}
self._labels_meta: List[dict] = []
self._labels_meta: list[dict] = []
def _meta_dict(self) -> dict:
return {
@@ -85,8 +84,8 @@ class LeannMultiVector:
if not self._pending_items:
return
embeddings: List[np.ndarray] = []
labels_meta: List[dict] = []
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
for item in self._pending_items:
doc_id = int(item["doc_id"])
@@ -108,12 +107,15 @@ class LeannMultiVector:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
# print shape of embeddings_np
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
ids = [str(i) for i in range(embeddings_np.shape[0])]
builder.build(embeddings_np, ids, self.index_path)
import json as _json
with open(self._meta_path(), "w", encoding="utf-8") as f:
_json.dump(self._meta_dict(), f, indent=2)
with open(self._labels_path(), "w", encoding="utf-8") as f:
@@ -127,10 +129,13 @@ class LeannMultiVector:
labels_path = self._labels_path()
if labels_path.exists():
import json as _json
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def search(self, data: np.ndarray, topk: int, first_stage_k: int = 50) -> List[Tuple[float, int]]:
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
@@ -175,5 +180,3 @@ class LeannMultiVector:
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores