This commit is contained in:
yichuan520030910320
2025-07-21 21:54:27 -07:00
3 changed files with 9 additions and 8 deletions

View File

@@ -57,7 +57,6 @@ def create_hnsw_embedding_server(
finally: finally:
sys.path.pop(0) sys.path.pop(0)
# Check port availability # Check port availability
import socket import socket
@@ -152,7 +151,9 @@ def create_hnsw_embedding_server(
raise raise
# Process embeddings # Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
print( print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )
@@ -278,7 +279,7 @@ if __name__ == "__main__":
"--embedding-mode", "--embedding-mode",
type=str, type=str,
default="sentence-transformers", default="sentence-transformers",
choices=["sentence-transformers", "openai"], choices=["sentence-transformers", "openai", "mlx"],
help="Embedding backend mode", help="Embedding backend mode",
) )

View File

@@ -42,8 +42,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"WARNING: embedding_model not found in meta.json. Recompute will fail." "WARNING: embedding_model not found in meta.json. Recompute will fail."
) )
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
self.embedding_server_manager = EmbeddingServerManager( self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name backend_module_name=backend_module_name,
) )
def _load_meta(self) -> Dict[str, Any]: def _load_meta(self) -> Dict[str, Any]:
@@ -67,14 +69,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"Cannot use recompute mode without 'embedding_model' in meta.json." "Cannot use recompute mode without 'embedding_model' in meta.json."
) )
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
server_started, actual_port = self.embedding_server_manager.start_server( server_started, actual_port = self.embedding_server_manager.start_server(
port=port, port=port,
model_name=self.embedding_model, model_name=self.embedding_model,
embedding_mode=self.embedding_mode,
passages_file=passages_source_file, passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"), distance_metric=kwargs.get("distance_metric"),
embedding_mode=embedding_mode,
enable_warmup=kwargs.get("enable_warmup", False), enable_warmup=kwargs.get("enable_warmup", False),
) )
if not server_started: if not server_started:

View File

@@ -12,7 +12,7 @@ else:
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="hnsw",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
use_mlx=True, embedding_mode="mlx",
) )
# 2. Add documents # 2. Add documents