227 lines
9.3 KiB
Python
227 lines
9.3 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
# import matplotlib.ticker as mticker # Not actively used
|
|
import os
|
|
|
|
FIGURE_PATH = "paper_plot/figures"
|
|
|
|
try:
|
|
os.makedirs(FIGURE_PATH, exist_ok=True)
|
|
print(f"Images will be saved to: {os.path.abspath(FIGURE_PATH)}")
|
|
except OSError as e:
|
|
print(f"Create {FIGURE_PATH} failed: {e}. Images will be saved in the current working directory.")
|
|
FIGURE_PATH = "."
|
|
|
|
plt.rcParams["font.family"] = "Helvetica"
|
|
plt.rcParams["ytick.direction"] = "in"
|
|
plt.rcParams["hatch.linewidth"] = 2
|
|
plt.rcParams["font.weight"] = "bold"
|
|
plt.rcParams["axes.labelweight"] = "bold"
|
|
plt.rcParams["text.usetex"] = True
|
|
|
|
method_labels = ["gte-small (33M)", "contriever-msmarco (110M)"]
|
|
dataset_names = ["NQ", "TriviaQA"]
|
|
metrics_plot1 = ["Exact Match", "F1"]
|
|
|
|
small_nq_f1 = 0.2621040899
|
|
small_tq_f1 = 0.4698198059
|
|
small_nq_em_score = 0.1845
|
|
small_tq_em_score = 0.4015
|
|
small_nq_time = 1.137
|
|
small_tq_time = 1.173
|
|
|
|
large_nq_f1 = 0.2841386117
|
|
large_tq_f1 = 0.4548340289
|
|
large_nq_em_score = 0.206
|
|
large_tq_em_score = 0.382
|
|
large_nq_time = 2.632
|
|
large_tq_time = 2.684
|
|
|
|
data_scores_plot1 = {
|
|
"NQ": {"Exact Match": [small_nq_em_score, large_nq_em_score], "F1": [small_nq_f1, large_nq_f1]},
|
|
"TriviaQA": {"Exact Match": [small_tq_em_score, large_tq_em_score], "F1": [small_tq_f1, large_tq_f1]}
|
|
}
|
|
latency_data_plot2 = {
|
|
"NQ": [small_nq_time, large_nq_time],
|
|
"TriviaQA": [small_tq_time, large_tq_time]
|
|
}
|
|
|
|
edgecolors = ["dimgrey", "tomato"]
|
|
hatches = ["/////", "\\\\\\\\\\"]
|
|
|
|
# Changed: bar_center_separation_in_group increased for larger gap
|
|
bar_center_separation_in_group = 0.42
|
|
# Changed: bar_visual_width decreased for narrower bars
|
|
bar_visual_width = 0.28
|
|
|
|
figsize_plot1 = (4, 2.5)
|
|
# Changed: figsize_plot2 width adjusted to match figsize_plot1 for legend/caption alignment
|
|
figsize_plot2 = (2.5, 2.5)
|
|
|
|
# Define plot1_xlim_per_subplot globally so it can be accessed by create_plot2_latency
|
|
plot1_xlim_per_subplot = (0.0, 2.0) # Explicit xlim for plot 1 subplots
|
|
|
|
common_subplots_adjust_params = dict(wspace=0.30, top=0.80, bottom=0.22, left=0.09, right=0.96)
|
|
|
|
|
|
def create_plot1_em_f1():
|
|
fig, axs = plt.subplots(1, 2, figsize=figsize_plot1)
|
|
fig.subplots_adjust(**common_subplots_adjust_params)
|
|
|
|
num_methods = len(method_labels)
|
|
metric_group_centers = np.array([0.5, 1.5])
|
|
# plot1_xlim_per_subplot is now global
|
|
|
|
for i, dataset_name in enumerate(dataset_names):
|
|
ax = axs[i]
|
|
for metric_idx, metric_name in enumerate(metrics_plot1):
|
|
metric_center_pos = metric_group_centers[metric_idx]
|
|
current_scores_raw = data_scores_plot1[dataset_name][metric_name]
|
|
current_scores_percent = [val * 100 for val in current_scores_raw]
|
|
|
|
for j, method_label in enumerate(method_labels):
|
|
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
|
bar_center_pos = metric_center_pos + offset
|
|
|
|
ax.bar(
|
|
bar_center_pos, current_scores_percent[j], width=bar_visual_width, color="white",
|
|
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
|
|
label=method_label if i == 0 and metric_idx == 0 else None
|
|
)
|
|
ax.text(
|
|
bar_center_pos, current_scores_percent[j] + 0.8, f"{current_scores_percent[j]:.1f}",
|
|
ha='center', va='bottom', fontsize=8, fontweight='bold'
|
|
)
|
|
|
|
ax.set_xticks(metric_group_centers)
|
|
ax.set_xticklabels(metrics_plot1, fontsize=9, fontweight='bold')
|
|
ax.set_title(dataset_name, fontsize=12, fontweight='bold')
|
|
ax.set_xlim(plot1_xlim_per_subplot) # Apply consistent xlim
|
|
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
|
|
|
if i == 0:
|
|
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
|
|
|
all_subplot_scores_percent = []
|
|
for metric_name_iter in metrics_plot1:
|
|
all_subplot_scores_percent.extend([val * 100 for val in data_scores_plot1[dataset_name][metric_name_iter]])
|
|
|
|
max_val = max(all_subplot_scores_percent) if all_subplot_scores_percent else 0
|
|
ax.set_ylim(0, max_val * 1.22 if max_val > 0 else 10)
|
|
ax.tick_params(axis='y', labelsize=12)
|
|
|
|
for spine in ax.spines.values():
|
|
spine.set_visible(True)
|
|
spine.set_linewidth(1.0)
|
|
spine.set_edgecolor("black")
|
|
|
|
handles, labels = axs[0].get_legend_handles_labels()
|
|
fig.legend(
|
|
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=len(method_labels),
|
|
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
|
|
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
|
|
prop={"weight": "bold", "size": 9}
|
|
)
|
|
|
|
# fig.text(0.5, 0.06, "(a) EM \& F1", ha='center', va='center', fontweight='bold', fontsize=11)
|
|
|
|
|
|
save_path = os.path.join(FIGURE_PATH, "plot1_em_f1.pdf")
|
|
# plt.tight_layout() # Adjusted call below
|
|
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
|
|
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
|
|
plt.close(fig)
|
|
print(f"Figure 1 (Exact Match & F1) has been saved to: {save_path}")
|
|
|
|
def create_plot2_latency():
|
|
fig, axs = plt.subplots(1, 2, figsize=figsize_plot2) # figsize_plot2 width is now 8.0
|
|
fig.subplots_adjust(**common_subplots_adjust_params)
|
|
|
|
num_methods = len(method_labels)
|
|
method_group_center_in_subplot = 0.5
|
|
|
|
# Calculate bar extents to determine focused xlim
|
|
bar_positions_calc = []
|
|
for j_idx in range(num_methods):
|
|
offset_calc = (j_idx - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
|
bar_center_pos_calc = method_group_center_in_subplot + offset_calc
|
|
bar_positions_calc.append(bar_center_pos_calc)
|
|
|
|
min_bar_actual_edge = min(bar_positions_calc) - bar_visual_width / 2.0
|
|
max_bar_actual_edge = max(bar_positions_calc) + bar_visual_width / 2.0
|
|
|
|
# Define padding around the bars
|
|
# Option 1: Fixed padding (e.g., 0.15 as derived from plot 1 visual)
|
|
# padding_val = 0.15
|
|
# plot2_xlim_calculated = (min_bar_actual_edge - padding_val, max_bar_actual_edge + padding_val)
|
|
# This would be (0.15 - 0.15, 0.85 + 0.15) = (0.0, 1.0)
|
|
|
|
# Option 2: Center the group (0.5) in a span of 1.0
|
|
plot2_xlim_calculated = (method_group_center_in_subplot - 0.5, method_group_center_in_subplot + 0.5)
|
|
# This is (0.5 - 0.5, 0.5 + 0.5) = (0.0, 1.0)
|
|
# This is simpler and achieves the (0.0, 1.0) directly.
|
|
|
|
for i, dataset_name in enumerate(dataset_names):
|
|
ax = axs[i]
|
|
current_latencies = latency_data_plot2[dataset_name]
|
|
|
|
for j, method_label in enumerate(method_labels):
|
|
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
|
bar_center_pos = method_group_center_in_subplot + offset
|
|
|
|
ax.bar(
|
|
bar_center_pos, current_latencies[j], width=bar_visual_width, color="white",
|
|
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
|
|
label=method_label if i == 0 else None
|
|
)
|
|
ax.text(
|
|
bar_center_pos, current_latencies[j] + 0.05, f"{current_latencies[j]:.2f}",
|
|
ha='center', va='bottom', fontsize=10, fontweight='bold'
|
|
)
|
|
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
|
|
|
ax.set_xticks([0.5])
|
|
ax.set_xticklabels(["Latency"], color="white", fontsize=12)
|
|
# set tick hatches
|
|
ax.tick_params(axis='x', colors="white")
|
|
ax.set_title(dataset_name, fontsize=13, fontweight='bold')
|
|
ax.set_xlim(plot2_xlim_calculated)
|
|
|
|
if i == 0:
|
|
ax.set_ylabel("Latency (s)", fontsize=12, fontweight="bold")
|
|
|
|
max_latency_in_subplot = max(current_latencies) if current_latencies else 0
|
|
ax.set_ylim(0, max_latency_in_subplot * 1.22 if max_latency_in_subplot > 0 else 1)
|
|
ax.tick_params(axis='y', labelsize=12)
|
|
|
|
for spine in ax.spines.values():
|
|
spine.set_visible(True)
|
|
spine.set_linewidth(1.0)
|
|
spine.set_edgecolor("black")
|
|
|
|
handles, labels = axs[0].get_legend_handles_labels()
|
|
fig.legend(
|
|
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=num_methods,
|
|
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
|
|
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
|
|
prop={"weight": "bold", "size": 9}
|
|
)
|
|
|
|
# fig.text(0.5, 0.06, "(b) Latency", ha='center', va='center', fontweight='bold', fontsize=11)
|
|
|
|
save_path = os.path.join(FIGURE_PATH, "plot2_latency.pdf")
|
|
# plt.tight_layout() # Adjusted call below
|
|
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
|
|
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
|
|
plt.close(fig)
|
|
print(f"Figure 2 (Latency) has been saved to: {save_path}")
|
|
|
|
if __name__ == "__main__":
|
|
print("Start generating figures...")
|
|
if plt.rcParams["text.usetex"]:
|
|
print("Info: LaTeX rendering is enabled. Ensure LaTeX is installed and configured if issues arise, or set plt.rcParams['text.usetex'] to False.")
|
|
|
|
create_plot1_em_f1()
|
|
create_plot2_latency()
|
|
print("All figures have been generated.")
|