Files
LEANN/research/paper_plot/disk_cache.py
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

151 lines
5.3 KiB
Python

import matplotlib
from matplotlib.axes import Axes
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
# plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
plt.rcParams['text.latex.preamble'] = r"""
\usepackage{helvet} % Use Helvetica font for text
\usepackage{sfmath} % Use sans-serif font for math
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
\usepackage[T1]{fontenc} % Recommended for font encoding
"""
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
SAVE_PTH = "./paper_plot/figures"
font_size = 16
# New data in dictionary format
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
latency_data = {
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
}
cache_hit_counts = {
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
}
# Create the figure with 4 subplots in a 2x2 grid
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
# Bar style settings
width = 0.7
x = np.arange(len(cache_ratios))
# Define hatch patterns for different cache ratios
hatch_patterns = ['//', '//', '//', '//', '//']
# Find max cache hit value across all datasets for unified y-axis
all_hit_counts = []
for dataset in datasets:
all_hit_counts.extend(cache_hit_counts[dataset])
max_unified_hit = max(all_hit_counts) * 1.13
for i, dataset in enumerate(datasets):
latencies = latency_data[dataset]
hit_counts = cache_hit_counts[dataset]
for j, val in enumerate(latencies):
container = axes[i].bar(
x[j],
val,
width=width,
color="white",
edgecolor="black",
linewidth=1.0,
zorder=10,
)
axes[i].bar_label(
container,
[f"{val:.2f}"],
fontsize=10,
zorder=200,
fontweight="bold",
)
axes[i].set_title(dataset, fontsize=font_size)
axes[i].set_xticks(x)
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
max_val = max(latencies) * max_val_ratios[i]
axes[i].set_ylim(0, max_val)
axes[i].tick_params(axis='y', labelsize=12)
if i % 2 == 0:
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
ax2: Axes = axes[i].twinx()
ax2.plot(x, hit_counts,
linestyle='--',
marker='o',
markersize=6,
linewidth=1.5,
color='k',
markerfacecolor='none',
zorder=20)
ax2.set_ylim(0, max_unified_hit)
ax2.tick_params(axis='y', labelsize=12)
if i % 2 == 1:
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
for j, val in enumerate(hit_counts):
if val > 0:
ax2.annotate(f"{val:.1f}%",
(x[j], val),
textcoords="offset points",
xytext=(0, 5),
ha='center',
va='bottom',
fontsize=10,
fontweight='bold')
# Create legend for both plots
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
# --- MODIFICATION FOR LEGEND AT THE TOP ---
fig.legend(handles=[bar_patch, line_patch],
loc='upper center', # Position the legend at the upper center
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
# 0.97 means 97% from the bottom, so near the top)
ncol=3,
fontsize=font_size-2)
# --- END OF MODIFICATION ---
# Set common x-axis label - you might want to add this back if needed
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
# --- MODIFICATION FOR TIGHT LAYOUT ---
# Adjust rect to make space for the legend at the top.
# (left, bottom, right, top_for_subplots)
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
# leaving the top portion (0.93 to 1.0) for the legend.
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
# --- END OF MODIFICATION ---
# Create directory if it doesn't exist (optional, good practice)
import os
if not os.path.exists(SAVE_PTH):
os.makedirs(SAVE_PTH)
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
# plt.show() # Optional: to display the plot