Resolve submodule conflict - update to af2a264
This commit is contained in:
31
README.md
31
README.md
@@ -146,6 +146,37 @@ This ensures the generated files are compatible with your system's protobuf libr
|
||||
|
||||
## 📊 Benchmarks
|
||||
|
||||
### How to Reproduce Evaluation Results
|
||||
|
||||
Reproducing our benchmarks is straightforward. The evaluation script is designed to be self-contained, automatically downloading all necessary data on its first run.
|
||||
|
||||
#### 1. Environment Setup
|
||||
|
||||
First, ensure you have followed the installation instructions in the [Quick Start](#-quick-start) section. This will install all core dependencies.
|
||||
|
||||
Next, install the optional development dependencies, which include the `huggingface-hub` library required for automatic data download:
|
||||
|
||||
```bash
|
||||
# This command installs all development dependencies
|
||||
uv pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
#### 2. Run the Evaluation
|
||||
|
||||
Simply run the evaluation script. The first time you run it, it will detect that the data is missing, download it from Hugging Face Hub, and then proceed with the evaluation.
|
||||
|
||||
**To evaluate the DPR dataset:**
|
||||
```bash
|
||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann
|
||||
```
|
||||
|
||||
**To evaluate the RPJ-Wiki dataset:**
|
||||
```bash
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index
|
||||
```
|
||||
|
||||
The script will print the recall and search time for each query, followed by the average results.
|
||||
|
||||
### Memory Usage Comparison
|
||||
|
||||
| System | 1M Documents | 10M Documents | 100M Documents |
|
||||
|
||||
34
build_mlx_index.py
Normal file
34
build_mlx_index.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from leann.api import LeannBuilder
|
||||
import os
|
||||
|
||||
# Define the path for our new MLX-based index
|
||||
INDEX_PATH = "./mlx_diskann_index/leann"
|
||||
|
||||
if os.path.exists(INDEX_PATH + ".meta.json"):
|
||||
print(f"Index already exists at {INDEX_PATH}. Skipping build.")
|
||||
else:
|
||||
print("Initializing LeannBuilder with MLX support...")
|
||||
# 1. Configure LeannBuilder to use MLX
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
|
||||
use_mlx=True
|
||||
)
|
||||
|
||||
# 2. Add documents
|
||||
print("Adding documents...")
|
||||
docs = [
|
||||
"MLX is an array framework for machine learning on Apple silicon.",
|
||||
"It was designed by Apple's machine learning research team.",
|
||||
"The mlx-community organization provides pre-trained models in MLX format.",
|
||||
"It supports operations on multi-dimensional arrays.",
|
||||
"Leann can now use MLX for its embedding models."
|
||||
]
|
||||
for doc in docs:
|
||||
builder.add_text(doc)
|
||||
|
||||
# 3. Build the index
|
||||
print(f"Building the MLX-based index at: {INDEX_PATH}")
|
||||
builder.build_index(INDEX_PATH)
|
||||
print("\nSuccessfully built the index with MLX embeddings!")
|
||||
print(f"Check the metadata file: {INDEX_PATH}.meta.json")
|
||||
82
data/.gitattributes
vendored
Normal file
82
data/.gitattributes
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - uncompressed
|
||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - compressed
|
||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - uncompressed
|
||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||
*.png filter=lfs diff=lfs merge=lfs -text
|
||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - compressed
|
||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||
# Video files - compressed
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
44
data/README.md
Normal file
44
data/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
---
|
||||
license: mit
|
||||
---
|
||||
|
||||
# LEANN-RAG Evaluation Data
|
||||
|
||||
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||
|
||||
## Dataset Components
|
||||
|
||||
This dataset is structured into three main parts:
|
||||
|
||||
1. **Pre-built LEANN Indices**:
|
||||
* `dpr/`: A pre-built index for the DPR dataset.
|
||||
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||
|
||||
2. **Ground Truth Data**:
|
||||
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||
|
||||
3. **Queries**:
|
||||
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||
|
||||
```bash
|
||||
pip install huggingface-hub
|
||||
```
|
||||
|
||||
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir="data"
|
||||
)
|
||||
```
|
||||
|
||||
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||
@@ -11,122 +11,143 @@ import time
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any
|
||||
import glob
|
||||
import pickle
|
||||
|
||||
# Add project root to path to allow importing from leann
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
from typing import List
|
||||
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
# --- Configuration ---
|
||||
NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl")
|
||||
|
||||
# Ground truth files for different datasets
|
||||
GROUND_TRUTH_FILES = {
|
||||
"rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json",
|
||||
"dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json"
|
||||
}
|
||||
def download_data_if_needed(data_root: Path):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
if not data_root.exists():
|
||||
print(f"Data directory '{data_root}' not found.")
|
||||
print(
|
||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
||||
)
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Old passages for different datasets
|
||||
OLD_PASSAGES_GLOBS = {
|
||||
"rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl",
|
||||
"dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl"
|
||||
}
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure
|
||||
)
|
||||
print("Data download complete!")
|
||||
except ImportError:
|
||||
print(
|
||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||
)
|
||||
print("uv pip install -e '.[dev]'")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during data download: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# --- Helper Class to Load Original Passages ---
|
||||
class OldPassageLoader:
|
||||
"""A simplified version of the old LazyPassages class to fetch golden results by ID."""
|
||||
def __init__(self, passages_glob: str):
|
||||
self.jsonl_paths = sorted(glob.glob(passages_glob))
|
||||
self.offsets = {}
|
||||
self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths]
|
||||
print("Building offset map for original passages...")
|
||||
for i, shard_path_str in enumerate(self.jsonl_paths):
|
||||
old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx"))
|
||||
if not old_idx_path.exists(): continue
|
||||
with open(old_idx_path, 'rb') as f:
|
||||
shard_offsets = pickle.load(f)
|
||||
for pid, offset in shard_offsets.items():
|
||||
self.offsets[str(pid)] = (i, offset)
|
||||
print("Offset map for original passages is ready.")
|
||||
|
||||
def get_passage_by_id(self, pid: str) -> Dict[str, Any]:
|
||||
pid = str(pid)
|
||||
if pid not in self.offsets:
|
||||
raise ValueError(f"Passage ID {pid} not found in offsets")
|
||||
file_idx, offset = self.offsets[pid]
|
||||
fp = self.fps[file_idx]
|
||||
fp.seek(offset)
|
||||
return json.loads(fp.readline())
|
||||
# --- Helper Function to get Golden Passages ---
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
"""
|
||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||
passage manager.
|
||||
"""
|
||||
golden_texts = set()
|
||||
for gid in golden_ids:
|
||||
try:
|
||||
# PassageManager uses string IDs
|
||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||
golden_texts.add(passage_data["text"])
|
||||
except KeyError:
|
||||
print(
|
||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
||||
)
|
||||
return golden_texts
|
||||
|
||||
def __del__(self):
|
||||
for fp in self.fps:
|
||||
fp.close()
|
||||
|
||||
def load_queries(file_path: Path) -> List[str]:
|
||||
queries = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
queries.append(data['query'])
|
||||
queries.append(data["query"])
|
||||
return queries
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
|
||||
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
|
||||
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run recall evaluation on a LEANN index."
|
||||
)
|
||||
parser.add_argument(
|
||||
"index_path", type=str, help="Path to the LEANN index to evaluate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---")
|
||||
|
||||
# Detect dataset type from index path
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'examples/'
|
||||
# and data is in 'data/' at the project root.
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
|
||||
# Automatically download data if it doesn't exist
|
||||
download_data_if_needed(data_root)
|
||||
|
||||
# Detect dataset type from index path to select the correct ground truth
|
||||
index_path_str = str(args.index_path)
|
||||
if "rpj_wiki" in index_path_str:
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in index_path_str:
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
print("WARNING: Unknown dataset type, defaulting to rpj_wiki")
|
||||
dataset_type = "rpj_wiki"
|
||||
|
||||
# Fallback: try to infer from the index directory name
|
||||
dataset_type = Path(args.index_path).name
|
||||
print(
|
||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
||||
)
|
||||
|
||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||
golden_results_file = (
|
||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||
)
|
||||
|
||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||
print(f"INFO: Using queries file: {queries_file}")
|
||||
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(args.index_path)
|
||||
queries = load_queries(NQ_QUERIES_FILE)
|
||||
|
||||
golden_results_file = GROUND_TRUTH_FILES[dataset_type]
|
||||
old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type]
|
||||
|
||||
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||
print(f"INFO: Using old passages glob: {old_passages_glob}")
|
||||
|
||||
with open(golden_results_file, 'r') as f:
|
||||
queries = load_queries(queries_file)
|
||||
|
||||
with open(golden_results_file, "r") as f:
|
||||
golden_results_data = json.load(f)
|
||||
|
||||
old_passage_loader = OldPassageLoader(old_passages_glob)
|
||||
|
||||
num_eval_queries = min(args.num_queries, len(queries))
|
||||
queries = queries[:num_eval_queries]
|
||||
|
||||
|
||||
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
||||
recall_scores = []
|
||||
search_times = []
|
||||
|
||||
for i in range(num_eval_queries):
|
||||
start_time = time.time()
|
||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
||||
new_results = searcher.search(
|
||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
||||
)
|
||||
search_times.append(time.time() - start_time)
|
||||
|
||||
# Correct Recall Calculation: Based on TEXT content
|
||||
new_texts = {result.text for result in new_results}
|
||||
golden_ids = golden_results_data["indices"][i][:args.top_k]
|
||||
golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids}
|
||||
|
||||
# Get golden texts directly from the searcher's passage manager
|
||||
golden_ids = golden_results_data["indices"][i][: args.top_k]
|
||||
golden_texts = get_golden_texts(searcher, golden_ids)
|
||||
|
||||
overlap = len(new_texts & golden_texts)
|
||||
recall = overlap / len(golden_texts) if golden_texts else 0
|
||||
@@ -139,19 +160,21 @@ def main():
|
||||
print(f"Overlap: {overlap}")
|
||||
print(f"Recall: {recall}")
|
||||
print(f"Search Time: {search_times[-1]:.4f}s")
|
||||
print(f"--------------------------------")
|
||||
print("--------------------------------")
|
||||
|
||||
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||
avg_time = np.mean(search_times) if search_times else 0
|
||||
|
||||
print(f"\n🎉 --- Evaluation Complete ---")
|
||||
print("\n🎉 --- Evaluation Complete ---")
|
||||
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -5,7 +5,6 @@ Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
||||
|
||||
import pickle
|
||||
import argparse
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Union
|
||||
@@ -16,7 +15,6 @@ from contextlib import contextmanager
|
||||
import zmq
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
@@ -154,6 +152,7 @@ def create_embedding_server_thread(
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
max_batch_size=128,
|
||||
passages_file: Optional[str] = None,
|
||||
use_mlx: bool = False,
|
||||
):
|
||||
"""
|
||||
在当前线程中创建并运行 embedding server
|
||||
@@ -172,36 +171,40 @@ def create_embedding_server_thread(
|
||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||
return
|
||||
|
||||
# 初始化模型
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
import torch
|
||||
|
||||
# 选择设备
|
||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
print("INFO: Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
print("INFO: Using MPS device (Apple Silicon)")
|
||||
if use_mlx:
|
||||
from leann.api import compute_embeddings_mlx
|
||||
print("INFO: Using MLX for embeddings")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("INFO: Using CPU device")
|
||||
|
||||
# 加载模型
|
||||
print(f"INFO: Loading model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
# 初始化模型
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
import torch
|
||||
|
||||
# 优化模型
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"INFO: Using FP16 precision with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
# 选择设备
|
||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
print("INFO: Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
print("INFO: Using MPS device (Apple Silicon)")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("INFO: Using CPU device")
|
||||
|
||||
# 加载模型
|
||||
print(f"INFO: Loading model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
|
||||
# 优化模型
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"INFO: Using FP16 precision with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
|
||||
# Load passages from file if provided
|
||||
if passages_file and os.path.exists(passages_file):
|
||||
@@ -233,7 +236,7 @@ def create_embedding_server_thread(
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
|
||||
if cuda_available:
|
||||
if not use_mlx and torch.cuda.is_available():
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
else:
|
||||
@@ -247,25 +250,25 @@ def create_embedding_server_thread(
|
||||
self.end()
|
||||
|
||||
def start(self):
|
||||
if cuda_available:
|
||||
if not use_mlx and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
self.start_event.record()
|
||||
else:
|
||||
if self.device.type == "mps":
|
||||
if not use_mlx and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.start_time = time.time()
|
||||
|
||||
def end(self):
|
||||
if cuda_available:
|
||||
if not use_mlx and torch.cuda.is_available():
|
||||
self.end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
if self.device.type == "mps":
|
||||
if not use_mlx and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed_time(self):
|
||||
if cuda_available:
|
||||
if not use_mlx and torch.cuda.is_available():
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||
else:
|
||||
return self.end_time - self.start_time
|
||||
@@ -273,7 +276,7 @@ def create_embedding_server_thread(
|
||||
def print_elapsed(self):
|
||||
print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds")
|
||||
|
||||
def process_batch(texts_batch, ids_batch, missing_ids):
|
||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
||||
"""处理文本批次"""
|
||||
batch_size = len(texts_batch)
|
||||
print(f"INFO: Processing batch of size {batch_size}")
|
||||
@@ -351,7 +354,7 @@ def create_embedding_server_thread(
|
||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
||||
|
||||
e2e_start = time.time()
|
||||
lookup_timer = DeviceTimer("text lookup", device)
|
||||
lookup_timer = DeviceTimer("text lookup")
|
||||
|
||||
# 解析请求
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
@@ -397,18 +400,25 @@ def create_embedding_server_thread(
|
||||
chunk_texts = texts[i:end_idx]
|
||||
chunk_ids = node_ids[i:end_idx]
|
||||
|
||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
||||
if use_mlx:
|
||||
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name)
|
||||
else:
|
||||
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
|
||||
all_embeddings.append(embeddings_chunk)
|
||||
|
||||
if cuda_available:
|
||||
torch.cuda.empty_cache()
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
if not use_mlx:
|
||||
if cuda_available:
|
||||
torch.cuda.empty_cache()
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
hidden = np.vstack(all_embeddings)
|
||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
||||
else:
|
||||
hidden = process_batch(texts, node_ids, missing_ids)
|
||||
if use_mlx:
|
||||
hidden = compute_embeddings_mlx(texts, model_name)
|
||||
else:
|
||||
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
|
||||
|
||||
# 序列化响应
|
||||
ser_start = time.time()
|
||||
@@ -429,16 +439,16 @@ def create_embedding_server_thread(
|
||||
|
||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
||||
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
if not use_mlx:
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
|
||||
except zmq.Again:
|
||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
||||
# REP套接字不需要重新创建,只需要继续监听
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"ERROR: Error in ZMQ server: {e}")
|
||||
@@ -460,7 +470,6 @@ def create_embedding_server_thread(
|
||||
raise
|
||||
|
||||
|
||||
# 保持原有的 create_embedding_server 函数不变,只添加线程化版本
|
||||
def create_embedding_server(
|
||||
domain="demo",
|
||||
load_passages=True,
|
||||
@@ -473,12 +482,13 @@ def create_embedding_server(
|
||||
lazy_load_passages=False,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
passages_file: Optional[str] = None,
|
||||
use_mlx: bool = False,
|
||||
):
|
||||
"""
|
||||
原有的 create_embedding_server 函数保持不变
|
||||
这个是阻塞版本,用于直接运行
|
||||
"""
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -495,6 +505,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings")
|
||||
args = parser.parse_args()
|
||||
|
||||
create_embedding_server(
|
||||
@@ -509,4 +520,5 @@ if __name__ == "__main__":
|
||||
lazy_load_passages=args.lazy_load_passages,
|
||||
model_name=args.model_name,
|
||||
passages_file=args.passages_file,
|
||||
)
|
||||
use_mlx=args.use_mlx,
|
||||
)
|
||||
|
||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: c7a9d681cb...af2a26481e
@@ -1,3 +1,4 @@
|
||||
|
||||
"""
|
||||
This file contains the core API for the LEANN project, now definitively updated
|
||||
with the correct, original embedding logic from the user's reference code.
|
||||
@@ -17,8 +18,10 @@ from .interface import LeannBackendFactoryInterface
|
||||
|
||||
# --- The Correct, Verified Embedding Logic from old_code.py ---
|
||||
|
||||
def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers for consistent results."""
|
||||
def compute_embeddings(chunks: List[str], model_name: str, use_mlx: bool = False) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers or MLX for consistent results."""
|
||||
if use_mlx:
|
||||
return compute_embeddings_mlx(chunks, model_name)
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
@@ -44,6 +47,45 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
|
||||
return embeddings
|
||||
|
||||
def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
f"MLX or related libraries not available. Install with: pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}'...")
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
|
||||
# Process each chunk
|
||||
all_embeddings = []
|
||||
for chunk in chunks:
|
||||
# Tokenize
|
||||
token_ids = tokenizer.encode(chunk)
|
||||
|
||||
# Convert to MLX array and add batch dimension
|
||||
input_ids = mx.array([token_ids])
|
||||
|
||||
# Get embeddings
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling (since we only have one sequence, just take the mean)
|
||||
pooled = embeddings.mean(axis=1) # Shape: (1, hidden_size)
|
||||
|
||||
# Convert individual embedding to numpy via list (to handle bfloat16)
|
||||
pooled_list = pooled[0].tolist() # Remove batch dimension and convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
|
||||
|
||||
# --- Core API Classes (Restored and Unchanged) ---
|
||||
|
||||
@dataclass
|
||||
@@ -83,7 +125,7 @@ class PassageManager:
|
||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||
|
||||
class LeannBuilder:
|
||||
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, **backend_kwargs):
|
||||
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, use_mlx: bool = False, **backend_kwargs):
|
||||
self.backend_name = backend_name
|
||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
@@ -91,6 +133,7 @@ class LeannBuilder:
|
||||
self.backend_factory = backend_factory
|
||||
self.embedding_model = embedding_model
|
||||
self.dimensions = dimensions
|
||||
self.use_mlx = use_mlx
|
||||
self.backend_kwargs = backend_kwargs
|
||||
self.chunks: List[Dict[str, Any]] = []
|
||||
|
||||
@@ -102,7 +145,7 @@ class LeannBuilder:
|
||||
|
||||
def build_index(self, index_path: str):
|
||||
if not self.chunks: raise ValueError("No chunks added.")
|
||||
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model)[0])
|
||||
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model, self.use_mlx)[0])
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_name = path.name
|
||||
@@ -118,7 +161,7 @@ class LeannBuilder:
|
||||
offset_map[chunk["id"]] = offset
|
||||
with open(offset_file, 'wb') as f: pickle.dump(offset_map, f)
|
||||
texts_to_embed = [c["text"] for c in self.chunks]
|
||||
embeddings = compute_embeddings(texts_to_embed, self.embedding_model)
|
||||
embeddings = compute_embeddings(texts_to_embed, self.embedding_model, self.use_mlx)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
@@ -126,7 +169,7 @@ class LeannBuilder:
|
||||
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||
meta_data = {
|
||||
"version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model,
|
||||
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs,
|
||||
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, "use_mlx": self.use_mlx,
|
||||
"passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}]
|
||||
}
|
||||
|
||||
@@ -145,6 +188,7 @@ class LeannSearcher:
|
||||
with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data['backend_name']
|
||||
self.embedding_model = self.meta_data['embedding_model']
|
||||
self.use_mlx = self.meta_data.get('use_mlx', False)
|
||||
self.passage_manager = PassageManager(self.meta_data.get('passage_sources', []))
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
@@ -157,7 +201,7 @@ class LeannSearcher:
|
||||
print(f" Top_k: {top_k}")
|
||||
print(f" Search kwargs: {search_kwargs}")
|
||||
|
||||
query_embedding = compute_embeddings([query], self.embedding_model)
|
||||
query_embedding = compute_embeddings([query], self.embedding_model, self.use_mlx)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
|
||||
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
|
||||
@@ -212,4 +256,4 @@ class LeannChat:
|
||||
print(f"Leann: {response}")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
break
|
||||
|
||||
@@ -41,6 +41,7 @@ dev = [
|
||||
"black>=23.0",
|
||||
"ruff>=0.1.0",
|
||||
"matplotlib",
|
||||
"huggingface-hub>=0.20.0",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
||||
128
tests/sanity_checks/benchmark_embeddings.py
Normal file
128
tests/sanity_checks/benchmark_embeddings.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import mlx.core as mx
|
||||
from mlx_lm import load
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||
|
||||
# --- Generate Dummy Data ---
|
||||
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||
|
||||
# --- Benchmark Functions ---b
|
||||
|
||||
def benchmark_torch(model, sentences):
|
||||
start_time = time.time()
|
||||
model.encode(sentences, convert_to_numpy=True)
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
def benchmark_mlx(model, tokenizer, sentences):
|
||||
start_time = time.time()
|
||||
|
||||
# Tokenize sentences using MLX tokenizer
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
token_ids = tokenizer.encode(sentence)
|
||||
tokens.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_len = max(len(t) for t in tokens)
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
|
||||
for token_seq in tokens:
|
||||
# Pad sequence
|
||||
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
|
||||
input_ids.append(padded)
|
||||
# Create attention mask (1 for real tokens, 0 for padding)
|
||||
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
|
||||
attention_mask.append(mask)
|
||||
|
||||
# Convert to MLX arrays
|
||||
input_ids = mx.array(input_ids)
|
||||
attention_mask = mx.array(attention_mask)
|
||||
|
||||
# Get embeddings
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling
|
||||
mask = mx.expand_dims(attention_mask, -1)
|
||||
sum_embeddings = (embeddings * mask).sum(axis=1)
|
||||
sum_mask = mask.sum(axis=1)
|
||||
_ = sum_embeddings / sum_mask
|
||||
|
||||
mx.eval() # Ensure computation is finished
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
# --- Main Execution ---
|
||||
def main():
|
||||
print("--- Initializing Models ---")
|
||||
# Load PyTorch model
|
||||
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
|
||||
print(f"PyTorch model loaded on: {device}")
|
||||
|
||||
# Load MLX model
|
||||
print(f"Loading MLX model: {MODEL_NAME_MLX}")
|
||||
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
|
||||
print("MLX model loaded.")
|
||||
|
||||
# --- Warm-up ---
|
||||
print("\n--- Performing Warm-up Runs ---")
|
||||
for _ in range(WARMUP_RUNS):
|
||||
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
|
||||
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
|
||||
print("Warm-up complete.")
|
||||
|
||||
# --- Benchmarking ---
|
||||
print("\n--- Starting Benchmark ---")
|
||||
results_torch = []
|
||||
results_mlx = []
|
||||
|
||||
for batch_size in BATCH_SIZES:
|
||||
print(f"Benchmarking batch size: {batch_size}")
|
||||
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
||||
|
||||
# Benchmark PyTorch
|
||||
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
results_torch.append(np.mean(torch_times))
|
||||
|
||||
# Benchmark MLX
|
||||
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
results_mlx.append(np.mean(mlx_times))
|
||||
|
||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||
print(f"Batch Sizes: {BATCH_SIZES}")
|
||||
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
|
||||
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
|
||||
|
||||
# --- Plotting ---
|
||||
print("\n--- Generating Plot ---")
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
|
||||
|
||||
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
|
||||
plt.xlabel("Batch Size")
|
||||
plt.ylabel("Average Time per Batch (ms)")
|
||||
plt.xticks(BATCH_SIZES)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
# Save the plot
|
||||
output_filename = "embedding_benchmark.png"
|
||||
plt.savefig(output_filename)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user