feat: mlx
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
|
||||
"""
|
||||
This file contains the core API for the LEANN project, now definitively updated
|
||||
with the correct, original embedding logic from the user's reference code.
|
||||
@@ -17,8 +18,10 @@ from .interface import LeannBackendFactoryInterface
|
||||
|
||||
# --- The Correct, Verified Embedding Logic from old_code.py ---
|
||||
|
||||
def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers for consistent results."""
|
||||
def compute_embeddings(chunks: List[str], model_name: str, use_mlx: bool = False) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers or MLX for consistent results."""
|
||||
if use_mlx:
|
||||
return compute_embeddings_mlx(chunks, model_name)
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
@@ -44,6 +47,45 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
|
||||
return embeddings
|
||||
|
||||
def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
f"MLX or related libraries not available. Install with: pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}'...")
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
|
||||
# Process each chunk
|
||||
all_embeddings = []
|
||||
for chunk in chunks:
|
||||
# Tokenize
|
||||
token_ids = tokenizer.encode(chunk)
|
||||
|
||||
# Convert to MLX array and add batch dimension
|
||||
input_ids = mx.array([token_ids])
|
||||
|
||||
# Get embeddings
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling (since we only have one sequence, just take the mean)
|
||||
pooled = embeddings.mean(axis=1) # Shape: (1, hidden_size)
|
||||
|
||||
# Convert individual embedding to numpy via list (to handle bfloat16)
|
||||
pooled_list = pooled[0].tolist() # Remove batch dimension and convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
|
||||
|
||||
# --- Core API Classes (Restored and Unchanged) ---
|
||||
|
||||
@dataclass
|
||||
@@ -83,7 +125,7 @@ class PassageManager:
|
||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||
|
||||
class LeannBuilder:
|
||||
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, **backend_kwargs):
|
||||
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, use_mlx: bool = False, **backend_kwargs):
|
||||
self.backend_name = backend_name
|
||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
@@ -91,6 +133,7 @@ class LeannBuilder:
|
||||
self.backend_factory = backend_factory
|
||||
self.embedding_model = embedding_model
|
||||
self.dimensions = dimensions
|
||||
self.use_mlx = use_mlx
|
||||
self.backend_kwargs = backend_kwargs
|
||||
self.chunks: List[Dict[str, Any]] = []
|
||||
|
||||
@@ -102,7 +145,7 @@ class LeannBuilder:
|
||||
|
||||
def build_index(self, index_path: str):
|
||||
if not self.chunks: raise ValueError("No chunks added.")
|
||||
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model)[0])
|
||||
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model, self.use_mlx)[0])
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_name = path.name
|
||||
@@ -118,7 +161,7 @@ class LeannBuilder:
|
||||
offset_map[chunk["id"]] = offset
|
||||
with open(offset_file, 'wb') as f: pickle.dump(offset_map, f)
|
||||
texts_to_embed = [c["text"] for c in self.chunks]
|
||||
embeddings = compute_embeddings(texts_to_embed, self.embedding_model)
|
||||
embeddings = compute_embeddings(texts_to_embed, self.embedding_model, self.use_mlx)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
@@ -126,7 +169,7 @@ class LeannBuilder:
|
||||
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||
meta_data = {
|
||||
"version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model,
|
||||
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs,
|
||||
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, "use_mlx": self.use_mlx,
|
||||
"passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}]
|
||||
}
|
||||
|
||||
@@ -145,6 +188,7 @@ class LeannSearcher:
|
||||
with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data['backend_name']
|
||||
self.embedding_model = self.meta_data['embedding_model']
|
||||
self.use_mlx = self.meta_data.get('use_mlx', False)
|
||||
self.passage_manager = PassageManager(self.meta_data.get('passage_sources', []))
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
@@ -157,7 +201,7 @@ class LeannSearcher:
|
||||
print(f" Top_k: {top_k}")
|
||||
print(f" Search kwargs: {search_kwargs}")
|
||||
|
||||
query_embedding = compute_embeddings([query], self.embedding_model)
|
||||
query_embedding = compute_embeddings([query], self.embedding_model, self.use_mlx)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
|
||||
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
|
||||
@@ -212,4 +256,4 @@ class LeannChat:
|
||||
print(f"Leann: {response}")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user