feat: support more embedders
This commit is contained in:
@@ -44,7 +44,7 @@ class HNSWEmbeddingServerManager:
|
||||
self.server_port = None
|
||||
atexit.register(self.stop_server)
|
||||
|
||||
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None):
|
||||
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None, distance_metric="mips"):
|
||||
"""
|
||||
Start the HNSW embedding server process.
|
||||
|
||||
@@ -52,6 +52,7 @@ class HNSWEmbeddingServerManager:
|
||||
port: ZMQ port for the server
|
||||
model_name: Name of the embedding model to use
|
||||
passages_file: Optional path to passages JSON file
|
||||
distance_metric: The distance metric to use
|
||||
"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
|
||||
@@ -69,7 +70,8 @@ class HNSWEmbeddingServerManager:
|
||||
sys.executable,
|
||||
"-m", "leann_backend_hnsw.hnsw_embedding_server",
|
||||
"--zmq-port", str(port),
|
||||
"--model-name", model_name
|
||||
"--model-name", model_name,
|
||||
"--distance-metric", distance_metric
|
||||
]
|
||||
|
||||
if passages_file:
|
||||
@@ -150,21 +152,16 @@ class HNSWBackend(LeannBackendFactoryInterface):
|
||||
path = Path(index_path)
|
||||
meta_path = path.parent / f"{path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(meta.get("embedding_model"))
|
||||
dimensions = model.get_sentence_embedding_dimension()
|
||||
kwargs['dimensions'] = dimensions
|
||||
except ImportError:
|
||||
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
|
||||
|
||||
dimensions = meta.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
||||
|
||||
kwargs['dimensions'] = dimensions
|
||||
return HNSWSearcher(index_path, **kwargs)
|
||||
|
||||
class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
@@ -172,10 +169,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
self.build_params = kwargs.copy()
|
||||
|
||||
# --- Configuration defaults with standardized names ---
|
||||
# Apply defaults and write them back to the build_params dict
|
||||
# so they can be saved in the metadata file by LeannBuilder.
|
||||
self.is_compact = self.build_params.setdefault("is_compact", True)
|
||||
self.is_recompute = self.build_params.setdefault("is_recompute", True) # Default: prune embeddings
|
||||
self.is_recompute = self.build_params.setdefault("is_recompute", True)
|
||||
|
||||
# --- Additional Options ---
|
||||
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
|
||||
@@ -186,6 +181,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
self.M = self.build_params.setdefault("M", 32)
|
||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||
self.dimensions = self.build_params.get("dimensions")
|
||||
|
||||
if self.is_skip_neighbors and not self.is_compact:
|
||||
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
|
||||
@@ -210,30 +206,25 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
|
||||
metric_str = self.distance_metric.lower()
|
||||
metric_enum = get_metric_map().get(metric_str)
|
||||
print('metric_enum', metric_enum,' metric_str', metric_str)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
|
||||
M = self.M
|
||||
efConstruction = self.efConstruction
|
||||
dim = data.shape[1]
|
||||
dim = self.dimensions
|
||||
if not dim:
|
||||
dim = data.shape[1]
|
||||
|
||||
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
|
||||
|
||||
try:
|
||||
if metric_enum == faiss.METRIC_INNER_PRODUCT:
|
||||
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
|
||||
else: # L2
|
||||
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
|
||||
|
||||
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
|
||||
index.hnsw.efConstruction = efConstruction
|
||||
|
||||
if metric_str == "cosine":
|
||||
faiss.normalize_L2(data)
|
||||
|
||||
print('starting to add vectors to index')
|
||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||
print('vectors added to index')
|
||||
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
faiss.write_index(index, str(index_file))
|
||||
@@ -243,7 +234,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
|
||||
# Generate passages file for recompute mode
|
||||
if self.is_recompute:
|
||||
self._generate_passages_file(index_dir, index_prefix, **kwargs)
|
||||
|
||||
@@ -423,13 +413,11 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
|
||||
# Apply additional configuration options with strict validation
|
||||
hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False)
|
||||
# If index is pruned, force recompute mode regardless of user setting
|
||||
hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False)
|
||||
hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0)
|
||||
hnsw_config.external_storage_path = self.config.get("external_storage_path")
|
||||
hnsw_config.zmq_port = self.config.get("zmq_port", 5557)
|
||||
|
||||
# CRITICAL ASSERTION: If index is pruned, recompute MUST be enabled
|
||||
if self.is_pruned and not hnsw_config.is_recompute:
|
||||
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
||||
|
||||
@@ -487,7 +475,7 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
if _check_port(zmq_port):
|
||||
print(f"INFO: Embedding server already running on port {zmq_port}")
|
||||
else:
|
||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file):
|
||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file, self.metric_str):
|
||||
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
|
||||
|
||||
# Give server extra time to fully initialize
|
||||
|
||||
Reference in New Issue
Block a user