From a13c527e39323f943a470e51b2a1253464500f8b Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Thu, 17 Jul 2025 17:02:47 -0700 Subject: [PATCH] feat: openai embeddings --- examples/openai_hnsw_example.py | 108 ++++++++++++++++++ .../leann_backend_diskann/embedding_server.py | 69 ++++++++--- .../hnsw_embedding_server.py | 75 +++++++++--- packages/leann-core/src/leann/api.py | 94 +++++++++++++-- .../src/leann/embedding_server_manager.py | 6 +- .../leann-core/src/leann/searcher_base.py | 8 +- 6 files changed, 311 insertions(+), 49 deletions(-) create mode 100644 examples/openai_hnsw_example.py diff --git a/examples/openai_hnsw_example.py b/examples/openai_hnsw_example.py new file mode 100644 index 0000000..451aeb9 --- /dev/null +++ b/examples/openai_hnsw_example.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +""" +OpenAI Embedding Example + +Complete example showing how to build and search with OpenAI embeddings using HNSW backend. +""" + +import os +import dotenv +from pathlib import Path +from leann.api import LeannBuilder, LeannSearcher + +# Load environment variables +dotenv.load_dotenv() + +def main(): + # Check if OpenAI API key is available + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("ERROR: OPENAI_API_KEY environment variable not set") + return False + + print(f"βœ… OpenAI API key found: {api_key[:10]}...") + + # Sample texts + sample_texts = [ + "Machine learning is a powerful technology that enables computers to learn from data.", + "Natural language processing helps computers understand and generate human language.", + "Deep learning uses neural networks with multiple layers to solve complex problems.", + "Computer vision allows machines to interpret and understand visual information.", + "Reinforcement learning trains agents to make decisions through trial and error.", + "Data science combines statistics, math, and programming to extract insights from data.", + "Artificial intelligence aims to create machines that can perform human-like tasks.", + "Python is a popular programming language used extensively in data science and AI.", + "Neural networks are inspired by the structure and function of the human brain.", + "Big data refers to extremely large datasets that require special tools to process." + ] + + INDEX_DIR = Path("./simple_openai_test_index") + INDEX_PATH = str(INDEX_DIR / "simple_test.leann") + + print(f"\n=== Building Index with OpenAI Embeddings ===") + print(f"Index path: {INDEX_PATH}") + + try: + # Use proper configuration for OpenAI embeddings + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="text-embedding-3-small", + embedding_mode="openai", + # HNSW settings for OpenAI embeddings + M=16, # Smaller graph degree + efConstruction=64, # Smaller construction complexity + is_compact=True, # Enable compact storage for recompute + is_recompute=True, # MUST enable for OpenAI embeddings + num_threads=1, + ) + + print(f"Adding {len(sample_texts)} texts to the index...") + for i, text in enumerate(sample_texts): + metadata = {"id": f"doc_{i}", "topic": "AI"} + builder.add_text(text, metadata) + + print("Building index...") + builder.build_index(INDEX_PATH) + print(f"βœ… Index built successfully!") + + except Exception as e: + print(f"❌ Error building index: {e}") + import traceback + traceback.print_exc() + return False + + print(f"\n=== Testing Search ===") + + try: + searcher = LeannSearcher(INDEX_PATH) + + test_queries = [ + "What is machine learning?", + "How do neural networks work?", + "Programming languages for data science" + ] + + for query in test_queries: + print(f"\nπŸ” Query: '{query}'") + results = searcher.search(query, top_k=3) + + print(f" Found {len(results)} results:") + for i, result in enumerate(results): + print(f" {i+1}. Score: {result.score:.4f}") + print(f" Text: {result.text[:80]}...") + + print(f"\nβœ… Search test completed successfully!") + return True + + except Exception as e: + print(f"❌ Error during search: {e}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + success = main() + if success: + print(f"\nπŸŽ‰ Simple OpenAI index test completed successfully!") + else: + print(f"\nπŸ’₯ Simple OpenAI index test failed!") \ No newline at end of file diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py index d517c0f..de58e3a 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py @@ -162,7 +162,7 @@ def create_embedding_server_thread( model_name="sentence-transformers/all-mpnet-base-v2", max_batch_size=128, passages_file: Optional[str] = None, - use_mlx: bool = False, + embedding_mode: str = "sentence-transformers", enable_warmup: bool = False, ): """ @@ -182,10 +182,27 @@ def create_embedding_server_thread( print(f"{RED}Port {zmq_port} is already in use{RESET}") return - if use_mlx: + # Auto-detect mode based on model name if not explicitly set + if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"): + embedding_mode = "openai" + + if embedding_mode == "mlx": from leann.api import compute_embeddings_mlx + import torch print("INFO: Using MLX for embeddings") - else: + # Set device to CPU for compatibility with DeviceTimer class + device = torch.device("cpu") + cuda_available = False + mps_available = False + elif embedding_mode == "openai": + from leann.api import compute_embeddings_openai + import torch + print("INFO: Using OpenAI API for embeddings") + # Set device to CPU for compatibility with DeviceTimer class + device = torch.device("cpu") + cuda_available = False + mps_available = False + elif embedding_mode == "sentence-transformers": # εˆε§‹εŒ–ζ¨‘εž‹ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) import torch @@ -216,6 +233,8 @@ def create_embedding_server_thread( print(f"INFO: Using FP16 precision with model: {model_name}") except Exception as e: print(f"WARNING: Model optimization failed: {e}") + else: + raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai") # Load passages from file if provided if passages_file and os.path.exists(passages_file): @@ -303,7 +322,7 @@ def create_embedding_server_thread( self.start_time = 0 self.end_time = 0 - if not use_mlx and torch.cuda.is_available(): + if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): self.start_event = torch.cuda.Event(enable_timing=True) self.end_event = torch.cuda.Event(enable_timing=True) else: @@ -317,25 +336,25 @@ def create_embedding_server_thread( self.end() def start(self): - if not use_mlx and torch.cuda.is_available(): + if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): torch.cuda.synchronize() self.start_event.record() else: - if not use_mlx and self.device.type == "mps": + if embedding_mode == "sentence-transformers" and self.device.type == "mps": torch.mps.synchronize() self.start_time = time.time() def end(self): - if not use_mlx and torch.cuda.is_available(): + if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): self.end_event.record() torch.cuda.synchronize() else: - if not use_mlx and self.device.type == "mps": + if embedding_mode == "sentence-transformers" and self.device.type == "mps": torch.mps.synchronize() self.end_time = time.time() def elapsed_time(self): - if not use_mlx and torch.cuda.is_available(): + if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): return self.start_event.elapsed_time(self.end_event) / 1000.0 else: return self.end_time - self.start_time @@ -571,13 +590,15 @@ def create_embedding_server_thread( chunk_texts = texts[i:end_idx] chunk_ids = node_ids[i:end_idx] - if use_mlx: + if embedding_mode == "mlx": embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name) - else: + elif embedding_mode == "openai": + embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name) + else: # sentence-transformers embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids) all_embeddings.append(embeddings_chunk) - if not use_mlx: + if embedding_mode == "sentence-transformers": if cuda_available: torch.cuda.empty_cache() elif device.type == "mps": @@ -586,9 +607,11 @@ def create_embedding_server_thread( hidden = np.vstack(all_embeddings) print(f"INFO: Combined embeddings shape: {hidden.shape}") else: - if use_mlx: + if embedding_mode == "mlx": hidden = compute_embeddings_mlx(texts, model_name) - else: + elif embedding_mode == "openai": + hidden = compute_embeddings_openai(texts, model_name) + else: # sentence-transformers hidden = process_batch_pytorch(texts, node_ids, missing_ids) # εΊεˆ—εŒ–ε“εΊ” @@ -610,7 +633,7 @@ def create_embedding_server_thread( print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds") - if not use_mlx: + if embedding_mode == "sentence-transformers": if device.type == "cuda": torch.cuda.synchronize() elif device.type == "mps": @@ -653,14 +676,14 @@ def create_embedding_server( lazy_load_passages=False, model_name="sentence-transformers/all-mpnet-base-v2", passages_file: Optional[str] = None, - use_mlx: bool = False, + embedding_mode: str = "sentence-transformers", enable_warmup: bool = False, ): """ εŽŸζœ‰ηš„ create_embedding_server ε‡½ζ•°δΏζŒδΈε˜ θΏ™δΈͺζ˜―ι˜»ε‘žη‰ˆζœ¬οΌŒη”¨δΊŽη›΄ζŽ₯运葌 """ - create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx, enable_warmup) + create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup) if __name__ == "__main__": @@ -677,9 +700,17 @@ if __name__ == "__main__": parser.add_argument("--lazy-load-passages", action="store_true", default=True) parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", help="Embedding model name") - parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings") + parser.add_argument("--embedding-mode", type=str, default="sentence-transformers", + choices=["sentence-transformers", "mlx", "openai"], + help="Embedding backend mode") + parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)") parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start") args = parser.parse_args() + + # Handle backward compatibility with use_mlx + embedding_mode = args.embedding_mode + if args.use_mlx: + embedding_mode = "mlx" create_embedding_server( domain=args.domain, @@ -693,6 +724,6 @@ if __name__ == "__main__": lazy_load_passages=args.lazy_load_passages, model_name=args.model_name, passages_file=args.passages_file, - use_mlx=args.use_mlx, + embedding_mode=embedding_mode, enable_warmup=not args.disable_warmup, ) diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index c19f581..f60085a 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -150,7 +150,7 @@ def create_hnsw_embedding_server( model_name: str = "sentence-transformers/all-mpnet-base-v2", custom_max_length_param: Optional[int] = None, distance_metric: str = "mips", - use_mlx: bool = False, + embedding_mode: str = "sentence-transformers", enable_warmup: bool = False, ): """ @@ -170,13 +170,22 @@ def create_hnsw_embedding_server( distance_metric: The distance metric to use enable_warmup: Whether to perform warmup requests on server start """ - if not use_mlx: + # Handle different embedding modes directly in HNSW server + + # Auto-detect mode based on model name if not explicitly set + if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"): + embedding_mode = "openai" + + if embedding_mode == "openai": + print(f"Using OpenAI API mode for {model_name}") + tokenizer = None # No local tokenizer needed for OpenAI API + elif embedding_mode == "mlx": + print(f"Using MLX mode for {model_name}") + tokenizer = None # MLX handles tokenization separately + else: # sentence-transformers print(f"Loading tokenizer for {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) print(f"Tokenizer loaded successfully!") - else: - print("Using MLX mode - tokenizer will be loaded separately") - tokenizer = None # Device setup mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() @@ -199,15 +208,17 @@ def create_hnsw_embedding_server( print(f"Starting HNSW server on port {zmq_port} with model {model_name}") print(f"Loading model {model_name}... (this may take a while if downloading)") - if use_mlx: + if embedding_mode == "mlx": # For MLX models, we need to use the MLX embedding computation print("MLX model detected - using MLX backend for embeddings") model = None # We'll handle MLX separately - tokenizer = None + elif embedding_mode == "openai": + # For OpenAI API, no local model needed + print("OpenAI API mode - no local model loading required") + model = None else: - # Use standard transformers for non-MLX models + # Use standard transformers for sentence-transformers models model = AutoModel.from_pretrained(model_name).to(device).eval() - tokenizer = AutoTokenizer.from_pretrained(model_name) print(f"Model {model_name} loaded successfully!") # Check port availability @@ -355,9 +366,12 @@ def create_hnsw_embedding_server( def process_batch(texts_batch, ids_batch, missing_ids): """Process a batch of texts and return embeddings""" - # Handle MLX models separately - if use_mlx: + # Handle different embedding modes + if embedding_mode == "mlx": return _process_batch_mlx(texts_batch, ids_batch, missing_ids) + elif embedding_mode == "openai": + from leann.api import compute_embeddings_openai + return compute_embeddings_openai(texts_batch, model_name) _is_e5_model = "e5" in model_name.lower() _is_bge_model = "bge" in model_name.lower() @@ -795,14 +809,33 @@ def create_hnsw_embedding_server( ) continue - # Standard embedding request + # Handle direct text embedding request (for OpenAI mode) + if embedding_mode == "openai" and isinstance(request_payload, list) and len(request_payload) > 0: + # Check if this is a direct text request (list of strings) + if all(isinstance(item, str) for item in request_payload): + print(f"Processing direct text embedding request for {len(request_payload)} texts") + + try: + from leann.api import compute_embeddings_openai + embeddings = compute_embeddings_openai(request_payload, model_name) + response = embeddings.tolist() + socket.send(msgpack.packb(response)) + e2e_end = time.time() + print(f"Text embedding E2E time: {e2e_end - e2e_start:.6f} seconds") + continue + except Exception as e: + print(f"ERROR: Failed to compute OpenAI embeddings: {e}") + socket.send(msgpack.packb([])) + continue + + # Standard embedding request (passage ID lookup) if ( not isinstance(request_payload, list) or len(request_payload) != 1 or not isinstance(request_payload[0], list) ): print( - f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}" + f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}" ) socket.send(msgpack.packb([[], []])) continue @@ -986,11 +1019,18 @@ if __name__ == "__main__": parser.add_argument( "--distance-metric", type=str, default="mips", help="Distance metric to use" ) + parser.add_argument( + "--embedding-mode", + type=str, + default="sentence-transformers", + choices=["sentence-transformers", "mlx", "openai"], + help="Embedding backend mode" + ) parser.add_argument( "--use-mlx", action="store_true", default=False, - help="Use MLX for model inference", + help="Use MLX for model inference (deprecated: use --embedding-mode mlx)", ) parser.add_argument( "--disable-warmup", @@ -1000,6 +1040,11 @@ if __name__ == "__main__": ) args = parser.parse_args() + + # Handle backward compatibility with use_mlx + embedding_mode = args.embedding_mode + if args.use_mlx: + embedding_mode = "mlx" # Create and start the HNSW embedding server create_hnsw_embedding_server( @@ -1013,6 +1058,6 @@ if __name__ == "__main__": model_name=args.model_name, custom_max_length_param=args.custom_max_length, distance_metric=args.distance_metric, - use_mlx=args.use_mlx, + embedding_mode=embedding_mode, enable_warmup=not args.disable_warmup, ) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 4e903c1..2042ac8 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -18,11 +18,40 @@ from .chat import get_llm def compute_embeddings( - chunks: List[str], model_name: str, use_mlx: bool = False + chunks: List[str], + model_name: str, + mode: str = "sentence-transformers" ) -> np.ndarray: - """Computes embeddings using sentence-transformers or MLX for consistent results.""" - if use_mlx: + """ + Computes embeddings using different backends. + + Args: + chunks: List of text chunks to embed + model_name: Name of the embedding model + mode: Embedding backend mode. Options: + - "sentence-transformers": Use sentence-transformers library (default) + - "mlx": Use MLX backend for Apple Silicon + - "openai": Use OpenAI embedding API + + Returns: + numpy array of embeddings + """ + # Auto-detect mode based on model name if not explicitly set + if mode == "sentence-transformers" and model_name.startswith("text-embedding-"): + mode = "openai" + + if mode == "mlx": return compute_embeddings_mlx(chunks, model_name) + elif mode == "openai": + return compute_embeddings_openai(chunks, model_name) + elif mode == "sentence-transformers": + return compute_embeddings_sentence_transformers(chunks, model_name) + else: + raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai") + + +def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray: + """Computes embeddings using sentence-transformers library.""" try: from sentence_transformers import SentenceTransformer except ImportError as e: @@ -53,6 +82,49 @@ def compute_embeddings( return embeddings +def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray: + """Computes embeddings using OpenAI API.""" + try: + import openai + import os + except ImportError as e: + raise RuntimeError( + "openai not available. Install with: uv pip install openai" + ) from e + + # Get API key from environment + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise RuntimeError("OPENAI_API_KEY environment variable not set") + + client = openai.OpenAI(api_key=api_key) + + print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...") + + # OpenAI has a limit on batch size and input length + max_batch_size = 100 # Conservative batch size + all_embeddings = [] + + for i in range(0, len(chunks), max_batch_size): + batch_chunks = chunks[i:i + max_batch_size] + print(f"INFO: Processing batch {i//max_batch_size + 1}/{(len(chunks) + max_batch_size - 1)//max_batch_size}") + + try: + response = client.embeddings.create( + model=model_name, + input=batch_chunks + ) + batch_embeddings = [embedding.embedding for embedding in response.data] + all_embeddings.extend(batch_embeddings) + except Exception as e: + print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}") + raise + + embeddings = np.array(all_embeddings, dtype=np.float32) + print(f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}") + return embeddings + + def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray: """Computes embeddings using an MLX model.""" try: @@ -140,7 +212,7 @@ class LeannBuilder: backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, - use_mlx: bool = False, + embedding_mode: str = "sentence-transformers", **backend_kwargs, ): self.backend_name = backend_name @@ -152,7 +224,7 @@ class LeannBuilder: self.backend_factory = backend_factory self.embedding_model = embedding_model self.dimensions = dimensions - self.use_mlx = use_mlx + self.embedding_mode = embedding_mode self.backend_kwargs = backend_kwargs self.chunks: List[Dict[str, Any]] = [] @@ -168,7 +240,7 @@ class LeannBuilder: raise ValueError("No chunks added.") if self.dimensions is None: self.dimensions = len( - compute_embeddings(["dummy"], self.embedding_model, self.use_mlx)[0] + compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode)[0] ) path = Path(index_path) index_dir = path.parent @@ -195,7 +267,7 @@ class LeannBuilder: pickle.dump(offset_map, f) texts_to_embed = [c["text"] for c in self.chunks] embeddings = compute_embeddings( - texts_to_embed, self.embedding_model, self.use_mlx + texts_to_embed, self.embedding_model, self.embedding_mode ) string_ids = [chunk["id"] for chunk in self.chunks] current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} @@ -210,7 +282,7 @@ class LeannBuilder: "embedding_model": self.embedding_model, "dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, - "use_mlx": self.use_mlx, + "embedding_mode": self.embedding_mode, "passage_sources": [ { "type": "jsonl", @@ -241,7 +313,11 @@ class LeannSearcher: self.meta_data = json.load(f) backend_name = self.meta_data["backend_name"] self.embedding_model = self.meta_data["embedding_model"] - self.use_mlx = self.meta_data.get("use_mlx", False) + # Support both old and new format + self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") + # Backward compatibility with use_mlx + if self.meta_data.get("use_mlx", False): + self.embedding_mode = "mlx" self.passage_manager = PassageManager(self.meta_data.get("passage_sources", [])) backend_factory = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index f409241..2a5e302 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -177,7 +177,7 @@ class EmbeddingServerManager: self.server_port: Optional[int] = None # atexit.register(self.stop_server) - def start_server(self, port: int, model_name: str, **kwargs) -> bool: + def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool: """ Starts the embedding server process. @@ -310,8 +310,8 @@ class EmbeddingServerManager: command.extend(["--passages-file", str(kwargs["passages_file"])]) # if "distance_metric" in kwargs and kwargs["distance_metric"]: # command.extend(["--distance-metric", kwargs["distance_metric"]]) - if "use_mlx" in kwargs and kwargs["use_mlx"]: - command.extend(["--use-mlx"]) + if embedding_mode != "sentence-transformers": + command.extend(["--embedding-mode", embedding_mode]) if "enable_warmup" in kwargs and not kwargs["enable_warmup"]: command.extend(["--disable-warmup"]) diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index 55c9843..0f40a85 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -78,12 +78,14 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): "Cannot use recompute mode without 'embedding_model' in meta.json." ) + embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") + server_started = self.embedding_server_manager.start_server( port=port, model_name=self.embedding_model, passages_file=passages_source_file, distance_metric=kwargs.get("distance_metric"), - use_mlx=kwargs.get("use_mlx", False), + embedding_mode=embedding_mode, enable_warmup=kwargs.get("enable_warmup", False), ) if not server_started: @@ -120,8 +122,8 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): # Fallback to direct computation from .api import compute_embeddings - use_mlx = self.meta.get("use_mlx", False) - return compute_embeddings([query], self.embedding_model, use_mlx) + embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") + return compute_embeddings([query], self.embedding_model, embedding_mode) def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray: """Compute embeddings using the ZMQ embedding server."""