style: format
This commit is contained in:
@@ -118,12 +118,16 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
# index_file_old = index_file.with_suffix(".old")
|
||||
# shutil.move(str(index_file), str(index_file_old))
|
||||
shutil.move(str(csr_temp_file), str(index_file))
|
||||
logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
||||
logger.info(
|
||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||
)
|
||||
else:
|
||||
# Clean up and fail fast
|
||||
if csr_temp_file.exists():
|
||||
os.remove(csr_temp_file)
|
||||
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
||||
raise RuntimeError(
|
||||
"CSR conversion failed - cannot proceed with compact format"
|
||||
)
|
||||
|
||||
|
||||
class HNSWSearcher(BaseSearcher):
|
||||
@@ -212,7 +216,9 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
@@ -221,7 +227,9 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
if zmq_port is not None:
|
||||
params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||
params.zmq_port = (
|
||||
zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||
)
|
||||
params.efSearch = complexity
|
||||
params.beam_size = beam_width
|
||||
|
||||
@@ -229,7 +237,8 @@ class HNSWSearcher(BaseSearcher):
|
||||
# This prevents early termination when all scores are in a narrow range
|
||||
embedding_model = self.meta.get("embedding_model", "").lower()
|
||||
if self.distance_metric == "cosine" and any(
|
||||
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
|
||||
openai_model in embedding_model
|
||||
for openai_model in ["text-embedding", "openai"]
|
||||
):
|
||||
params.check_relative_distance = False
|
||||
else:
|
||||
@@ -244,7 +253,9 @@ class HNSWSearcher(BaseSearcher):
|
||||
params.send_neigh_times_ratio = 0.0
|
||||
elif pruning_strategy == "proportional":
|
||||
params.local_prune = False
|
||||
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
|
||||
params.send_neigh_times_ratio = (
|
||||
1.0 # Any value > 1e-6 triggers proportional mode
|
||||
)
|
||||
else: # "global"
|
||||
params.local_prune = False
|
||||
params.send_neigh_times_ratio = 0.0
|
||||
@@ -266,7 +277,9 @@ class HNSWSearcher(BaseSearcher):
|
||||
params,
|
||||
)
|
||||
search_time = time.time() - search_time
|
||||
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||
logger.info(
|
||||
f" Search time in HNSWSearcher.search() backend: {search_time} seconds"
|
||||
)
|
||||
if self._id_map:
|
||||
|
||||
def map_label(x: int) -> str:
|
||||
@@ -274,10 +287,13 @@ class HNSWSearcher(BaseSearcher):
|
||||
return self._id_map[x]
|
||||
return str(x)
|
||||
|
||||
string_labels = [[map_label(int(l)) for l in batch_labels] for batch_labels in labels]
|
||||
string_labels = [
|
||||
[map_label(int(l)) for l in batch_labels] for batch_labels in labels
|
||||
]
|
||||
else:
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
[str(int_label) for int_label in batch_labels]
|
||||
for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
Reference in New Issue
Block a user