feat: mlx

This commit is contained in:
Andy Lee
2025-07-13 02:13:04 -07:00
parent 71ef4b7d4c
commit 48dda1cb5b
4 changed files with 278 additions and 60 deletions

View File

@@ -1,3 +1,4 @@
"""
This file contains the core API for the LEANN project, now definitively updated
with the correct, original embedding logic from the user's reference code.
@@ -17,8 +18,10 @@ from .interface import LeannBackendFactoryInterface
# --- The Correct, Verified Embedding Logic from old_code.py ---
def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using sentence-transformers for consistent results."""
def compute_embeddings(chunks: List[str], model_name: str, use_mlx: bool = False) -> np.ndarray:
"""Computes embeddings using sentence-transformers or MLX for consistent results."""
if use_mlx:
return compute_embeddings_mlx(chunks, model_name)
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
@@ -44,6 +47,45 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
return embeddings
def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
except ImportError as e:
raise RuntimeError(
f"MLX or related libraries not available. Install with: pip install mlx mlx-lm"
) from e
print(f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}'...")
# Load model and tokenizer
model, tokenizer = load(model_name)
# Process each chunk
all_embeddings = []
for chunk in chunks:
# Tokenize
token_ids = tokenizer.encode(chunk)
# Convert to MLX array and add batch dimension
input_ids = mx.array([token_ids])
# Get embeddings
embeddings = model(input_ids)
# Mean pooling (since we only have one sequence, just take the mean)
pooled = embeddings.mean(axis=1) # Shape: (1, hidden_size)
# Convert individual embedding to numpy via list (to handle bfloat16)
pooled_list = pooled[0].tolist() # Remove batch dimension and convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)
# --- Core API Classes (Restored and Unchanged) ---
@dataclass
@@ -83,7 +125,7 @@ class PassageManager:
raise KeyError(f"Passage ID not found: {passage_id}")
class LeannBuilder:
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, **backend_kwargs):
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, use_mlx: bool = False, **backend_kwargs):
self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
@@ -91,6 +133,7 @@ class LeannBuilder:
self.backend_factory = backend_factory
self.embedding_model = embedding_model
self.dimensions = dimensions
self.use_mlx = use_mlx
self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = []
@@ -102,7 +145,7 @@ class LeannBuilder:
def build_index(self, index_path: str):
if not self.chunks: raise ValueError("No chunks added.")
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model)[0])
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model, self.use_mlx)[0])
path = Path(index_path)
index_dir = path.parent
index_name = path.name
@@ -118,7 +161,7 @@ class LeannBuilder:
offset_map[chunk["id"]] = offset
with open(offset_file, 'wb') as f: pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(texts_to_embed, self.embedding_model)
embeddings = compute_embeddings(texts_to_embed, self.embedding_model, self.use_mlx)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
@@ -126,7 +169,7 @@ class LeannBuilder:
leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = {
"version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model,
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs,
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, "use_mlx": self.use_mlx,
"passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}]
}
@@ -145,6 +188,7 @@ class LeannSearcher:
with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f)
backend_name = self.meta_data['backend_name']
self.embedding_model = self.meta_data['embedding_model']
self.use_mlx = self.meta_data.get('use_mlx', False)
self.passage_manager = PassageManager(self.meta_data.get('passage_sources', []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.")
@@ -157,7 +201,7 @@ class LeannSearcher:
print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}")
query_embedding = compute_embeddings([query], self.embedding_model)
query_embedding = compute_embeddings([query], self.embedding_model, self.use_mlx)
print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
@@ -212,4 +256,4 @@ class LeannChat:
print(f"Leann: {response}")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
break