Initial commit
This commit is contained in:
60
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/clustering.py
vendored
Normal file
60
packages/leann-backend-hnsw/third_party/faiss/contrib/torch/clustering.py
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This contrib module contains Pytorch code for k-means clustering
|
||||
"""
|
||||
import faiss
|
||||
import faiss.contrib.torch_utils
|
||||
import torch
|
||||
|
||||
# the kmeans can produce both torch and numpy centroids
|
||||
from faiss.contrib.clustering import kmeans
|
||||
|
||||
|
||||
class DatasetAssign:
|
||||
"""Wrapper for a tensor that offers a function to assign the vectors
|
||||
to centroids. All other implementations offer the same interface"""
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def count(self):
|
||||
return self.x.shape[0]
|
||||
|
||||
def dim(self):
|
||||
return self.x.shape[1]
|
||||
|
||||
def get_subset(self, indices):
|
||||
return self.x[indices]
|
||||
|
||||
def perform_search(self, centroids):
|
||||
return faiss.knn(self.x, centroids, 1)
|
||||
|
||||
def assign_to(self, centroids, weights=None):
|
||||
D, I = self.perform_search(centroids)
|
||||
|
||||
I = I.ravel()
|
||||
D = D.ravel()
|
||||
nc, d = centroids.shape
|
||||
|
||||
sum_per_centroid = torch.zeros_like(centroids)
|
||||
if weights is None:
|
||||
sum_per_centroid.index_add_(0, I, self.x)
|
||||
else:
|
||||
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])
|
||||
|
||||
# the indices are still in numpy.
|
||||
return I.cpu().numpy(), D, sum_per_centroid
|
||||
|
||||
|
||||
class DatasetAssignGPU(DatasetAssign):
|
||||
|
||||
def __init__(self, res, x):
|
||||
DatasetAssign.__init__(self, x)
|
||||
self.res = res
|
||||
|
||||
def perform_search(self, centroids):
|
||||
return faiss.knn_gpu(self.res, self.x, centroids, 1)
|
||||
Reference in New Issue
Block a user