fix: mlx when searching, added to embedding_server

This commit is contained in:
Andy Lee
2025-07-14 01:11:21 -07:00
parent 8b4654921b
commit 3da5b44d7f
8 changed files with 315 additions and 885 deletions

View File

@@ -150,6 +150,7 @@ def create_hnsw_embedding_server(
model_name: str = "sentence-transformers/all-mpnet-base-v2",
custom_max_length_param: Optional[int] = None,
distance_metric: str = "mips",
use_mlx: bool = False,
):
"""
Create and start a ZMQ-based embedding server for HNSW backend.
@@ -167,9 +168,13 @@ def create_hnsw_embedding_server(
custom_max_length_param: Custom max sequence length
distance_metric: The distance metric to use
"""
print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
print(f"Tokenizer loaded successfully!")
if not use_mlx:
print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
print(f"Tokenizer loaded successfully!")
else:
print("Using MLX mode - tokenizer will be loaded separately")
tokenizer = None
# Device setup
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
@@ -191,8 +196,17 @@ def create_hnsw_embedding_server(
# Load model to the appropriate device
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Loading model {model_name}... (this may take a while if downloading)")
model = AutoModel.from_pretrained(model_name).to(device).eval()
print(f"Model {model_name} loaded successfully!")
if use_mlx:
# For MLX models, we need to use the MLX embedding computation
print("MLX model detected - using MLX backend for embeddings")
model = None # We'll handle MLX separately
tokenizer = None
else:
# Use standard transformers for non-MLX models
model = AutoModel.from_pretrained(model_name).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Model {model_name} loaded successfully!")
# Check port availability
import socket
@@ -312,8 +326,37 @@ def create_hnsw_embedding_server(
def print_elapsed(self):
return # Disabled for now
def _process_batch_mlx(texts_batch, ids_batch, missing_ids):
"""Process a batch of texts using MLX backend"""
try:
# Import MLX embedding computation from main API
from leann.api import compute_embeddings
# Compute embeddings using MLX
embeddings = compute_embeddings(texts_batch, model_name, use_mlx=True)
print(
f"[leann_backend_hnsw.hnsw_embedding_server LOG]: MLX embeddings computed for {len(texts_batch)} texts"
)
print(
f"[leann_backend_hnsw.hnsw_embedding_server LOG]: Embedding shape: {embeddings.shape}"
)
return embeddings
except Exception as e:
print(
f"[leann_backend_hnsw.hnsw_embedding_server LOG]: ERROR in MLX processing: {e}"
)
raise
def process_batch(texts_batch, ids_batch, missing_ids):
"""Process a batch of texts and return embeddings"""
# Handle MLX models separately
if use_mlx:
return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
_is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch)
@@ -927,6 +970,12 @@ if __name__ == "__main__":
parser.add_argument(
"--distance-metric", type=str, default="mips", help="Distance metric to use"
)
parser.add_argument(
"--use-mlx",
action="store_true",
default=False,
help="Use MLX for model inference",
)
args = parser.parse_args()
@@ -942,4 +991,5 @@ if __name__ == "__main__":
model_name=args.model_name,
custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric,
use_mlx=args.use_mlx,
)