Initial commit
This commit is contained in:
151
research/paper_plot/disk_cache.py
Normal file
151
research/paper_plot/disk_cache.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import matplotlib
|
||||
from matplotlib.axes import Axes
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
# plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
|
||||
plt.rcParams['text.latex.preamble'] = r"""
|
||||
\usepackage{helvet} % Use Helvetica font for text
|
||||
\usepackage{sfmath} % Use sans-serif font for math
|
||||
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
|
||||
\usepackage[T1]{fontenc} % Recommended for font encoding
|
||||
"""
|
||||
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
|
||||
SAVE_PTH = "./paper_plot/figures"
|
||||
font_size = 16
|
||||
|
||||
# New data in dictionary format
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
|
||||
|
||||
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
|
||||
latency_data = {
|
||||
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
|
||||
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
|
||||
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
|
||||
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
|
||||
}
|
||||
cache_hit_counts = {
|
||||
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
|
||||
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
|
||||
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
|
||||
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
|
||||
}
|
||||
|
||||
# Create the figure with 4 subplots in a 2x2 grid
|
||||
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
|
||||
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
|
||||
|
||||
# Bar style settings
|
||||
width = 0.7
|
||||
x = np.arange(len(cache_ratios))
|
||||
|
||||
# Define hatch patterns for different cache ratios
|
||||
hatch_patterns = ['//', '//', '//', '//', '//']
|
||||
|
||||
# Find max cache hit value across all datasets for unified y-axis
|
||||
all_hit_counts = []
|
||||
for dataset in datasets:
|
||||
all_hit_counts.extend(cache_hit_counts[dataset])
|
||||
max_unified_hit = max(all_hit_counts) * 1.13
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
latencies = latency_data[dataset]
|
||||
hit_counts = cache_hit_counts[dataset]
|
||||
|
||||
for j, val in enumerate(latencies):
|
||||
container = axes[i].bar(
|
||||
x[j],
|
||||
val,
|
||||
width=width,
|
||||
color="white",
|
||||
edgecolor="black",
|
||||
linewidth=1.0,
|
||||
zorder=10,
|
||||
)
|
||||
axes[i].bar_label(
|
||||
container,
|
||||
[f"{val:.2f}"],
|
||||
fontsize=10,
|
||||
zorder=200,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
axes[i].set_title(dataset, fontsize=font_size)
|
||||
axes[i].set_xticks(x)
|
||||
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
|
||||
|
||||
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
|
||||
max_val = max(latencies) * max_val_ratios[i]
|
||||
axes[i].set_ylim(0, max_val)
|
||||
axes[i].tick_params(axis='y', labelsize=12)
|
||||
|
||||
if i % 2 == 0:
|
||||
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
|
||||
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
|
||||
|
||||
ax2: Axes = axes[i].twinx()
|
||||
ax2.plot(x, hit_counts,
|
||||
linestyle='--',
|
||||
marker='o',
|
||||
markersize=6,
|
||||
linewidth=1.5,
|
||||
color='k',
|
||||
markerfacecolor='none',
|
||||
zorder=20)
|
||||
|
||||
ax2.set_ylim(0, max_unified_hit)
|
||||
ax2.tick_params(axis='y', labelsize=12)
|
||||
if i % 2 == 1:
|
||||
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
|
||||
|
||||
for j, val in enumerate(hit_counts):
|
||||
if val > 0:
|
||||
ax2.annotate(f"{val:.1f}%",
|
||||
(x[j], val),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 5),
|
||||
ha='center',
|
||||
va='bottom',
|
||||
fontsize=10,
|
||||
fontweight='bold')
|
||||
|
||||
# Create legend for both plots
|
||||
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
|
||||
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
|
||||
|
||||
# --- MODIFICATION FOR LEGEND AT THE TOP ---
|
||||
fig.legend(handles=[bar_patch, line_patch],
|
||||
loc='upper center', # Position the legend at the upper center
|
||||
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
|
||||
# 0.97 means 97% from the bottom, so near the top)
|
||||
ncol=3,
|
||||
fontsize=font_size-2)
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Set common x-axis label - you might want to add this back if needed
|
||||
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
|
||||
|
||||
# --- MODIFICATION FOR TIGHT LAYOUT ---
|
||||
# Adjust rect to make space for the legend at the top.
|
||||
# (left, bottom, right, top_for_subplots)
|
||||
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
|
||||
# leaving the top portion (0.93 to 1.0) for the legend.
|
||||
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Create directory if it doesn't exist (optional, good practice)
|
||||
import os
|
||||
if not os.path.exists(SAVE_PTH):
|
||||
os.makedirs(SAVE_PTH)
|
||||
|
||||
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
|
||||
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
|
||||
# plt.show() # Optional: to display the plot
|
||||
Reference in New Issue
Block a user