refactor: nits
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal
|
||||
import pickle
|
||||
@@ -18,7 +17,7 @@ from leann.interface import (
|
||||
|
||||
|
||||
def get_metric_map():
|
||||
from . import faiss
|
||||
from . import faiss # type: ignore
|
||||
|
||||
return {
|
||||
"mips": faiss.METRIC_INNER_PRODUCT,
|
||||
@@ -49,7 +48,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
self.dimensions = self.build_params.get("dimensions")
|
||||
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
from . import faiss
|
||||
from . import faiss # type: ignore
|
||||
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
@@ -117,7 +116,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
|
||||
**kwargs,
|
||||
)
|
||||
from . import faiss
|
||||
from . import faiss # type: ignore
|
||||
|
||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||
metric_enum = get_metric_map().get(self.distance_metric)
|
||||
|
||||
Reference in New Issue
Block a user