163 lines
5.7 KiB
Python
163 lines
5.7 KiB
Python
import csv
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import csv
|
|
|
|
plt.rcParams["font.family"] = "Helvetica"
|
|
plt.rcParams["ytick.direction"] = "in"
|
|
plt.rcParams["hatch.linewidth"] = 1
|
|
plt.rcParams["font.weight"] = "bold"
|
|
plt.rcParams["axes.labelweight"] = "bold"
|
|
plt.rcParams["text.usetex"] = True
|
|
SAVE_PTH = "./paper_plot/figures"
|
|
font_size = 16
|
|
|
|
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
|
|
# 0.085s 0.217s 0.472s
|
|
llm_inference_time=[0.085, 0.217, 0.472, 0]
|
|
|
|
USE_LLM_INDEX = 3 # +0
|
|
|
|
file_path = "./paper_plot/data/main_latency.csv"
|
|
|
|
with open(file_path, mode="r", newline="") as file:
|
|
reader = csv.reader(file)
|
|
data = list(reader)
|
|
|
|
# 打印原始数据
|
|
for row in data:
|
|
print(",".join(row))
|
|
|
|
|
|
|
|
|
|
models = ["A10", "MAC"]
|
|
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
|
data = [[float(cell) if cell.isdigit() else cell for cell in row] for row in data[1:]]
|
|
for k, model in enumerate(models):
|
|
|
|
fig, axes = plt.subplots(1, 4)
|
|
fig.set_size_inches(20, 3)
|
|
plt.subplots_adjust(wspace=0, hspace=0)
|
|
|
|
total_width, n = 6, 6
|
|
group = 1
|
|
width = total_width * 0.9 / n
|
|
x = np.arange(group) * n
|
|
exit_idx_x = x + (total_width - width) / n
|
|
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "mediumpurple", "green", "red", "blue", "yellow", "silver"]
|
|
# hatches = ["", "\\\\", "//", "||", "x", "--", "..", "", "\\\\", "//", "||", "x", "--", ".."]
|
|
hatches =["\\\\\\","\\\\"]
|
|
|
|
labels = [
|
|
"HNSW",
|
|
"IVF",
|
|
"DiskANN",
|
|
"IVF-Disk",
|
|
"IVF-Recompute",
|
|
"Our",
|
|
# "DGL-OnDisk",
|
|
]
|
|
if k == 0:
|
|
x_labels = "GraphSAGE"
|
|
else:
|
|
x_labels = "GAT"
|
|
|
|
yticks = [0.01, 0.1, 1, 10, 100, 1000,10000] # Log scale ticks
|
|
val_limit = 15000 # Upper limit for the plot
|
|
|
|
for i in range(4):
|
|
axes[i].set_yscale('log') # Set y-axis to logarithmic scale
|
|
axes[i].set_yticks(yticks)
|
|
axes[i].set_ylim(0.01, val_limit) # Lower limit should be > 0 for log scale
|
|
|
|
axes[i].tick_params(axis="y", labelsize=10)
|
|
|
|
axes[i].set_xticks([])
|
|
# axes[i].set_xticklabels()
|
|
axes[i].set_xlabel(datasets[i], fontsize=font_size)
|
|
axes[i].grid(axis="y", linestyle="--")
|
|
axes[i].set_xlim(exit_idx_x[0] - 0.15 * width - 0.2, exit_idx_x[0] + (n-0.25)* width + 0.2)
|
|
for j in range(n):
|
|
##TODO add label
|
|
|
|
# num = float(data[i * 2 + k][j + 3])
|
|
# plot_label = [num]
|
|
# if j == 6 and i == 3:
|
|
# plot_label = ["N/A"]
|
|
# num = 0
|
|
local_hatches=["////","\\\\","xxxx"]
|
|
# here add 3 bars rather than one bar TODO
|
|
print('exit_idx_x',exit_idx_x)
|
|
|
|
# Check if all three models for this algorithm are OOM (data = 0)
|
|
is_oom = True
|
|
for m in range(3):
|
|
if float(data[i * 6 + k*3 + m][j + 3]) != 0:
|
|
is_oom = False
|
|
break
|
|
|
|
if is_oom:
|
|
# Draw a cross for OOM instead of bars
|
|
pos = exit_idx_x + j * width + width * 0.3 # Center position for cross
|
|
marker_size = width * 150 # Size of the cross
|
|
axes[i].scatter(pos, 0.02, marker='x', color=edgecolors[j], s=marker_size,
|
|
linewidth=4, label=labels[j] if j < len(labels) else "", zorder=20)
|
|
else:
|
|
# Create three separate bar calls instead of trying to plot multiple bars at once
|
|
for m in range(3):
|
|
num = float(data[i * 6 + k*3 +m][j + 3]) +llm_inference_time[USE_LLM_INDEX]
|
|
plot_label = [num]
|
|
pos = exit_idx_x + j * width + width * 0.3 * m
|
|
print(f"j: {j}, m: {m}, pos: {pos}")
|
|
# For log scale, we need to ensure values are positive
|
|
plot_value = max(0.01, num) if num < val_limit else val_limit
|
|
container = axes[i].bar(
|
|
pos,
|
|
plot_value,
|
|
width=width * 0.3,
|
|
color="white",
|
|
edgecolor=edgecolors[j],
|
|
# edgecolor="k",
|
|
hatch=local_hatches[m], # Use different hatches for each of the 3 bars
|
|
linewidth=1.0,
|
|
label=labels[j] if m == 0 else "", # Only add label for the first bar
|
|
zorder=10,
|
|
)
|
|
# axes[i].bar_label(
|
|
# container,
|
|
# plot_label,
|
|
# fontsize=font_size - 2,
|
|
# zorder=200,
|
|
# fontweight="bold",
|
|
# )
|
|
|
|
if k == 0:
|
|
axes[0].legend(
|
|
bbox_to_anchor=(3.25, 1.02),
|
|
ncol=7,
|
|
loc="lower right",
|
|
# fontsize=font_size,
|
|
# markerscale=3,
|
|
labelspacing=0.2,
|
|
edgecolor="black",
|
|
facecolor="white",
|
|
framealpha=1,
|
|
shadow=False,
|
|
# fancybox=False,
|
|
handlelength=2,
|
|
handletextpad=0.5,
|
|
columnspacing=0.5,
|
|
prop={"weight": "bold", "size": font_size},
|
|
).set_zorder(100)
|
|
|
|
axes[0].set_ylabel("Runtime (log scale)", fontsize=font_size, fontweight="bold")
|
|
axes[0].set_yticklabels([r"$10^{-2}$", r"$10^{-1}$", r"$10^{0}$", r"$10^{1}$", r"$10^{2}$", r"$10^{3}$",r"$10^{4}$"], fontsize=font_size)
|
|
axes[1].set_yticklabels([])
|
|
axes[2].set_yticklabels([])
|
|
axes[3].set_yticklabels([])
|
|
|
|
plt.savefig(f"{SAVE_PTH }/speed_{model}_revised.pdf", bbox_inches="tight", dpi=300)
|
|
## print save
|
|
print(f"{SAVE_PTH }/speed_{model}_revised.pdf") |