fix: mlx when searching, added to embedding_server

This commit is contained in:
Andy Lee
2025-07-14 01:11:21 -07:00
parent 8b4654921b
commit 3da5b44d7f
8 changed files with 315 additions and 885 deletions

View File

@@ -303,6 +303,41 @@ Once the index is built, you can ask questions like:
</details> </details>
## ⚡ Performance Comparison
### LEANN vs Faiss HNSW
We benchmarked LEANN against the popular Faiss HNSW implementation to demonstrate the significant memory and storage savings our approach provides:
```bash
# Run the comparison benchmark
python examples/compare_faiss_vs_leann.py
```
#### 🎯 Results Summary
| Metric | Faiss HNSW | LEANN HNSW | **Improvement** |
|--------|------------|-------------|-----------------|
| **Peak Memory** | 887.0 MB | 618.2 MB | **1.4x less** (268.8 MB saved) |
| **Storage Size** | 5.5 MB | 0.5 MB | **11.4x smaller** (5.0 MB saved) |
#### 📈 Key Takeaways
- **🧠 Memory Efficiency**: LEANN uses **30% less memory** during index building and querying
- **💾 Storage Optimization**: LEANN requires **91% less storage** for the same dataset
- **🔄 On-demand Computing**: Storage savings come from computing embeddings at query time instead of pre-storing them
- **⚖️ Fair Comparison**: Both systems tested on identical hardware with the same 2,573 document dataset
> **Note**: Results may vary based on dataset size, hardware configuration, and query patterns. The comparison excludes text storage to focus purely on index structures.
### Run the comparison
```bash
python examples/compare_faiss_vs_leann.py
```
*Benchmark results obtained on Apple Silicon with consistent environmental conditions*
## 📊 Benchmarks ## 📊 Benchmarks

View File

@@ -150,6 +150,7 @@ def create_hnsw_embedding_server(
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",
custom_max_length_param: Optional[int] = None, custom_max_length_param: Optional[int] = None,
distance_metric: str = "mips", distance_metric: str = "mips",
use_mlx: bool = False,
): ):
""" """
Create and start a ZMQ-based embedding server for HNSW backend. Create and start a ZMQ-based embedding server for HNSW backend.
@@ -167,9 +168,13 @@ def create_hnsw_embedding_server(
custom_max_length_param: Custom max sequence length custom_max_length_param: Custom max sequence length
distance_metric: The distance metric to use distance_metric: The distance metric to use
""" """
print(f"Loading tokenizer for {model_name}...") if not use_mlx:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) print(f"Loading tokenizer for {model_name}...")
print(f"Tokenizer loaded successfully!") tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
print(f"Tokenizer loaded successfully!")
else:
print("Using MLX mode - tokenizer will be loaded separately")
tokenizer = None
# Device setup # Device setup
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
@@ -191,8 +196,17 @@ def create_hnsw_embedding_server(
# Load model to the appropriate device # Load model to the appropriate device
print(f"Starting HNSW server on port {zmq_port} with model {model_name}") print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Loading model {model_name}... (this may take a while if downloading)") print(f"Loading model {model_name}... (this may take a while if downloading)")
model = AutoModel.from_pretrained(model_name).to(device).eval()
print(f"Model {model_name} loaded successfully!") if use_mlx:
# For MLX models, we need to use the MLX embedding computation
print("MLX model detected - using MLX backend for embeddings")
model = None # We'll handle MLX separately
tokenizer = None
else:
# Use standard transformers for non-MLX models
model = AutoModel.from_pretrained(model_name).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Model {model_name} loaded successfully!")
# Check port availability # Check port availability
import socket import socket
@@ -312,8 +326,37 @@ def create_hnsw_embedding_server(
def print_elapsed(self): def print_elapsed(self):
return # Disabled for now return # Disabled for now
def _process_batch_mlx(texts_batch, ids_batch, missing_ids):
"""Process a batch of texts using MLX backend"""
try:
# Import MLX embedding computation from main API
from leann.api import compute_embeddings
# Compute embeddings using MLX
embeddings = compute_embeddings(texts_batch, model_name, use_mlx=True)
print(
f"[leann_backend_hnsw.hnsw_embedding_server LOG]: MLX embeddings computed for {len(texts_batch)} texts"
)
print(
f"[leann_backend_hnsw.hnsw_embedding_server LOG]: Embedding shape: {embeddings.shape}"
)
return embeddings
except Exception as e:
print(
f"[leann_backend_hnsw.hnsw_embedding_server LOG]: ERROR in MLX processing: {e}"
)
raise
def process_batch(texts_batch, ids_batch, missing_ids): def process_batch(texts_batch, ids_batch, missing_ids):
"""Process a batch of texts and return embeddings""" """Process a batch of texts and return embeddings"""
# Handle MLX models separately
if use_mlx:
return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
_is_e5_model = "e5" in model_name.lower() _is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower() _is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch) batch_size = len(texts_batch)
@@ -927,6 +970,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--distance-metric", type=str, default="mips", help="Distance metric to use" "--distance-metric", type=str, default="mips", help="Distance metric to use"
) )
parser.add_argument(
"--use-mlx",
action="store_true",
default=False,
help="Use MLX for model inference",
)
args = parser.parse_args() args = parser.parse_args()
@@ -942,4 +991,5 @@ if __name__ == "__main__":
model_name=args.model_name, model_name=args.model_name,
custom_max_length_param=args.custom_max_length, custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric, distance_metric=args.distance_metric,
use_mlx=args.use_mlx,
) )

View File

@@ -1,4 +1,3 @@
""" """
This file contains the core API for the LEANN project, now definitively updated This file contains the core API for the LEANN project, now definitively updated
with the correct, original embedding logic from the user's reference code. with the correct, original embedding logic from the user's reference code.
@@ -18,7 +17,10 @@ from .interface import LeannBackendFactoryInterface
# --- The Correct, Verified Embedding Logic from old_code.py --- # --- The Correct, Verified Embedding Logic from old_code.py ---
def compute_embeddings(chunks: List[str], model_name: str, use_mlx: bool = False) -> np.ndarray:
def compute_embeddings(
chunks: List[str], model_name: str, use_mlx: bool = False
) -> np.ndarray:
"""Computes embeddings using sentence-transformers or MLX for consistent results.""" """Computes embeddings using sentence-transformers or MLX for consistent results."""
if use_mlx: if use_mlx:
return compute_embeddings_mlx(chunks, model_name) return compute_embeddings_mlx(chunks, model_name)
@@ -33,7 +35,9 @@ def compute_embeddings(chunks: List[str], model_name: str, use_mlx: bool = False
model = SentenceTransformer(model_name) model = SentenceTransformer(model_name)
model = model.half() model = model.half()
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...") print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..."
)
# use acclerater GPU or MAC GPU # use acclerater GPU or MAC GPU
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -43,10 +47,13 @@ def compute_embeddings(chunks: List[str], model_name: str, use_mlx: bool = False
# Generate embeddings # Generate embeddings
# give use an warning if OOM here means we need to turn down the batch size # give use an warning if OOM here means we need to turn down the batch size
embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=256) embeddings = model.encode(
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=256
)
return embeddings return embeddings
def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray: def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using an MLX model.""" """Computes embeddings using an MLX model."""
try: try:
@@ -54,10 +61,12 @@ def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
from mlx_lm.utils import load from mlx_lm.utils import load
except ImportError as e: except ImportError as e:
raise RuntimeError( raise RuntimeError(
f"MLX or related libraries not available. Install with: pip install mlx mlx-lm" f"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e ) from e
print(f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}'...") print(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}'..."
)
# Load model and tokenizer # Load model and tokenizer
model, tokenizer = load(model_name) model, tokenizer = load(model_name)
@@ -88,6 +97,7 @@ def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
# --- Core API Classes (Restored and Unchanged) --- # --- Core API Classes (Restored and Unchanged) ---
@dataclass @dataclass
class SearchResult: class SearchResult:
id: str id: str
@@ -95,6 +105,7 @@ class SearchResult:
text: str text: str
metadata: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)
class PassageManager: class PassageManager:
def __init__(self, passage_sources: List[Dict[str, Any]]): def __init__(self, passage_sources: List[Dict[str, Any]]):
self.offset_maps = {} self.offset_maps = {}
@@ -106,8 +117,10 @@ class PassageManager:
passage_file = source["path"] passage_file = source["path"]
index_file = source["index_path"] index_file = source["index_path"]
if not Path(index_file).exists(): if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}") raise FileNotFoundError(
with open(index_file, 'rb') as f: f"Passage index file not found: {index_file}"
)
with open(index_file, "rb") as f:
offset_map = pickle.load(f) offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file self.passage_files[passage_file] = passage_file
@@ -119,15 +132,25 @@ class PassageManager:
def get_passage(self, passage_id: str) -> Dict[str, Any]: def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map: if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id] passage_file, offset = self.global_offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f: with open(passage_file, "r", encoding="utf-8") as f:
f.seek(offset) f.seek(offset)
return json.loads(f.readline()) return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}") raise KeyError(f"Passage ID not found: {passage_id}")
class LeannBuilder: class LeannBuilder:
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, use_mlx: bool = False, **backend_kwargs): def __init__(
self,
backend_name: str,
embedding_model: str = "facebook/contriever-msmarco",
dimensions: Optional[int] = None,
use_mlx: bool = False,
**backend_kwargs,
):
self.backend_name = backend_name self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name) backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(
backend_name
)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.") raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory self.backend_factory = backend_factory
@@ -138,14 +161,19 @@ class LeannBuilder:
self.chunks: List[Dict[str, Any]] = [] self.chunks: List[Dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None: metadata = {} if metadata is None:
passage_id = metadata.get('id', str(uuid.uuid4())) metadata = {}
passage_id = metadata.get("id", str(uuid.uuid4()))
chunk_data = {"id": passage_id, "text": text, "metadata": metadata} chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
self.chunks.append(chunk_data) self.chunks.append(chunk_data)
def build_index(self, index_path: str): def build_index(self, index_path: str):
if not self.chunks: raise ValueError("No chunks added.") if not self.chunks:
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model, self.use_mlx)[0]) raise ValueError("No chunks added.")
if self.dimensions is None:
self.dimensions = len(
compute_embeddings(["dummy"], self.embedding_model, self.use_mlx)[0]
)
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_name = path.name index_name = path.name
@@ -153,24 +181,47 @@ class LeannBuilder:
passages_file = index_dir / f"{index_name}.passages.jsonl" passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx" offset_file = index_dir / f"{index_name}.passages.idx"
offset_map = {} offset_map = {}
with open(passages_file, 'w', encoding='utf-8') as f: with open(passages_file, "w", encoding="utf-8") as f:
for chunk in self.chunks: for chunk in self.chunks:
offset = f.tell() offset = f.tell()
json.dump({"id": chunk["id"], "text": chunk["text"], "metadata": chunk["metadata"]}, f, ensure_ascii=False) json.dump(
f.write('\n') {
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"],
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset offset_map[chunk["id"]] = offset
with open(offset_file, 'wb') as f: pickle.dump(offset_map, f) with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks] texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(texts_to_embed, self.embedding_model, self.use_mlx) embeddings = compute_embeddings(
texts_to_embed, self.embedding_model, self.use_mlx
)
string_ids = [chunk["id"] for chunk in self.chunks] string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions} current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs) builder_instance.build(
embeddings, string_ids, index_path, **current_backend_kwargs
)
leann_meta_path = index_dir / f"{index_name}.meta.json" leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = { meta_data = {
"version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model, "version": "1.0",
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, "use_mlx": self.use_mlx, "backend_name": self.backend_name,
"passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}] "embedding_model": self.embedding_model,
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs,
"use_mlx": self.use_mlx,
"passage_sources": [
{
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file),
}
],
} }
# Add storage status flags for HNSW backend # Add storage status flags for HNSW backend
@@ -178,21 +229,28 @@ class LeannBuilder:
is_compact = self.backend_kwargs.get("is_compact", True) is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True) is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute # Pruned only if compact and recompute meta_data["is_pruned"] = (
with open(leann_meta_path, 'w', encoding='utf-8') as f: json.dump(meta_data, f, indent=2) is_compact and is_recompute
) # Pruned only if compact and recompute
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
class LeannSearcher: class LeannSearcher:
def __init__(self, index_path: str, **backend_kwargs): def __init__(self, index_path: str, **backend_kwargs):
meta_path_str = f"{index_path}.meta.json" meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists(): raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}") if not Path(meta_path_str).exists():
with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f) raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
backend_name = self.meta_data['backend_name'] with open(meta_path_str, "r", encoding="utf-8") as f:
self.embedding_model = self.meta_data['embedding_model'] self.meta_data = json.load(f)
self.use_mlx = self.meta_data.get('use_mlx', False) backend_name = self.meta_data["backend_name"]
self.passage_manager = PassageManager(self.meta_data.get('passage_sources', [])) self.embedding_model = self.meta_data["embedding_model"]
self.use_mlx = self.meta_data.get("use_mlx", False)
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.") if backend_factory is None:
final_kwargs = {**self.meta_data.get('backend_kwargs', {}), **backend_kwargs} raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]: def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
@@ -201,34 +259,55 @@ class LeannSearcher:
print(f" Top_k: {top_k}") print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}") print(f" Search kwargs: {search_kwargs}")
query_embedding = compute_embeddings([query], self.embedding_model, self.use_mlx) query_embedding = compute_embeddings(
[query], self.embedding_model, self.use_mlx
)
print(f" Generated embedding shape: {query_embedding.shape}") print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}") print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}") print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
# Add use_mlx to search kwargs
search_kwargs["use_mlx"] = self.use_mlx
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
print(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results") print(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
enriched_results = [] enriched_results = []
if 'labels' in results and 'distances' in results: if "labels" in results and "distances" in results:
print(f" Processing {len(results['labels'][0])} passage IDs:") print(f" Processing {len(results['labels'][0])} passage IDs:")
for i, (string_id, dist) in enumerate(zip(results['labels'][0], results['distances'][0])): for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0])
):
try: try:
passage_data = self.passage_manager.get_passage(string_id) passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult( enriched_results.append(
id=string_id, score=dist, text=passage_data['text'], metadata=passage_data.get('metadata', {}) SearchResult(
)) id=string_id,
print(f" {i+1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text'][:60]}...") score=dist,
text=passage_data["text"],
metadata=passage_data.get("metadata", {}),
)
)
print(
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text'][:60]}..."
)
except KeyError: except KeyError:
print(f" {i+1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!") print(
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
)
print(f" Final enriched results: {len(enriched_results)} passages") print(f" Final enriched results: {len(enriched_results)} passages")
return enriched_results return enriched_results
from .chat import get_llm from .chat import get_llm
class LeannChat: class LeannChat:
def __init__(self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs): def __init__(
self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs
):
self.searcher = LeannSearcher(index_path, **kwargs) self.searcher = LeannSearcher(index_path, **kwargs)
self.llm = get_llm(llm_config) self.llm = get_llm(llm_config)
@@ -248,7 +327,7 @@ class LeannChat:
while True: while True:
try: try:
user_input = input("You: ").strip() user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit']: if user_input.lower() in ["quit", "exit"]:
break break
if not user_input: if not user_input:
continue continue

View File

@@ -310,9 +310,12 @@ class EmbeddingServerManager:
command.extend(["--passages-file", str(kwargs["passages_file"])]) command.extend(["--passages-file", str(kwargs["passages_file"])])
# if "distance_metric" in kwargs and kwargs["distance_metric"]: # if "distance_metric" in kwargs and kwargs["distance_metric"]:
# command.extend(["--distance-metric", kwargs["distance_metric"]]) # command.extend(["--distance-metric", kwargs["distance_metric"]])
if "use_mlx" in kwargs and kwargs["use_mlx"]:
command.extend(["--use-mlx"])
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}") print(f"INFO: Running command from project root: {project_root}")
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, command,

View File

@@ -78,9 +78,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
model_name=self.embedding_model, model_name=self.embedding_model,
passages_file=passages_source_file, passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"), distance_metric=kwargs.get("distance_metric"),
use_mlx=kwargs.get("use_mlx", False),
) )
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {kwargs.get('zmq_port')}") raise RuntimeError(f"Failed to start embedding server on port {port}")
@abstractmethod @abstractmethod
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:

View File

@@ -34,6 +34,8 @@ dependencies = [
"msgpack>=1.1.1", "msgpack>=1.1.1",
"llama-index-vector-stores-faiss>=0.4.0", "llama-index-vector-stores-faiss>=0.4.0",
"llama-index-embeddings-huggingface>=0.5.5", "llama-index-embeddings-huggingface>=0.5.5",
"mlx>=0.26.3",
"mlx-lm>=0.26.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -12,7 +12,7 @@ else:
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="hnsw",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
use_mlx=True use_mlx=True,
) )
# 2. Add documents # 2. Add documents
@@ -22,7 +22,7 @@ else:
"It was designed by Apple's machine learning research team.", "It was designed by Apple's machine learning research team.",
"The mlx-community organization provides pre-trained models in MLX format.", "The mlx-community organization provides pre-trained models in MLX format.",
"It supports operations on multi-dimensional arrays.", "It supports operations on multi-dimensional arrays.",
"Leann can now use MLX for its embedding models." "Leann can now use MLX for its embedding models.",
] ]
for doc in docs: for doc in docs:
builder.add_text(doc) builder.add_text(doc)
@@ -34,9 +34,11 @@ else:
print(f"Check the metadata file: {INDEX_PATH}.meta.json") print(f"Check the metadata file: {INDEX_PATH}.meta.json")
chat = LeannChat(index_path=INDEX_PATH) chat = LeannChat(index_path=INDEX_PATH)
# add query # add query
query = "MLX is an array framework for machine learning on Apple silicon." query = "MLX is an array framework for machine learning on Apple silicon."
print(f"Query: {query}") print(f"Query: {query}")
response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1) response = chat.ask(
print(f"Response: {response}") query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1
)
print(f"Response: {response}")

882
uv.lock generated
View File

File diff suppressed because it is too large Load Diff