fix micro bench and fix pre commit
This commit is contained in:
@@ -6,8 +6,6 @@ results and the golden standard results, making the comparison robust to ID chan
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
@@ -16,11 +14,6 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
# Configure logging level (default INFO; override with LEANN_LOG_LEVEL)
|
||||
_log_level_str = os.getenv("LEANN_LOG_LEVEL", "INFO").upper()
|
||||
_log_level = getattr(logging, _log_level_str, logging.INFO)
|
||||
logging.basicConfig(level=_log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
|
||||
|
||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
|
||||
@@ -20,7 +20,7 @@ except ImportError:
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str = "facebook/contriever"
|
||||
model_path: str = "facebook/contriever-msmarco"
|
||||
batch_sizes: list[int] = None
|
||||
seq_length: int = 256
|
||||
num_runs: int = 5
|
||||
@@ -34,7 +34,7 @@ class BenchmarkConfig:
|
||||
|
||||
def __post_init__(self):
|
||||
if self.batch_sizes is None:
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
|
||||
class MLXBenchmark:
|
||||
@@ -179,10 +179,16 @@ class Benchmark:
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# print shape of input_ids and attention_mask
|
||||
print(f"input_ids shape: {input_ids.shape}")
|
||||
print(f"attention_mask shape: {attention_mask.shape}")
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
Reference in New Issue
Block a user