import matplotlib.pyplot as plt import seaborn as sns import matplotlib.lines as mlines import pandas as pd import numpy as np from matplotlib.patches import FancyArrowPatch sns.set_theme(style="ticks", font_scale=1.2) plt.rcParams['axes.grid'] = True plt.rcParams['axes.grid.which'] = 'major' plt.rcParams['grid.linestyle'] = '--' plt.rcParams['grid.color'] = 'gray' plt.rcParams['grid.alpha'] = 0.3 plt.rcParams['xtick.minor.visible'] = False plt.rcParams['ytick.minor.visible'] = False plt.rcParams["font.family"] = "Helvetica" plt.rcParams["text.usetex"] = True plt.rcParams["font.weight"] = "bold" plt.rcParams["axes.labelweight"] = "bold" # Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B) # 0.085s 0.217s 0.472s # llm_inference_time=[0.085, 0.217, 0.472, 0] # Will be replaced by CSV data # llm_inference_time_for_mac = [0.316, 0.717, 1.468, 0] # Will be replaced by CSV data def parse_latency_data(csv_path): df = pd.read_csv(csv_path) latency_data = {} llm_gen_times = {} # To store LLM generation times: (dataset, hardware) -> time for _, row in df.iterrows(): dataset = row['Dataset'] hardware = row['Hardware'] recall_target_str = row['Recall_target'].replace('%', '') try: recall_target = float(recall_target_str) except ValueError: print(f"Warning: Could not parse recall_target '{row['Recall_target']}'. Skipping row.") continue if (dataset, hardware) not in llm_gen_times: # Read once per (dataset, hardware) llm_time_val = pd.to_numeric(row.get('LLM_Gen_Time_1B'), errors='coerce') if not pd.isna(llm_time_val): llm_gen_times[(dataset, hardware)] = llm_time_val else: llm_gen_times[(dataset, hardware)] = np.nan # Store NaN if unparsable/missing cols_to_skip = ['Dataset', 'Hardware', 'Recall_target', 'LLM_Gen_Time_1B', 'LLM_Gen_Time_3B', 'LLM_Gen_Time_7B'] for col in df.columns: if col not in cols_to_skip: method_name = col key = (dataset, hardware, method_name) if key not in latency_data: latency_data[key] = [] try: latency_value = float(row[method_name]) latency_data[key].append((recall_target, latency_value)) except ValueError: # Handle cases where latency might be non-numeric (e.g., 'N/A' or empty) print(f"Warning: Could not parse latency for {method_name} at {dataset}/{hardware}/Recall {recall_target} ('{row[method_name]}'). Skipping this point.") latency_data[key].append((recall_target, np.nan)) # Or skip appending # Sort by recall for consistent plotting for key in latency_data: latency_data[key].sort(key=lambda x: x[0]) return latency_data, llm_gen_times def parse_storage_data(csv_path): df = pd.read_csv(csv_path) storage_data = {} # Assuming the first column is 'MetricType' (RAM/Storage) and subsequent columns are methods # And the header row is like: MetricType, Method1, Method2, ... # Transpose to make methods as rows for easier lookup might be an option, # but let's try direct parsing. # Find the row for RAM and Storage ram_row = df[df.iloc[:, 0] == 'RAM'].iloc[0] storage_row = df[df.iloc[:, 0] == 'Storage'].iloc[0] methods = df.columns[1:] # First column is the metric type label for method in methods: storage_data[method] = { 'RAM': pd.to_numeric(ram_row[method], errors='coerce'), 'Storage': pd.to_numeric(storage_row[method], errors='coerce') } return storage_data # Load data latency_csv_path = 'paper_plot/data/main_latency.csv' storage_csv_path = 'paper_plot/data/ram_storage.csv' latency_data, llm_generation_times = parse_latency_data(latency_csv_path) storage_info = parse_storage_data(storage_csv_path) # --- Determine unique Datasets and Hardware combinations to plot for --- unique_dataset_hardware_configs = sorted(list(set((d, h) for d, h, m in latency_data.keys()))) if not unique_dataset_hardware_configs: print("Error: No (Dataset, Hardware) combinations found in latency data. Check CSV paths and content.") exit() # --- Define constants for plotting --- all_method_names = sorted(list(set(m for d,h,m in latency_data.keys()))) if not all_method_names: # Fallback if latency_data is empty but storage_info might have method names all_method_names = sorted(list(storage_info.keys())) if not all_method_names: print("Error: No method names found in data. Cannot proceed with plotting.") exit() method_markers = { 'HNSW': 'o', 'IVF': 'X', 'DiskANN': 's', 'IVF-Disk': 'P', 'IVF-Recompute': '^', 'Our': '*', 'BM25': "v" # Add more if necessary, or make it dynamic } method_display_names = { 'IVF-Recompute': 'IVF-Recompute (EdgeRAG)', # 其他方法保持原名 } # Ensure all methods have a marker default_markers = ['^', 'v', '<', '>', 'H', 'h', '+', 'x', '|', '_'] next_default_marker = 0 for mn in all_method_names: if mn not in method_markers: print(f"mn: {mn}") method_markers[mn] = default_markers[next_default_marker % len(default_markers)] next_default_marker +=1 recall_levels_present = sorted(list(set(r for key in latency_data for r, l in latency_data[key]))) # Define colors for up to a few common recall levels, add more if needed base_recall_colors = { 85.0: "#1f77b4", # Blue 90.0: "#ff7f0e", # Orange 95.0: "#2ca02c", # Green # Add more if other recall % values exist } recall_colors = {} color_palette = sns.color_palette("viridis", n_colors=len(recall_levels_present)) for idx, r_level in enumerate(recall_levels_present): recall_colors[r_level] = base_recall_colors.get(r_level, color_palette[idx % len(color_palette)]) # --- Determine global x (latency) and y (storage) limits for consistent axes --- all_latency_values = [] all_storage_values = [] raw_data_size = 76 # Raw data size in GB for ds_hw_key in unique_dataset_hardware_configs: current_ds, current_hw = ds_hw_key for method_name in all_method_names: # Get storage for this method disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan) if not np.isnan(disk_storage): all_storage_values.append(disk_storage) # Get latencies for this method under current_ds, current_hw latency_key = (current_ds, current_hw, method_name) if latency_key in latency_data: for recall, latency in latency_data[latency_key]: if not np.isnan(latency): all_latency_values.append(latency) # Add padding to limits min_lat = min(all_latency_values) if all_latency_values else 0.001 max_lat = max(all_latency_values) if all_latency_values else 1 min_store = min(all_storage_values) if all_storage_values else 0 max_store = max(all_storage_values) if all_storage_values else 1 # Convert storage values to proportion of raw data min_store_proportion = min_store / raw_data_size if all_storage_values else 0 max_store_proportion = max_store / raw_data_size if all_storage_values else 0.1 # Padding for log scale latency - adjust minimum to be more reasonable lat_log_min = -1 # Changed from -2 to -1 to set minimum to 10^-1 (0.1s) lat_log_max = np.log10(max_lat) if max_lat > 0 else 3 # default to 1000 s lat_padding = (lat_log_max - lat_log_min) * 0.05 global_xlim = [10**(lat_log_min - lat_padding), 10**(lat_log_max + lat_padding)] if global_xlim[0] <= 0: global_xlim[0] = 0.1 # Changed from 0.01 to 0.1 # Padding for linear scale storage proportion store_padding = (max_store_proportion - min_store_proportion) * 0.05 global_ylim = [max(0, min_store_proportion - store_padding), max_store_proportion + store_padding] if global_ylim[0] >= global_ylim[1]: # Avoid inverted or zero range global_ylim[1] = global_ylim[0] + 0.1 # After loading the data and before plotting, add this code to reorder the datasets # Find where you define all_datasets (around line 95) # Original code: all_datasets = sorted(list(set(ds for ds, _ in unique_dataset_hardware_configs))) # Replace with this to specify the exact order: all_datasets_unsorted = list(set(ds for ds, _ in unique_dataset_hardware_configs)) desired_order = ['NQ', 'TriviaQA', 'GPQA','HotpotQA'] all_datasets = [ds for ds in desired_order if ds in all_datasets_unsorted] # Add any datasets that might be in the data but not in our desired_order list all_datasets.extend([ds for ds in all_datasets_unsorted if ds not in desired_order]) # Then the rest of your code remains the same: a10_configs = [(ds, 'A10') for ds in all_datasets if (ds, 'A10') in unique_dataset_hardware_configs] mac_configs = [(ds, 'MAC') for ds in all_datasets if (ds, 'MAC') in unique_dataset_hardware_configs] # Create two figures - one for A10 and one for MAC hardware_configs = [a10_configs, mac_configs] hardware_names = ['A10', 'MAC'] for fig_idx, configs_for_this_figure in enumerate(hardware_configs): if not configs_for_this_figure: continue num_cols_this_figure = len(configs_for_this_figure) # 1 row, num_cols_this_figure columns fig, axs = plt.subplots(1, num_cols_this_figure, figsize=(7 * num_cols_this_figure, 6), sharex=True, sharey=True, squeeze=False) # fig.suptitle(f"Latency vs. Storage ({hardware_names[fig_idx]})", fontsize=18, y=0.98) for subplot_idx, (current_ds, current_hw) in enumerate(configs_for_this_figure): ax = axs[0, subplot_idx] # Accessing column in the first row ax.set_title(f"{current_ds}", fontsize=25) # No need to show hardware in title since it's in suptitle for method_name in all_method_names: marker = method_markers.get(method_name, '+') disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan) latency_points_key = (current_ds, current_hw, method_name) if latency_points_key in latency_data: points_for_method = latency_data[latency_points_key] print(f"points_for_method: {points_for_method}") for recall, latency in points_for_method: # Only skip if latency is invalid (since we need log scale for x-axis) # But allow zero storage since y-axis is now linear if np.isnan(latency) or np.isnan(disk_storage) or latency <= 0: continue # Add LLM generation time from CSV current_llm_add_time = llm_generation_times.get((current_ds, current_hw)) if current_llm_add_time is not None and not np.isnan(current_llm_add_time): latency = latency + current_llm_add_time else: raise ValueError(f"No LLM generation time found for {current_ds} on {current_hw}") # Special handling for BM25 if method_name == 'BM25': # BM25 is only valid for 85% recall points (other points are 0) if recall != 85.0: continue color = 'grey' else: # Use the color for target recall color = recall_colors.get(recall, 'grey') # Convert storage to proportion disk_storage_proportion = disk_storage / raw_data_size size = 80 x_offset = -50 if current_ds == 'GPQA': x_offset = -32 # Apply a small vertical offset to IVF-Recompute points to make them more visible if method_name == 'IVF-Recompute': # Add a small vertical offset (adjust the 0.05 value as needed) disk_storage_proportion += 0.07 size = 80 if method_name == 'DiskANN': size = 50 if method_name == 'Our': size = 140 disk_storage_proportion += 0.05 # Add "Pareto Frontier" label to Our method points if recall == 95: ax.annotate('Ours', (latency, disk_storage_proportion), xytext=(x_offset, 25), # Increased leftward offset from -65 to -120 textcoords='offset points', fontsize=20, color='red', weight='bold', bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="red", alpha=0.7)) # Increase size for BM25 points if method_name == 'BM25': size = 70 size*=5 ax.scatter(latency, disk_storage_proportion, marker=marker, color=color, s=size, alpha=0.85, edgecolors='black', linewidths=0.7) ax.set_xscale("log") ax.set_yscale("linear") # CHANGED from log scale to linear scale for Y-axis # Generate appropriate powers of 10 based on your data range min_power = -1 max_power = 4 log_ticks = [10**i for i in range(min_power, max_power+1)] # Set custom tick positions ax.set_xticks(log_ticks) # Create custom bold LaTeX labels with 10^n format log_tick_labels = [fr'$\mathbf{{10^{{{i}}}}}$' for i in range(min_power, max_power+1)] ax.set_xticklabels(log_tick_labels, fontsize=24) # Apply global limits if subplot_idx == 0: ax.set_xlim(global_xlim) ax.set_ylim(global_ylim) ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.7) # Remove minor grid lines completely ax.grid(False, which="minor") # Remove ticks # First set the shared parameters for both axes ax.tick_params(axis='both', which='both', length=0, labelsize=24) # Then set the padding only for the x-axis ax.tick_params(axis='x', which='both', pad=10) if subplot_idx == 0: # Y-label only for the leftmost subplot ax.set_ylabel("Proportional Size", fontsize=24) # X-label for all subplots in a 1xN layout can be okay, or just the middle/last one. # Let's put it on all for now. ax.set_xlabel("Latency (s)", fontsize=25) # Display 100%, 200%, 300% for yaxis ax.set_yticks([1, 2, 3]) ax.set_yticklabels(['100\%', '200\\%', '300\\%']) # Create a custom arrow with "Better" text inside # Create the arrow patch with a wider shaft arrow = FancyArrowPatch( (0.8, 0.8), # Start point (top-right) (0.65, 0.6), # End point (toward bottom-left) transform=ax.transAxes, arrowstyle='simple,head_width=40,head_length=35,tail_width=20', # Increased arrow dimensions facecolor='white', edgecolor='black', linewidth=3, # Thicker outline zorder=5 ) # Add the arrow to the plot ax.add_patch(arrow) # Calculate the midpoint of the arrow for text placement mid_x = (0.8 + 0.65) / 2 + 0.002 + 0.01 mid_y = (0.8 + 0.6) / 2 + 0.01 # Add the "Better" text at the midpoint of the arrow ax.text(mid_x, mid_y, 'Better', transform=ax.transAxes, ha='center', va='center', fontsize=16, # Increased font size from 12 to 16 fontweight='bold', rotation=40, # Rotate to match arrow direction zorder=6) # Ensure text is on top of arrow # Create legends (once per figure) method_legend_handles = [] for method, marker_style in method_markers.items(): if method in all_method_names: print(f"method: {method}") # Use black color for BM25 in the legend if method == 'BM25': method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None', markersize=10, label=method)) else: if method in method_display_names: method = method_display_names[method] method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None', markersize=10, label=method)) recall_legend_handles = [] sorted_recall_levels = sorted(recall_colors.keys()) for r_level in sorted_recall_levels: recall_legend_handles.append(mlines.Line2D([], [], color=recall_colors[r_level], marker='o', linestyle='None', markersize=20, label=f"Target Recall={r_level:.0f}\%")) # 将图例分成两行:第一行是方法,第二行是召回率 if fig_idx == 0: # 从方法列表中先排除'Our' other_methods = [m for m in all_method_names if m != 'Our'] # 按照需要的顺序创建方法列表(将'Our'放在最后) ordered_methods = other_methods + (['Our'] if 'Our' in all_method_names else []) # 按照新顺序创建方法图例句柄 method_legend_handles = [] for method in ordered_methods: if method in method_markers: marker_style = method_markers[method] # 使用显示名称映射 display_name = method_display_names.get(method, method) color = 'black' marker_size = 22 if method == 'Our': marker_size = 27 elif 'IVF-Recompute' in method or 'EdgeRAG' in method: marker_size = 17 elif 'DiskANN' in method: marker_size = 19 elif 'BM25' in method: marker_size = 20 method_legend_handles.append(mlines.Line2D([], [], color=color, marker=marker_style, linestyle='None', markersize=marker_size, label=display_name)) # 创建召回率图例(第二行)- 注意位置调整,放在方法图例下方 recall_legend = fig.legend(handles=recall_legend_handles, loc='upper center', bbox_to_anchor=(0.5, 1.05), # y坐标降低,放在第一行下方 ncol=len(recall_legend_handles), fontsize=28) # 创建方法图例(第一行) method_legend = fig.legend(handles=method_legend_handles, loc='upper center', bbox_to_anchor=(0.5, 0.91), ncol=len(method_legend_handles), fontsize=28) # 添加图例到渲染器 fig.add_artist(method_legend) fig.add_artist(recall_legend) # 调整布局,为顶部的两行图例留出更多空间 plt.tight_layout(rect=(0, 0, 1.0, 0.74)) # 顶部空间从0.9调整到0.85,给两行图例留出更多空间 save_path = f'./paper_plot/figures/main_exp_fig_{fig_idx+1}.pdf' plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Saved figure {fig_idx+1} to {save_path}") plt.show()