fix: auto-detect normalized embeddings and use cosine distance
- Add automatic detection for normalized embedding models (OpenAI, Voyage AI, Cohere) - Automatically set distance_metric='cosine' for normalized embeddings - Add warnings when using non-optimal distance metrics - Implement manual L2 normalization in HNSW backend (custom Faiss build lacks normalize_L2) - Fix DiskANN zmq_port compatibility with lazy loading strategy - Add documentation for normalized embeddings feature This fixes the low accuracy issue when using OpenAI text-embedding-3-small model with default MIPS metric.
This commit is contained in:
@@ -163,18 +163,44 @@ class DiskannSearcher(BaseSearcher):
|
||||
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
|
||||
fake_zmq_port = 6666
|
||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||
# Store the initialization parameters for later use
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
fake_zmq_port, # Initial port, can be updated at runtime
|
||||
"",
|
||||
"",
|
||||
)
|
||||
self._init_params = {
|
||||
"metric_enum": metric_enum,
|
||||
"full_index_prefix": full_index_prefix,
|
||||
"num_threads": self.num_threads,
|
||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||
"cache_mechanism": 1,
|
||||
"pq_prefix": "",
|
||||
"partition_prefix": "",
|
||||
}
|
||||
self._diskannpy = diskannpy
|
||||
self._current_zmq_port = None
|
||||
self._index = None
|
||||
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
||||
|
||||
def _ensure_index_loaded(self, zmq_port: int):
|
||||
"""Ensure the index is loaded with the correct zmq_port."""
|
||||
if self._index is None or self._current_zmq_port != zmq_port:
|
||||
# Need to (re)load the index with the correct zmq_port
|
||||
with suppress_cpp_output_if_needed():
|
||||
if self._index is not None:
|
||||
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
||||
else:
|
||||
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
||||
|
||||
self._index = self._diskannpy.StaticDiskFloatIndex(
|
||||
self._init_params["metric_enum"],
|
||||
self._init_params["full_index_prefix"],
|
||||
self._init_params["num_threads"],
|
||||
self._init_params["num_nodes_to_cache"],
|
||||
self._init_params["cache_mechanism"],
|
||||
zmq_port,
|
||||
self._init_params["pq_prefix"],
|
||||
self._init_params["partition_prefix"],
|
||||
)
|
||||
self._current_zmq_port = zmq_port
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -212,14 +238,15 @@ class DiskannSearcher(BaseSearcher):
|
||||
Returns:
|
||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||
"""
|
||||
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
current_port = self._index.get_zmq_port()
|
||||
if zmq_port != current_port:
|
||||
logger.debug(f"Updating DiskANN zmq_port from {current_port} to {zmq_port}")
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
self._ensure_index_loaded(zmq_port)
|
||||
else:
|
||||
# If not recomputing, we still need an index, use a default port
|
||||
if self._index is None:
|
||||
self._ensure_index_loaded(6666) # Default port when not recomputing
|
||||
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
if pruning_strategy == "proportional":
|
||||
|
||||
@@ -28,6 +28,12 @@ def get_metric_map():
|
||||
}
|
||||
|
||||
|
||||
def normalize_l2(data: np.ndarray) -> np.ndarray:
|
||||
norms = np.linalg.norm(data, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1 # Avoid division by zero
|
||||
return data / norms
|
||||
|
||||
|
||||
@register_backend("hnsw")
|
||||
class HNSWBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -76,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index.hnsw.efConstruction = self.efConstruction
|
||||
|
||||
if self.distance_metric.lower() == "cosine":
|
||||
faiss.normalize_L2(data)
|
||||
data = normalize_l2(data)
|
||||
|
||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
@@ -186,7 +192,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
if self.distance_metric == "cosine":
|
||||
faiss.normalize_L2(query)
|
||||
query = normalize_l2(query)
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
if zmq_port is not None:
|
||||
|
||||
@@ -12,6 +12,7 @@ from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
|
||||
@@ -163,6 +164,76 @@ class LeannBuilder:
|
||||
self.embedding_model = embedding_model
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
|
||||
# Check if we need to use cosine distance for normalized embeddings
|
||||
normalized_embeddings_models = {
|
||||
# OpenAI models
|
||||
("openai", "text-embedding-ada-002"),
|
||||
("openai", "text-embedding-3-small"),
|
||||
("openai", "text-embedding-3-large"),
|
||||
# Voyage AI models
|
||||
("voyage", "voyage-2"),
|
||||
("voyage", "voyage-3"),
|
||||
("voyage", "voyage-large-2"),
|
||||
("voyage", "voyage-multilingual-2"),
|
||||
("voyage", "voyage-code-2"),
|
||||
# Cohere models
|
||||
("cohere", "embed-english-v3.0"),
|
||||
("cohere", "embed-multilingual-v3.0"),
|
||||
("cohere", "embed-english-light-v3.0"),
|
||||
("cohere", "embed-multilingual-light-v3.0"),
|
||||
}
|
||||
|
||||
# Also check for patterns in model names
|
||||
is_normalized = False
|
||||
current_model_lower = embedding_model.lower()
|
||||
current_mode_lower = embedding_mode.lower()
|
||||
|
||||
# Check exact matches
|
||||
for mode, model in normalized_embeddings_models:
|
||||
if (current_mode_lower == mode and current_model_lower == model) or (
|
||||
mode in current_mode_lower and model in current_model_lower
|
||||
):
|
||||
is_normalized = True
|
||||
break
|
||||
|
||||
# Check patterns
|
||||
if not is_normalized:
|
||||
# OpenAI patterns
|
||||
if "openai" in current_mode_lower or "openai" in current_model_lower:
|
||||
if any(
|
||||
pattern in current_model_lower
|
||||
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
|
||||
):
|
||||
is_normalized = True
|
||||
# Voyage patterns
|
||||
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
|
||||
is_normalized = True
|
||||
# Cohere patterns
|
||||
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
|
||||
if "embed" in current_model_lower:
|
||||
is_normalized = True
|
||||
|
||||
# Handle distance metric
|
||||
if is_normalized and "distance_metric" not in backend_kwargs:
|
||||
backend_kwargs["distance_metric"] = "cosine"
|
||||
warnings.warn(
|
||||
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
|
||||
f"Automatically setting distance_metric='cosine' for optimal performance. "
|
||||
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
|
||||
current_metric = backend_kwargs.get("distance_metric", "mips")
|
||||
warnings.warn(
|
||||
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
|
||||
f"'{embedding_model}' may lead to suboptimal search results. "
|
||||
f"Consider using 'cosine' distance metric for better performance.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self.backend_kwargs = backend_kwargs
|
||||
self.chunks: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user