111 lines
2.9 KiB
Python
111 lines
2.9 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
|
# \file: /speed_ablation.py
|
|
# \brief:
|
|
# Author: raphael hao
|
|
|
|
# %%
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
# %%
|
|
# from script.settings import DATA_PATH, FIGURE_PATH
|
|
|
|
# Load the latency ablation data
|
|
latency_data = pd.read_csv("./paper_plot/data/latency_ablation.csv")
|
|
# Filter for SpeedUp metric only
|
|
speedup_data = latency_data[latency_data['Metric'] == 'SpeedUp']
|
|
|
|
# %%
|
|
from matplotlib import pyplot as plt
|
|
|
|
plt.rcParams["font.family"] = "Helvetica"
|
|
plt.rcParams["ytick.direction"] = "in"
|
|
plt.rcParams["hatch.linewidth"] = 1.5
|
|
plt.rcParams["font.weight"] = "bold"
|
|
plt.rcParams["axes.labelweight"] = "bold"
|
|
plt.rcParams["text.usetex"] = True
|
|
|
|
# %%
|
|
|
|
fig, ax = plt.subplots()
|
|
fig.set_size_inches(5, 1.5)
|
|
plt.subplots_adjust(wspace=0, hspace=0)
|
|
|
|
total_width, n = 3, 3
|
|
group = len(speedup_data['Dataset'].unique())
|
|
width = total_width * 0.9 / n
|
|
x = np.arange(group) * n
|
|
exit_idx_x = x + (total_width - width) / n
|
|
edgecolors = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
|
hatches = ["/////", "xxxxx", "\\\\\\\\\\"]
|
|
labels = ["Base", "Base + Two-level", "Base + Two-level + Batch"]
|
|
|
|
datasets = speedup_data['Dataset'].unique()
|
|
|
|
for i, dataset in enumerate(datasets):
|
|
dataset_data = speedup_data[speedup_data['Dataset'] == dataset]
|
|
|
|
for j in range(n):
|
|
if j == 0:
|
|
value = dataset_data['Original'].values[0]
|
|
elif j == 1:
|
|
value = dataset_data['original + two_level'].values[0]
|
|
else:
|
|
value = dataset_data['original + two_level + batch'].values[0]
|
|
|
|
ax.text(
|
|
exit_idx_x[i] + j * width,
|
|
value + 0.05,
|
|
f"{value:.2f}",
|
|
ha='center',
|
|
va='bottom',
|
|
fontsize=10,
|
|
fontweight='bold',
|
|
rotation=0,
|
|
zorder=20,
|
|
)
|
|
|
|
ax.bar(
|
|
exit_idx_x[i] + j * width,
|
|
value,
|
|
width=width * 0.8,
|
|
color="white",
|
|
edgecolor=edgecolors[j],
|
|
hatch=hatches[j],
|
|
linewidth=1.5,
|
|
label=labels[j] if i == 0 else None,
|
|
zorder=10,
|
|
)
|
|
|
|
|
|
|
|
ax.set_ylim([0.5, 2.3])
|
|
ax.set_yticks(np.arange(0.5, 2.2, 0.5))
|
|
ax.set_yticklabels(np.arange(0.5, 2.2, 0.5), fontsize=12)
|
|
ax.set_xticks(exit_idx_x + width)
|
|
ax.set_xticklabels(datasets, fontsize=10)
|
|
# ax.set_xlabel("Different Datasets", fontsize=14)
|
|
ax.legend(
|
|
bbox_to_anchor=(-0.03, 1.4),
|
|
ncol=3,
|
|
loc="upper left",
|
|
labelspacing=0.1,
|
|
edgecolor="black",
|
|
facecolor="white",
|
|
framealpha=1,
|
|
shadow=False,
|
|
fancybox=False,
|
|
handlelength=0.8,
|
|
handletextpad=0.6,
|
|
columnspacing=0.8,
|
|
prop={"weight": "bold", "size": 10},
|
|
).set_zorder(100)
|
|
ax.set_ylabel("Speedup", fontsize=11)
|
|
|
|
plt.savefig("./paper_plot/figures/latency_speedup.pdf", bbox_inches="tight", dpi=300)
|
|
|
|
# %%
|
|
|
|
print(f"Save to ./paper_plot/figures/latency_speedup.pdf") |