Files
LEANN/research/paper_plot/main_exp.py
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

441 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.lines as mlines
import pandas as pd
import numpy as np
from matplotlib.patches import FancyArrowPatch
sns.set_theme(style="ticks", font_scale=1.2)
plt.rcParams['axes.grid'] = True
plt.rcParams['axes.grid.which'] = 'major'
plt.rcParams['grid.linestyle'] = '--'
plt.rcParams['grid.color'] = 'gray'
plt.rcParams['grid.alpha'] = 0.3
plt.rcParams['xtick.minor.visible'] = False
plt.rcParams['ytick.minor.visible'] = False
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["text.usetex"] = True
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
# 0.085s 0.217s 0.472s
# llm_inference_time=[0.085, 0.217, 0.472, 0] # Will be replaced by CSV data
# llm_inference_time_for_mac = [0.316, 0.717, 1.468, 0] # Will be replaced by CSV data
def parse_latency_data(csv_path):
df = pd.read_csv(csv_path)
latency_data = {}
llm_gen_times = {} # To store LLM generation times: (dataset, hardware) -> time
for _, row in df.iterrows():
dataset = row['Dataset']
hardware = row['Hardware']
recall_target_str = row['Recall_target'].replace('%', '')
try:
recall_target = float(recall_target_str)
except ValueError:
print(f"Warning: Could not parse recall_target '{row['Recall_target']}'. Skipping row.")
continue
if (dataset, hardware) not in llm_gen_times: # Read once per (dataset, hardware)
llm_time_val = pd.to_numeric(row.get('LLM_Gen_Time_1B'), errors='coerce')
if not pd.isna(llm_time_val):
llm_gen_times[(dataset, hardware)] = llm_time_val
else:
llm_gen_times[(dataset, hardware)] = np.nan # Store NaN if unparsable/missing
cols_to_skip = ['Dataset', 'Hardware', 'Recall_target',
'LLM_Gen_Time_1B', 'LLM_Gen_Time_3B', 'LLM_Gen_Time_7B']
for col in df.columns:
if col not in cols_to_skip:
method_name = col
key = (dataset, hardware, method_name)
if key not in latency_data:
latency_data[key] = []
try:
latency_value = float(row[method_name])
latency_data[key].append((recall_target, latency_value))
except ValueError:
# Handle cases where latency might be non-numeric (e.g., 'N/A' or empty)
print(f"Warning: Could not parse latency for {method_name} at {dataset}/{hardware}/Recall {recall_target} ('{row[method_name]}'). Skipping this point.")
latency_data[key].append((recall_target, np.nan)) # Or skip appending
# Sort by recall for consistent plotting
for key in latency_data:
latency_data[key].sort(key=lambda x: x[0])
return latency_data, llm_gen_times
def parse_storage_data(csv_path):
df = pd.read_csv(csv_path)
storage_data = {}
# Assuming the first column is 'MetricType' (RAM/Storage) and subsequent columns are methods
# And the header row is like: MetricType, Method1, Method2, ...
# Transpose to make methods as rows for easier lookup might be an option,
# but let's try direct parsing.
# Find the row for RAM and Storage
ram_row = df[df.iloc[:, 0] == 'RAM'].iloc[0]
storage_row = df[df.iloc[:, 0] == 'Storage'].iloc[0]
methods = df.columns[1:] # First column is the metric type label
for method in methods:
storage_data[method] = {
'RAM': pd.to_numeric(ram_row[method], errors='coerce'),
'Storage': pd.to_numeric(storage_row[method], errors='coerce')
}
return storage_data
# Load data
latency_csv_path = 'paper_plot/data/main_latency.csv'
storage_csv_path = 'paper_plot/data/ram_storage.csv'
latency_data, llm_generation_times = parse_latency_data(latency_csv_path)
storage_info = parse_storage_data(storage_csv_path)
# --- Determine unique Datasets and Hardware combinations to plot for ---
unique_dataset_hardware_configs = sorted(list(set((d, h) for d, h, m in latency_data.keys())))
if not unique_dataset_hardware_configs:
print("Error: No (Dataset, Hardware) combinations found in latency data. Check CSV paths and content.")
exit()
# --- Define constants for plotting ---
all_method_names = sorted(list(set(m for d,h,m in latency_data.keys())))
if not all_method_names:
# Fallback if latency_data is empty but storage_info might have method names
all_method_names = sorted(list(storage_info.keys()))
if not all_method_names:
print("Error: No method names found in data. Cannot proceed with plotting.")
exit()
method_markers = {
'HNSW': 'o',
'IVF': 'X',
'DiskANN': 's',
'IVF-Disk': 'P',
'IVF-Recompute': '^',
'Our': '*',
'BM25': "v"
# Add more if necessary, or make it dynamic
}
method_display_names = {
'IVF-Recompute': 'IVF-Recompute (EdgeRAG)',
# 其他方法保持原名
}
# Ensure all methods have a marker
default_markers = ['^', 'v', '<', '>', 'H', 'h', '+', 'x', '|', '_']
next_default_marker = 0
for mn in all_method_names:
if mn not in method_markers:
print(f"mn: {mn}")
method_markers[mn] = default_markers[next_default_marker % len(default_markers)]
next_default_marker +=1
recall_levels_present = sorted(list(set(r for key in latency_data for r, l in latency_data[key])))
# Define colors for up to a few common recall levels, add more if needed
base_recall_colors = {
85.0: "#1f77b4", # Blue
90.0: "#ff7f0e", # Orange
95.0: "#2ca02c", # Green
# Add more if other recall % values exist
}
recall_colors = {}
color_palette = sns.color_palette("viridis", n_colors=len(recall_levels_present))
for idx, r_level in enumerate(recall_levels_present):
recall_colors[r_level] = base_recall_colors.get(r_level, color_palette[idx % len(color_palette)])
# --- Determine global x (latency) and y (storage) limits for consistent axes ---
all_latency_values = []
all_storage_values = []
raw_data_size = 76 # Raw data size in GB
for ds_hw_key in unique_dataset_hardware_configs:
current_ds, current_hw = ds_hw_key
for method_name in all_method_names:
# Get storage for this method
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
if not np.isnan(disk_storage):
all_storage_values.append(disk_storage)
# Get latencies for this method under current_ds, current_hw
latency_key = (current_ds, current_hw, method_name)
if latency_key in latency_data:
for recall, latency in latency_data[latency_key]:
if not np.isnan(latency):
all_latency_values.append(latency)
# Add padding to limits
min_lat = min(all_latency_values) if all_latency_values else 0.001
max_lat = max(all_latency_values) if all_latency_values else 1
min_store = min(all_storage_values) if all_storage_values else 0
max_store = max(all_storage_values) if all_storage_values else 1
# Convert storage values to proportion of raw data
min_store_proportion = min_store / raw_data_size if all_storage_values else 0
max_store_proportion = max_store / raw_data_size if all_storage_values else 0.1
# Padding for log scale latency - adjust minimum to be more reasonable
lat_log_min = -1 # Changed from -2 to -1 to set minimum to 10^-1 (0.1s)
lat_log_max = np.log10(max_lat) if max_lat > 0 else 3 # default to 1000 s
lat_padding = (lat_log_max - lat_log_min) * 0.05
global_xlim = [10**(lat_log_min - lat_padding), 10**(lat_log_max + lat_padding)]
if global_xlim[0] <= 0: global_xlim[0] = 0.1 # Changed from 0.01 to 0.1
# Padding for linear scale storage proportion
store_padding = (max_store_proportion - min_store_proportion) * 0.05
global_ylim = [max(0, min_store_proportion - store_padding), max_store_proportion + store_padding]
if global_ylim[0] >= global_ylim[1]: # Avoid inverted or zero range
global_ylim[1] = global_ylim[0] + 0.1
# After loading the data and before plotting, add this code to reorder the datasets
# Find where you define all_datasets (around line 95)
# Original code:
all_datasets = sorted(list(set(ds for ds, _ in unique_dataset_hardware_configs)))
# Replace with this to specify the exact order:
all_datasets_unsorted = list(set(ds for ds, _ in unique_dataset_hardware_configs))
desired_order = ['NQ', 'TriviaQA', 'GPQA','HotpotQA']
all_datasets = [ds for ds in desired_order if ds in all_datasets_unsorted]
# Add any datasets that might be in the data but not in our desired_order list
all_datasets.extend([ds for ds in all_datasets_unsorted if ds not in desired_order])
# Then the rest of your code remains the same:
a10_configs = [(ds, 'A10') for ds in all_datasets if (ds, 'A10') in unique_dataset_hardware_configs]
mac_configs = [(ds, 'MAC') for ds in all_datasets if (ds, 'MAC') in unique_dataset_hardware_configs]
# Create two figures - one for A10 and one for MAC
hardware_configs = [a10_configs, mac_configs]
hardware_names = ['A10', 'MAC']
for fig_idx, configs_for_this_figure in enumerate(hardware_configs):
if not configs_for_this_figure:
continue
num_cols_this_figure = len(configs_for_this_figure)
# 1 row, num_cols_this_figure columns
fig, axs = plt.subplots(1, num_cols_this_figure, figsize=(7 * num_cols_this_figure, 6), sharex=True, sharey=True, squeeze=False)
# fig.suptitle(f"Latency vs. Storage ({hardware_names[fig_idx]})", fontsize=18, y=0.98)
for subplot_idx, (current_ds, current_hw) in enumerate(configs_for_this_figure):
ax = axs[0, subplot_idx] # Accessing column in the first row
ax.set_title(f"{current_ds}", fontsize=25) # No need to show hardware in title since it's in suptitle
for method_name in all_method_names:
marker = method_markers.get(method_name, '+')
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
latency_points_key = (current_ds, current_hw, method_name)
if latency_points_key in latency_data:
points_for_method = latency_data[latency_points_key]
print(f"points_for_method: {points_for_method}")
for recall, latency in points_for_method:
# Only skip if latency is invalid (since we need log scale for x-axis)
# But allow zero storage since y-axis is now linear
if np.isnan(latency) or np.isnan(disk_storage) or latency <= 0:
continue
# Add LLM generation time from CSV
current_llm_add_time = llm_generation_times.get((current_ds, current_hw))
if current_llm_add_time is not None and not np.isnan(current_llm_add_time):
latency = latency + current_llm_add_time
else:
raise ValueError(f"No LLM generation time found for {current_ds} on {current_hw}")
# Special handling for BM25
if method_name == 'BM25':
# BM25 is only valid for 85% recall points (other points are 0)
if recall != 85.0:
continue
color = 'grey'
else:
# Use the color for target recall
color = recall_colors.get(recall, 'grey')
# Convert storage to proportion
disk_storage_proportion = disk_storage / raw_data_size
size = 80
x_offset = -50
if current_ds == 'GPQA':
x_offset = -32
# Apply a small vertical offset to IVF-Recompute points to make them more visible
if method_name == 'IVF-Recompute':
# Add a small vertical offset (adjust the 0.05 value as needed)
disk_storage_proportion += 0.07
size = 80
if method_name == 'DiskANN':
size = 50
if method_name == 'Our':
size = 140
disk_storage_proportion += 0.05
# Add "Pareto Frontier" label to Our method points
if recall == 95:
ax.annotate('Ours',
(latency, disk_storage_proportion),
xytext=(x_offset, 25), # Increased leftward offset from -65 to -120
textcoords='offset points',
fontsize=20,
color='red',
weight='bold',
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="red", alpha=0.7))
# Increase size for BM25 points
if method_name == 'BM25':
size = 70
size*=5
ax.scatter(latency, disk_storage_proportion, marker=marker, color=color,
s=size, alpha=0.85, edgecolors='black', linewidths=0.7)
ax.set_xscale("log")
ax.set_yscale("linear") # CHANGED from log scale to linear scale for Y-axis
# Generate appropriate powers of 10 based on your data range
min_power = -1
max_power = 4
log_ticks = [10**i for i in range(min_power, max_power+1)]
# Set custom tick positions
ax.set_xticks(log_ticks)
# Create custom bold LaTeX labels with 10^n format
log_tick_labels = [fr'$\mathbf{{10^{{{i}}}}}$' for i in range(min_power, max_power+1)]
ax.set_xticklabels(log_tick_labels, fontsize=24)
# Apply global limits
if subplot_idx == 0:
ax.set_xlim(global_xlim)
ax.set_ylim(global_ylim)
ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.7)
# Remove minor grid lines completely
ax.grid(False, which="minor")
# Remove ticks
# First set the shared parameters for both axes
ax.tick_params(axis='both', which='both', length=0, labelsize=24)
# Then set the padding only for the x-axis
ax.tick_params(axis='x', which='both', pad=10)
if subplot_idx == 0: # Y-label only for the leftmost subplot
ax.set_ylabel("Proportional Size", fontsize=24)
# X-label for all subplots in a 1xN layout can be okay, or just the middle/last one.
# Let's put it on all for now.
ax.set_xlabel("Latency (s)", fontsize=25)
# Display 100%, 200%, 300% for yaxis
ax.set_yticks([1, 2, 3])
ax.set_yticklabels(['100\%', '200\\%', '300\\%'])
# Create a custom arrow with "Better" text inside
# Create the arrow patch with a wider shaft
arrow = FancyArrowPatch(
(0.8, 0.8), # Start point (top-right)
(0.65, 0.6), # End point (toward bottom-left)
transform=ax.transAxes,
arrowstyle='simple,head_width=40,head_length=35,tail_width=20', # Increased arrow dimensions
facecolor='white',
edgecolor='black',
linewidth=3, # Thicker outline
zorder=5
)
# Add the arrow to the plot
ax.add_patch(arrow)
# Calculate the midpoint of the arrow for text placement
mid_x = (0.8 + 0.65) / 2 + 0.002 + 0.01
mid_y = (0.8 + 0.6) / 2 + 0.01
# Add the "Better" text at the midpoint of the arrow
ax.text(mid_x, mid_y, 'Better',
transform=ax.transAxes,
ha='center',
va='center',
fontsize=16, # Increased font size from 12 to 16
fontweight='bold',
rotation=40, # Rotate to match arrow direction
zorder=6) # Ensure text is on top of arrow
# Create legends (once per figure)
method_legend_handles = []
for method, marker_style in method_markers.items():
if method in all_method_names:
print(f"method: {method}")
# Use black color for BM25 in the legend
if method == 'BM25':
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
markersize=10, label=method))
else:
if method in method_display_names:
method = method_display_names[method]
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
markersize=10, label=method))
recall_legend_handles = []
sorted_recall_levels = sorted(recall_colors.keys())
for r_level in sorted_recall_levels:
recall_legend_handles.append(mlines.Line2D([], [], color=recall_colors[r_level], marker='o', linestyle='None',
markersize=20, label=f"Target Recall={r_level:.0f}\%"))
# 将图例分成两行:第一行是方法,第二行是召回率
if fig_idx == 0:
# 从方法列表中先排除'Our'
other_methods = [m for m in all_method_names if m != 'Our']
# 按照需要的顺序创建方法列表(将'Our'放在最后)
ordered_methods = other_methods + (['Our'] if 'Our' in all_method_names else [])
# 按照新顺序创建方法图例句柄
method_legend_handles = []
for method in ordered_methods:
if method in method_markers:
marker_style = method_markers[method]
# 使用显示名称映射
display_name = method_display_names.get(method, method)
color = 'black'
marker_size = 22
if method == 'Our':
marker_size = 27
elif 'IVF-Recompute' in method or 'EdgeRAG' in method:
marker_size = 17
elif 'DiskANN' in method:
marker_size = 19
elif 'BM25' in method:
marker_size = 20
method_legend_handles.append(mlines.Line2D([], [], color=color, marker=marker_style,
linestyle='None', markersize=marker_size, label=display_name))
# 创建召回率图例(第二行)- 注意位置调整,放在方法图例下方
recall_legend = fig.legend(handles=recall_legend_handles,
loc='upper center', bbox_to_anchor=(0.5, 1.05), # y坐标降低放在第一行下方
ncol=len(recall_legend_handles), fontsize=28)
# 创建方法图例(第一行)
method_legend = fig.legend(handles=method_legend_handles,
loc='upper center', bbox_to_anchor=(0.5, 0.91),
ncol=len(method_legend_handles), fontsize=28)
# 添加图例到渲染器
fig.add_artist(method_legend)
fig.add_artist(recall_legend)
# 调整布局,为顶部的两行图例留出更多空间
plt.tight_layout(rect=(0, 0, 1.0, 0.74)) # 顶部空间从0.9调整到0.85,给两行图例留出更多空间
save_path = f'./paper_plot/figures/main_exp_fig_{fig_idx+1}.pdf'
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved figure {fig_idx+1} to {save_path}")
plt.show()