Initial commit
This commit is contained in:
91
packages/leann-backend-hnsw/third_party/faiss/contrib/client_server.py
vendored
Executable file
91
packages/leann-backend-hnsw/third_party/faiss/contrib/client_server.py
vendored
Executable file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import faiss
|
||||
from typing import List, Tuple
|
||||
|
||||
from . import rpc
|
||||
|
||||
############################################################
|
||||
# Server implementation
|
||||
############################################################
|
||||
|
||||
|
||||
class SearchServer(rpc.Server):
|
||||
""" Assign version that can be exposed via RPC """
|
||||
|
||||
def __init__(self, s: int, index: faiss.Index):
|
||||
rpc.Server.__init__(self, s)
|
||||
self.index = index
|
||||
self.index_ivf = faiss.extract_index_ivf(index)
|
||||
|
||||
def set_nprobe(self, nprobe: int) -> int:
|
||||
""" set nprobe field """
|
||||
self.index_ivf.nprobe = nprobe
|
||||
|
||||
def get_ntotal(self) -> int:
|
||||
return self.index.ntotal
|
||||
|
||||
def __getattr__(self, f):
|
||||
# all other functions get forwarded to the index
|
||||
return getattr(self.index, f)
|
||||
|
||||
|
||||
def run_index_server(index: faiss.Index, port: int, v6: bool = False):
|
||||
""" serve requests for that index forerver """
|
||||
rpc.run_server(
|
||||
lambda s: SearchServer(s, index),
|
||||
port, v6=v6)
|
||||
|
||||
|
||||
############################################################
|
||||
# Client implementation
|
||||
############################################################
|
||||
|
||||
class ClientIndex:
|
||||
"""manages a set of distance sub-indexes. The sub_indexes search a
|
||||
subset of the inverted lists. Searches are merged afterwards
|
||||
"""
|
||||
|
||||
def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False):
|
||||
""" connect to a series of (host, port) pairs """
|
||||
self.sub_indexes = []
|
||||
for machine, port in machine_ports:
|
||||
self.sub_indexes.append(rpc.Client(machine, port, v6))
|
||||
|
||||
self.ni = len(self.sub_indexes)
|
||||
# pool of threads. Each thread manages one sub-index.
|
||||
self.pool = ThreadPool(self.ni)
|
||||
# test connection...
|
||||
self.ntotal = self.get_ntotal()
|
||||
self.verbose = False
|
||||
|
||||
def set_nprobe(self, nprobe: int) -> None:
|
||||
self.pool.map(
|
||||
lambda idx: idx.set_nprobe(nprobe),
|
||||
self.sub_indexes
|
||||
)
|
||||
|
||||
def set_omp_num_threads(self, nt: int) -> None:
|
||||
self.pool.map(
|
||||
lambda idx: idx.set_omp_num_threads(nt),
|
||||
self.sub_indexes
|
||||
)
|
||||
|
||||
def get_ntotal(self) -> None:
|
||||
return sum(self.pool.map(
|
||||
lambda idx: idx.get_ntotal(),
|
||||
self.sub_indexes
|
||||
))
|
||||
|
||||
def search(self, x, k: int):
|
||||
|
||||
rh = faiss.ResultHeap(x.shape[0], k)
|
||||
|
||||
for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes):
|
||||
rh.add_result(Di, Ii)
|
||||
rh.finalize()
|
||||
return rh.D, rh.I
|
||||
Reference in New Issue
Block a user