add colqwen stuff and pass ruff
This commit is contained in:
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -28,8 +27,8 @@ class LeannMultiVector:
|
||||
index_path: str,
|
||||
dim: int = 128,
|
||||
distance_metric: str = "mips",
|
||||
M: int = 16,
|
||||
efConstruction: int = 500,
|
||||
m: int = 16,
|
||||
ef_construction: int = 500,
|
||||
is_compact: bool = False,
|
||||
is_recompute: bool = False,
|
||||
embedding_model_name: str = "colvision",
|
||||
@@ -37,15 +36,15 @@ class LeannMultiVector:
|
||||
self.index_path = index_path
|
||||
self.dim = dim
|
||||
self.embedding_model_name = embedding_model_name
|
||||
self._pending_items: List[dict] = []
|
||||
self._pending_items: list[dict] = []
|
||||
self._backend_kwargs = {
|
||||
"distance_metric": distance_metric,
|
||||
"M": M,
|
||||
"efConstruction": efConstruction,
|
||||
"M": m,
|
||||
"efConstruction": ef_construction,
|
||||
"is_compact": is_compact,
|
||||
"is_recompute": is_recompute,
|
||||
}
|
||||
self._labels_meta: List[dict] = []
|
||||
self._labels_meta: list[dict] = []
|
||||
|
||||
def _meta_dict(self) -> dict:
|
||||
return {
|
||||
@@ -85,8 +84,8 @@ class LeannMultiVector:
|
||||
if not self._pending_items:
|
||||
return
|
||||
|
||||
embeddings: List[np.ndarray] = []
|
||||
labels_meta: List[dict] = []
|
||||
embeddings: list[np.ndarray] = []
|
||||
labels_meta: list[dict] = []
|
||||
|
||||
for item in self._pending_items:
|
||||
doc_id = int(item["doc_id"])
|
||||
@@ -108,12 +107,15 @@ class LeannMultiVector:
|
||||
return
|
||||
|
||||
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||
# print shape of embeddings_np
|
||||
print(embeddings_np.shape)
|
||||
|
||||
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||
ids = [str(i) for i in range(embeddings_np.shape[0])]
|
||||
builder.build(embeddings_np, ids, self.index_path)
|
||||
|
||||
import json as _json
|
||||
|
||||
with open(self._meta_path(), "w", encoding="utf-8") as f:
|
||||
_json.dump(self._meta_dict(), f, indent=2)
|
||||
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||
@@ -127,10 +129,13 @@ class LeannMultiVector:
|
||||
labels_path = self._labels_path()
|
||||
if labels_path.exists():
|
||||
import json as _json
|
||||
|
||||
with open(labels_path, encoding="utf-8") as f:
|
||||
self._labels_meta = _json.load(f)
|
||||
|
||||
def search(self, data: np.ndarray, topk: int, first_stage_k: int = 50) -> List[Tuple[float, int]]:
|
||||
def search(
|
||||
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||
) -> list[tuple[float, int]]:
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(1, -1)
|
||||
if data.dtype != np.float32:
|
||||
@@ -175,5 +180,3 @@ class LeannMultiVector:
|
||||
|
||||
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||
return scores[:topk] if len(scores) >= topk else scores
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user