Files
LEANN/research/paper_plot/small_emb.py
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

227 lines
9.3 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
# import matplotlib.ticker as mticker # Not actively used
import os
FIGURE_PATH = "paper_plot/figures"
try:
os.makedirs(FIGURE_PATH, exist_ok=True)
print(f"Images will be saved to: {os.path.abspath(FIGURE_PATH)}")
except OSError as e:
print(f"Create {FIGURE_PATH} failed: {e}. Images will be saved in the current working directory.")
FIGURE_PATH = "."
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 2
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
method_labels = ["gte-small (33M)", "contriever-msmarco (110M)"]
dataset_names = ["NQ", "TriviaQA"]
metrics_plot1 = ["Exact Match", "F1"]
small_nq_f1 = 0.2621040899
small_tq_f1 = 0.4698198059
small_nq_em_score = 0.1845
small_tq_em_score = 0.4015
small_nq_time = 1.137
small_tq_time = 1.173
large_nq_f1 = 0.2841386117
large_tq_f1 = 0.4548340289
large_nq_em_score = 0.206
large_tq_em_score = 0.382
large_nq_time = 2.632
large_tq_time = 2.684
data_scores_plot1 = {
"NQ": {"Exact Match": [small_nq_em_score, large_nq_em_score], "F1": [small_nq_f1, large_nq_f1]},
"TriviaQA": {"Exact Match": [small_tq_em_score, large_tq_em_score], "F1": [small_tq_f1, large_tq_f1]}
}
latency_data_plot2 = {
"NQ": [small_nq_time, large_nq_time],
"TriviaQA": [small_tq_time, large_tq_time]
}
edgecolors = ["dimgrey", "tomato"]
hatches = ["/////", "\\\\\\\\\\"]
# Changed: bar_center_separation_in_group increased for larger gap
bar_center_separation_in_group = 0.42
# Changed: bar_visual_width decreased for narrower bars
bar_visual_width = 0.28
figsize_plot1 = (4, 2.5)
# Changed: figsize_plot2 width adjusted to match figsize_plot1 for legend/caption alignment
figsize_plot2 = (2.5, 2.5)
# Define plot1_xlim_per_subplot globally so it can be accessed by create_plot2_latency
plot1_xlim_per_subplot = (0.0, 2.0) # Explicit xlim for plot 1 subplots
common_subplots_adjust_params = dict(wspace=0.30, top=0.80, bottom=0.22, left=0.09, right=0.96)
def create_plot1_em_f1():
fig, axs = plt.subplots(1, 2, figsize=figsize_plot1)
fig.subplots_adjust(**common_subplots_adjust_params)
num_methods = len(method_labels)
metric_group_centers = np.array([0.5, 1.5])
# plot1_xlim_per_subplot is now global
for i, dataset_name in enumerate(dataset_names):
ax = axs[i]
for metric_idx, metric_name in enumerate(metrics_plot1):
metric_center_pos = metric_group_centers[metric_idx]
current_scores_raw = data_scores_plot1[dataset_name][metric_name]
current_scores_percent = [val * 100 for val in current_scores_raw]
for j, method_label in enumerate(method_labels):
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
bar_center_pos = metric_center_pos + offset
ax.bar(
bar_center_pos, current_scores_percent[j], width=bar_visual_width, color="white",
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
label=method_label if i == 0 and metric_idx == 0 else None
)
ax.text(
bar_center_pos, current_scores_percent[j] + 0.8, f"{current_scores_percent[j]:.1f}",
ha='center', va='bottom', fontsize=8, fontweight='bold'
)
ax.set_xticks(metric_group_centers)
ax.set_xticklabels(metrics_plot1, fontsize=9, fontweight='bold')
ax.set_title(dataset_name, fontsize=12, fontweight='bold')
ax.set_xlim(plot1_xlim_per_subplot) # Apply consistent xlim
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
if i == 0:
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
all_subplot_scores_percent = []
for metric_name_iter in metrics_plot1:
all_subplot_scores_percent.extend([val * 100 for val in data_scores_plot1[dataset_name][metric_name_iter]])
max_val = max(all_subplot_scores_percent) if all_subplot_scores_percent else 0
ax.set_ylim(0, max_val * 1.22 if max_val > 0 else 10)
ax.tick_params(axis='y', labelsize=12)
for spine in ax.spines.values():
spine.set_visible(True)
spine.set_linewidth(1.0)
spine.set_edgecolor("black")
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=len(method_labels),
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
prop={"weight": "bold", "size": 9}
)
# fig.text(0.5, 0.06, "(a) EM \& F1", ha='center', va='center', fontweight='bold', fontsize=11)
save_path = os.path.join(FIGURE_PATH, "plot1_em_f1.pdf")
# plt.tight_layout() # Adjusted call below
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
plt.close(fig)
print(f"Figure 1 (Exact Match & F1) has been saved to: {save_path}")
def create_plot2_latency():
fig, axs = plt.subplots(1, 2, figsize=figsize_plot2) # figsize_plot2 width is now 8.0
fig.subplots_adjust(**common_subplots_adjust_params)
num_methods = len(method_labels)
method_group_center_in_subplot = 0.5
# Calculate bar extents to determine focused xlim
bar_positions_calc = []
for j_idx in range(num_methods):
offset_calc = (j_idx - (num_methods - 1) / 2.0) * bar_center_separation_in_group
bar_center_pos_calc = method_group_center_in_subplot + offset_calc
bar_positions_calc.append(bar_center_pos_calc)
min_bar_actual_edge = min(bar_positions_calc) - bar_visual_width / 2.0
max_bar_actual_edge = max(bar_positions_calc) + bar_visual_width / 2.0
# Define padding around the bars
# Option 1: Fixed padding (e.g., 0.15 as derived from plot 1 visual)
# padding_val = 0.15
# plot2_xlim_calculated = (min_bar_actual_edge - padding_val, max_bar_actual_edge + padding_val)
# This would be (0.15 - 0.15, 0.85 + 0.15) = (0.0, 1.0)
# Option 2: Center the group (0.5) in a span of 1.0
plot2_xlim_calculated = (method_group_center_in_subplot - 0.5, method_group_center_in_subplot + 0.5)
# This is (0.5 - 0.5, 0.5 + 0.5) = (0.0, 1.0)
# This is simpler and achieves the (0.0, 1.0) directly.
for i, dataset_name in enumerate(dataset_names):
ax = axs[i]
current_latencies = latency_data_plot2[dataset_name]
for j, method_label in enumerate(method_labels):
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
bar_center_pos = method_group_center_in_subplot + offset
ax.bar(
bar_center_pos, current_latencies[j], width=bar_visual_width, color="white",
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
label=method_label if i == 0 else None
)
ax.text(
bar_center_pos, current_latencies[j] + 0.05, f"{current_latencies[j]:.2f}",
ha='center', va='bottom', fontsize=10, fontweight='bold'
)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
ax.set_xticks([0.5])
ax.set_xticklabels(["Latency"], color="white", fontsize=12)
# set tick hatches
ax.tick_params(axis='x', colors="white")
ax.set_title(dataset_name, fontsize=13, fontweight='bold')
ax.set_xlim(plot2_xlim_calculated)
if i == 0:
ax.set_ylabel("Latency (s)", fontsize=12, fontweight="bold")
max_latency_in_subplot = max(current_latencies) if current_latencies else 0
ax.set_ylim(0, max_latency_in_subplot * 1.22 if max_latency_in_subplot > 0 else 1)
ax.tick_params(axis='y', labelsize=12)
for spine in ax.spines.values():
spine.set_visible(True)
spine.set_linewidth(1.0)
spine.set_edgecolor("black")
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=num_methods,
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
prop={"weight": "bold", "size": 9}
)
# fig.text(0.5, 0.06, "(b) Latency", ha='center', va='center', fontweight='bold', fontsize=11)
save_path = os.path.join(FIGURE_PATH, "plot2_latency.pdf")
# plt.tight_layout() # Adjusted call below
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
plt.close(fig)
print(f"Figure 2 (Latency) has been saved to: {save_path}")
if __name__ == "__main__":
print("Start generating figures...")
if plt.rcParams["text.usetex"]:
print("Info: LaTeX rendering is enabled. Ensure LaTeX is installed and configured if issues arise, or set plt.rcParams['text.usetex'] to False.")
create_plot1_em_f1()
create_plot2_latency()
print("All figures have been generated.")