fix: auto-detect normalized embeddings and use cosine distance

- Add automatic detection for normalized embedding models (OpenAI, Voyage AI, Cohere)
- Automatically set distance_metric='cosine' for normalized embeddings
- Add warnings when using non-optimal distance metrics
- Implement manual L2 normalization in HNSW backend (custom Faiss build lacks normalize_L2)
- Fix DiskANN zmq_port compatibility with lazy loading strategy
- Add documentation for normalized embeddings feature

This fixes the low accuracy issue when using OpenAI text-embedding-3-small model with default MIPS metric.
This commit is contained in:
Andy Lee
2025-07-27 20:21:05 -07:00
parent 48207c3b69
commit 9a5c197acd
6 changed files with 223 additions and 26 deletions

View File

@@ -163,18 +163,44 @@ class DiskannSearcher(BaseSearcher):
self.num_threads = kwargs.get("num_threads", 8)
fake_zmq_port = 6666
# For DiskANN, we need to reinitialize the index when zmq_port changes
# Store the initialization parameters for later use
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum,
full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
fake_zmq_port, # Initial port, can be updated at runtime
"",
"",
)
self._init_params = {
"metric_enum": metric_enum,
"full_index_prefix": full_index_prefix,
"num_threads": self.num_threads,
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
"cache_mechanism": 1,
"pq_prefix": "",
"partition_prefix": "",
}
self._diskannpy = diskannpy
self._current_zmq_port = None
self._index = None
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
def _ensure_index_loaded(self, zmq_port: int):
"""Ensure the index is loaded with the correct zmq_port."""
if self._index is None or self._current_zmq_port != zmq_port:
# Need to (re)load the index with the correct zmq_port
with suppress_cpp_output_if_needed():
if self._index is not None:
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
else:
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
self._index = self._diskannpy.StaticDiskFloatIndex(
self._init_params["metric_enum"],
self._init_params["full_index_prefix"],
self._init_params["num_threads"],
self._init_params["num_nodes_to_cache"],
self._init_params["cache_mechanism"],
zmq_port,
self._init_params["pq_prefix"],
self._init_params["partition_prefix"],
)
self._current_zmq_port = zmq_port
def search(
self,
@@ -212,14 +238,15 @@ class DiskannSearcher(BaseSearcher):
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
# Handle zmq_port compatibility: DiskANN can now update port at runtime
# Handle zmq_port compatibility: Ensure index is loaded with correct port
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
current_port = self._index.get_zmq_port()
if zmq_port != current_port:
logger.debug(f"Updating DiskANN zmq_port from {current_port} to {zmq_port}")
self._index.set_zmq_port(zmq_port)
self._ensure_index_loaded(zmq_port)
else:
# If not recomputing, we still need an index, use a default port
if self._index is None:
self._ensure_index_loaded(6666) # Default port when not recomputing
# DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional":