From 48dda1cb5be18bd5b5e78257b58c2d8ff8f36243 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Sun, 13 Jul 2025 02:13:04 -0700 Subject: [PATCH] feat: mlx --- build_mlx_index.py | 34 +++++ .../leann_backend_diskann/embedding_server.py | 116 +++++++++------- packages/leann-core/src/leann/api.py | 60 ++++++-- tests/sanity_checks/benchmark_embeddings.py | 128 ++++++++++++++++++ 4 files changed, 278 insertions(+), 60 deletions(-) create mode 100644 build_mlx_index.py create mode 100644 tests/sanity_checks/benchmark_embeddings.py diff --git a/build_mlx_index.py b/build_mlx_index.py new file mode 100644 index 0000000..c7c6e4f --- /dev/null +++ b/build_mlx_index.py @@ -0,0 +1,34 @@ +from leann.api import LeannBuilder +import os + +# Define the path for our new MLX-based index +INDEX_PATH = "./mlx_diskann_index/leann" + +if os.path.exists(INDEX_PATH + ".meta.json"): + print(f"Index already exists at {INDEX_PATH}. Skipping build.") +else: + print("Initializing LeannBuilder with MLX support...") + # 1. Configure LeannBuilder to use MLX + builder = LeannBuilder( + backend_name="diskann", + embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", + use_mlx=True + ) + + # 2. Add documents + print("Adding documents...") + docs = [ + "MLX is an array framework for machine learning on Apple silicon.", + "It was designed by Apple's machine learning research team.", + "The mlx-community organization provides pre-trained models in MLX format.", + "It supports operations on multi-dimensional arrays.", + "Leann can now use MLX for its embedding models." + ] + for doc in docs: + builder.add_text(doc) + + # 3. Build the index + print(f"Building the MLX-based index at: {INDEX_PATH}") + builder.build_index(INDEX_PATH) + print("\nSuccessfully built the index with MLX embeddings!") + print(f"Check the metadata file: {INDEX_PATH}.meta.json") diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py index 8c09e37..49442d0 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py @@ -5,7 +5,6 @@ Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern import pickle import argparse -import threading import time import json from typing import Dict, Any, Optional, Union @@ -16,7 +15,6 @@ from contextlib import contextmanager import zmq import numpy as np from pathlib import Path -import pickle RED = "\033[91m" RESET = "\033[0m" @@ -154,6 +152,7 @@ def create_embedding_server_thread( model_name="sentence-transformers/all-mpnet-base-v2", max_batch_size=128, passages_file: Optional[str] = None, + use_mlx: bool = False, ): """ 在当前线程中创建并运行 embedding server @@ -172,36 +171,40 @@ def create_embedding_server_thread( print(f"{RED}Port {zmq_port} is already in use{RESET}") return - # 初始化模型 - tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - import torch - - # 选择设备 - mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() - cuda_available = torch.cuda.is_available() - - if cuda_available: - device = torch.device("cuda") - print("INFO: Using CUDA device") - elif mps_available: - device = torch.device("mps") - print("INFO: Using MPS device (Apple Silicon)") + if use_mlx: + from leann.api import compute_embeddings_mlx + print("INFO: Using MLX for embeddings") else: - device = torch.device("cpu") - print("INFO: Using CPU device") - - # 加载模型 - print(f"INFO: Loading model {model_name}") - model = AutoModel.from_pretrained(model_name).to(device).eval() + # 初始化模型 + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + import torch - # 优化模型 - if cuda_available or mps_available: - try: - model = model.half() - model = torch.compile(model) - print(f"INFO: Using FP16 precision with model: {model_name}") - except Exception as e: - print(f"WARNING: Model optimization failed: {e}") + # 选择设备 + mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() + cuda_available = torch.cuda.is_available() + + if cuda_available: + device = torch.device("cuda") + print("INFO: Using CUDA device") + elif mps_available: + device = torch.device("mps") + print("INFO: Using MPS device (Apple Silicon)") + else: + device = torch.device("cpu") + print("INFO: Using CPU device") + + # 加载模型 + print(f"INFO: Loading model {model_name}") + model = AutoModel.from_pretrained(model_name).to(device).eval() + + # 优化模型 + if cuda_available or mps_available: + try: + model = model.half() + model = torch.compile(model) + print(f"INFO: Using FP16 precision with model: {model_name}") + except Exception as e: + print(f"WARNING: Model optimization failed: {e}") # Load passages from file if provided if passages_file and os.path.exists(passages_file): @@ -233,7 +236,7 @@ def create_embedding_server_thread( self.start_time = 0 self.end_time = 0 - if cuda_available: + if not use_mlx and torch.cuda.is_available(): self.start_event = torch.cuda.Event(enable_timing=True) self.end_event = torch.cuda.Event(enable_timing=True) else: @@ -247,25 +250,25 @@ def create_embedding_server_thread( self.end() def start(self): - if cuda_available: + if not use_mlx and torch.cuda.is_available(): torch.cuda.synchronize() self.start_event.record() else: - if self.device.type == "mps": + if not use_mlx and self.device.type == "mps": torch.mps.synchronize() self.start_time = time.time() def end(self): - if cuda_available: + if not use_mlx and torch.cuda.is_available(): self.end_event.record() torch.cuda.synchronize() else: - if self.device.type == "mps": + if not use_mlx and self.device.type == "mps": torch.mps.synchronize() self.end_time = time.time() def elapsed_time(self): - if cuda_available: + if not use_mlx and torch.cuda.is_available(): return self.start_event.elapsed_time(self.end_event) / 1000.0 else: return self.end_time - self.start_time @@ -273,7 +276,7 @@ def create_embedding_server_thread( def print_elapsed(self): print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds") - def process_batch(texts_batch, ids_batch, missing_ids): + def process_batch_pytorch(texts_batch, ids_batch, missing_ids): """处理文本批次""" batch_size = len(texts_batch) print(f"INFO: Processing batch of size {batch_size}") @@ -351,7 +354,7 @@ def create_embedding_server_thread( print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes") e2e_start = time.time() - lookup_timer = DeviceTimer("text lookup", device) + lookup_timer = DeviceTimer("text lookup") # 解析请求 req_proto = embedding_pb2.NodeEmbeddingRequest() @@ -397,18 +400,25 @@ def create_embedding_server_thread( chunk_texts = texts[i:end_idx] chunk_ids = node_ids[i:end_idx] - embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids) + if use_mlx: + embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name) + else: + embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids) all_embeddings.append(embeddings_chunk) - if cuda_available: - torch.cuda.empty_cache() - elif device.type == "mps": - torch.mps.empty_cache() + if not use_mlx: + if cuda_available: + torch.cuda.empty_cache() + elif device.type == "mps": + torch.mps.empty_cache() hidden = np.vstack(all_embeddings) print(f"INFO: Combined embeddings shape: {hidden.shape}") else: - hidden = process_batch(texts, node_ids, missing_ids) + if use_mlx: + hidden = compute_embeddings_mlx(texts, model_name) + else: + hidden = process_batch_pytorch(texts, node_ids, missing_ids) # 序列化响应 ser_start = time.time() @@ -429,16 +439,16 @@ def create_embedding_server_thread( print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds") - if device.type == "cuda": - torch.cuda.synchronize() - elif device.type == "mps": - torch.mps.synchronize() + if not use_mlx: + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "mps": + torch.mps.synchronize() e2e_end = time.time() print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds") except zmq.Again: print("INFO: ZMQ socket timeout, continuing to listen") - # REP套接字不需要重新创建,只需要继续监听 continue except Exception as e: print(f"ERROR: Error in ZMQ server: {e}") @@ -460,7 +470,6 @@ def create_embedding_server_thread( raise -# 保持原有的 create_embedding_server 函数不变,只添加线程化版本 def create_embedding_server( domain="demo", load_passages=True, @@ -473,12 +482,13 @@ def create_embedding_server( lazy_load_passages=False, model_name="sentence-transformers/all-mpnet-base-v2", passages_file: Optional[str] = None, + use_mlx: bool = False, ): """ 原有的 create_embedding_server 函数保持不变 这个是阻塞版本,用于直接运行 """ - create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file) + create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx) if __name__ == "__main__": @@ -495,6 +505,7 @@ if __name__ == "__main__": parser.add_argument("--lazy-load-passages", action="store_true", default=True) parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", help="Embedding model name") + parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings") args = parser.parse_args() create_embedding_server( @@ -509,4 +520,5 @@ if __name__ == "__main__": lazy_load_passages=args.lazy_load_passages, model_name=args.model_name, passages_file=args.passages_file, - ) \ No newline at end of file + use_mlx=args.use_mlx, + ) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index b8f2783..7677f4e 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -1,3 +1,4 @@ + """ This file contains the core API for the LEANN project, now definitively updated with the correct, original embedding logic from the user's reference code. @@ -17,8 +18,10 @@ from .interface import LeannBackendFactoryInterface # --- The Correct, Verified Embedding Logic from old_code.py --- -def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray: - """Computes embeddings using sentence-transformers for consistent results.""" +def compute_embeddings(chunks: List[str], model_name: str, use_mlx: bool = False) -> np.ndarray: + """Computes embeddings using sentence-transformers or MLX for consistent results.""" + if use_mlx: + return compute_embeddings_mlx(chunks, model_name) try: from sentence_transformers import SentenceTransformer except ImportError as e: @@ -44,6 +47,45 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray: return embeddings +def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray: + """Computes embeddings using an MLX model.""" + try: + import mlx.core as mx + from mlx_lm.utils import load + except ImportError as e: + raise RuntimeError( + f"MLX or related libraries not available. Install with: pip install mlx mlx-lm" + ) from e + + print(f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}'...") + + # Load model and tokenizer + model, tokenizer = load(model_name) + + # Process each chunk + all_embeddings = [] + for chunk in chunks: + # Tokenize + token_ids = tokenizer.encode(chunk) + + # Convert to MLX array and add batch dimension + input_ids = mx.array([token_ids]) + + # Get embeddings + embeddings = model(input_ids) + + # Mean pooling (since we only have one sequence, just take the mean) + pooled = embeddings.mean(axis=1) # Shape: (1, hidden_size) + + # Convert individual embedding to numpy via list (to handle bfloat16) + pooled_list = pooled[0].tolist() # Remove batch dimension and convert to list + pooled_numpy = np.array(pooled_list, dtype=np.float32) + all_embeddings.append(pooled_numpy) + + # Stack numpy arrays + return np.stack(all_embeddings) + + # --- Core API Classes (Restored and Unchanged) --- @dataclass @@ -83,7 +125,7 @@ class PassageManager: raise KeyError(f"Passage ID not found: {passage_id}") class LeannBuilder: - def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, **backend_kwargs): + def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, use_mlx: bool = False, **backend_kwargs): self.backend_name = backend_name backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: @@ -91,6 +133,7 @@ class LeannBuilder: self.backend_factory = backend_factory self.embedding_model = embedding_model self.dimensions = dimensions + self.use_mlx = use_mlx self.backend_kwargs = backend_kwargs self.chunks: List[Dict[str, Any]] = [] @@ -102,7 +145,7 @@ class LeannBuilder: def build_index(self, index_path: str): if not self.chunks: raise ValueError("No chunks added.") - if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model)[0]) + if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model, self.use_mlx)[0]) path = Path(index_path) index_dir = path.parent index_name = path.name @@ -118,7 +161,7 @@ class LeannBuilder: offset_map[chunk["id"]] = offset with open(offset_file, 'wb') as f: pickle.dump(offset_map, f) texts_to_embed = [c["text"] for c in self.chunks] - embeddings = compute_embeddings(texts_to_embed, self.embedding_model) + embeddings = compute_embeddings(texts_to_embed, self.embedding_model, self.use_mlx) string_ids = [chunk["id"] for chunk in self.chunks] current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions} builder_instance = self.backend_factory.builder(**current_backend_kwargs) @@ -126,7 +169,7 @@ class LeannBuilder: leann_meta_path = index_dir / f"{index_name}.meta.json" meta_data = { "version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model, - "dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, + "dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, "use_mlx": self.use_mlx, "passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}] } @@ -145,6 +188,7 @@ class LeannSearcher: with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f) backend_name = self.meta_data['backend_name'] self.embedding_model = self.meta_data['embedding_model'] + self.use_mlx = self.meta_data.get('use_mlx', False) self.passage_manager = PassageManager(self.meta_data.get('passage_sources', [])) backend_factory = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.") @@ -157,7 +201,7 @@ class LeannSearcher: print(f" Top_k: {top_k}") print(f" Search kwargs: {search_kwargs}") - query_embedding = compute_embeddings([query], self.embedding_model) + query_embedding = compute_embeddings([query], self.embedding_model, self.use_mlx) print(f" Generated embedding shape: {query_embedding.shape}") print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}") print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}") @@ -212,4 +256,4 @@ class LeannChat: print(f"Leann: {response}") except (KeyboardInterrupt, EOFError): print("\nGoodbye!") - break \ No newline at end of file + break diff --git a/tests/sanity_checks/benchmark_embeddings.py b/tests/sanity_checks/benchmark_embeddings.py new file mode 100644 index 0000000..0154204 --- /dev/null +++ b/tests/sanity_checks/benchmark_embeddings.py @@ -0,0 +1,128 @@ +import time +import numpy as np +import matplotlib.pyplot as plt +import torch +from sentence_transformers import SentenceTransformer +import mlx.core as mx +from mlx_lm import load + +# --- Configuration --- +MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B" +MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ" +BATCH_SIZES = [1, 8, 16, 32, 64, 128] +NUM_RUNS = 10 # Number of runs to average for each batch size +WARMUP_RUNS = 2 # Number of warm-up runs + +# --- Generate Dummy Data --- +DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES) + +# --- Benchmark Functions ---b + +def benchmark_torch(model, sentences): + start_time = time.time() + model.encode(sentences, convert_to_numpy=True) + end_time = time.time() + return (end_time - start_time) * 1000 # Return time in ms + +def benchmark_mlx(model, tokenizer, sentences): + start_time = time.time() + + # Tokenize sentences using MLX tokenizer + tokens = [] + for sentence in sentences: + token_ids = tokenizer.encode(sentence) + tokens.append(token_ids) + + # Pad sequences to the same length + max_len = max(len(t) for t in tokens) + input_ids = [] + attention_mask = [] + + for token_seq in tokens: + # Pad sequence + padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq)) + input_ids.append(padded) + # Create attention mask (1 for real tokens, 0 for padding) + mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq)) + attention_mask.append(mask) + + # Convert to MLX arrays + input_ids = mx.array(input_ids) + attention_mask = mx.array(attention_mask) + + # Get embeddings + embeddings = model(input_ids) + + # Mean pooling + mask = mx.expand_dims(attention_mask, -1) + sum_embeddings = (embeddings * mask).sum(axis=1) + sum_mask = mask.sum(axis=1) + _ = sum_embeddings / sum_mask + + mx.eval() # Ensure computation is finished + end_time = time.time() + return (end_time - start_time) * 1000 # Return time in ms + +# --- Main Execution --- +def main(): + print("--- Initializing Models ---") + # Load PyTorch model + print(f"Loading PyTorch model: {MODEL_NAME_TORCH}") + device = "mps" if torch.backends.mps.is_available() else "cpu" + model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device) + print(f"PyTorch model loaded on: {device}") + + # Load MLX model + print(f"Loading MLX model: {MODEL_NAME_MLX}") + model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX) + print("MLX model loaded.") + + # --- Warm-up --- + print("\n--- Performing Warm-up Runs ---") + for _ in range(WARMUP_RUNS): + benchmark_torch(model_torch, DUMMY_SENTENCES[:1]) + benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1]) + print("Warm-up complete.") + + # --- Benchmarking --- + print("\n--- Starting Benchmark ---") + results_torch = [] + results_mlx = [] + + for batch_size in BATCH_SIZES: + print(f"Benchmarking batch size: {batch_size}") + sentences_batch = DUMMY_SENTENCES[:batch_size] + + # Benchmark PyTorch + torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)] + results_torch.append(np.mean(torch_times)) + + # Benchmark MLX + mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)] + results_mlx.append(np.mean(mlx_times)) + + print("\n--- Benchmark Results (Average time per batch in ms) ---") + print(f"Batch Sizes: {BATCH_SIZES}") + print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}") + print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}") + + # --- Plotting --- + print("\n--- Generating Plot ---") + plt.figure(figsize=(10, 6)) + plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})') + plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX') + + plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}') + plt.xlabel("Batch Size") + plt.ylabel("Average Time per Batch (ms)") + plt.xticks(BATCH_SIZES) + plt.grid(True) + plt.legend() + + # Save the plot + output_filename = "embedding_benchmark.png" + plt.savefig(output_filename) + print(f"Plot saved to {output_filename}") + +if __name__ == "__main__": + main()