Initial commit
This commit is contained in:
165
research/paper_plot/acc_fig.py
Normal file
165
research/paper_plot/acc_fig.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Set plot parameters
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
# Path settings
|
||||
FIGURE_PATH = "./paper_plot/figures"
|
||||
|
||||
# Load accuracy data
|
||||
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
|
||||
|
||||
# Create figure with 4 subplots (one for each dataset)
|
||||
fig, axs = plt.subplots(1, 4)
|
||||
fig.set_size_inches(9, 2.5)
|
||||
|
||||
# Reduce the spacing between subplots
|
||||
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
|
||||
|
||||
# Define datasets and their columns
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
||||
metrics = ["Exact Match", "F1"]
|
||||
|
||||
# Define bar settings - make bars thicker
|
||||
# total_width, n = 0.9, 3 # increased total width and n for three models
|
||||
# width = total_width / n
|
||||
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
|
||||
# It's also used as the base for calculating the actual plotted bar width.
|
||||
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
|
||||
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
|
||||
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
|
||||
n = 3 # Number of models
|
||||
width = 0.64 # Distance between centers of adjacent bars in a group
|
||||
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
|
||||
|
||||
# Colors and hatches
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
|
||||
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
|
||||
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
|
||||
|
||||
# Create plots for each dataset
|
||||
for i, dataset in enumerate(datasets):
|
||||
ax = axs[i]
|
||||
|
||||
# Get data for this dataset and convert to percentages
|
||||
em_values = [
|
||||
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
|
||||
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
|
||||
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
|
||||
]
|
||||
f1_values = [
|
||||
acc_data.loc[0, f"{dataset} F1"] * 100,
|
||||
acc_data.loc[1, f"{dataset} F1"] * 100,
|
||||
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
|
||||
]
|
||||
|
||||
# Define x positions for bars
|
||||
# For EM: center - width, center, center + width
|
||||
# For F1: center - width, center, center + width
|
||||
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
|
||||
bar_offsets = [-width, 0, width]
|
||||
|
||||
# Plot all bars on the same axis
|
||||
for metric_idx, metric_group_center in enumerate(group_centers):
|
||||
values_to_plot = em_values if metric_idx == 0 else f1_values
|
||||
for j, model_label in enumerate(labels):
|
||||
x_pos = metric_group_center + bar_offsets[j]
|
||||
bar_value = values_to_plot[j]
|
||||
|
||||
ax.bar(
|
||||
x_pos,
|
||||
bar_value,
|
||||
width=width * bar_width_plotting_factor, # Use the new factor for bar width
|
||||
color="white",
|
||||
edgecolor=edgecolors[j],
|
||||
hatch=hatches[j],
|
||||
linewidth=1.5,
|
||||
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
|
||||
)
|
||||
|
||||
# Add value on top of bar
|
||||
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
|
||||
f"{bar_value:.1f}", ha='center', va='bottom',
|
||||
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
|
||||
|
||||
# Set x-ticks and labels
|
||||
ax.set_xticks(group_centers) # Position ticks at the center of each group
|
||||
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
|
||||
|
||||
# Now, shift these labels slightly to the right
|
||||
# Adjust this value to control the amount of shift (in data coordinates)
|
||||
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
|
||||
# horizontal_shift = 0.7 # Try adjusting this value
|
||||
|
||||
# for label in xticklabels:
|
||||
# # Get the current x position (which is the tick location)
|
||||
# current_x_pos = label.get_position()[0]
|
||||
# # Set the new x position by adding the shift
|
||||
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
|
||||
# # Ensure the label remains horizontally centered on this new x position
|
||||
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
|
||||
# label.set_horizontalalignment('center')
|
||||
|
||||
# Set title
|
||||
ax.set_title(dataset, fontsize=14)
|
||||
|
||||
# Set y-label for all subplots
|
||||
if i == 0:
|
||||
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
||||
else:
|
||||
# Hide y-tick labels for non-first subplots to save space
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
# Set y-limits based on data range
|
||||
all_values = em_values + f1_values
|
||||
max_val = max(all_values)
|
||||
min_val = min(all_values)
|
||||
|
||||
# Special handling for GPQA which has very low values
|
||||
if dataset == "GPQA":
|
||||
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
|
||||
else:
|
||||
# Reduce the extra space above the bars
|
||||
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
|
||||
|
||||
# Format y-ticks as percentages
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
||||
|
||||
# Set x-limits to properly space the bars with less blank space
|
||||
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
|
||||
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
|
||||
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
|
||||
|
||||
# Add a box around the subplot
|
||||
# for spine in ax.spines.values():
|
||||
# spine.set_visible(True)
|
||||
# spine.set_linewidth(1.0)
|
||||
|
||||
# Add legend to first subplot
|
||||
if i == 0:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
|
||||
ncol=3, # Changed to 3 columns for three labels
|
||||
loc="upper center",
|
||||
labelspacing=0.1,
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
fancybox=False,
|
||||
handlelength=1.0,
|
||||
handletextpad=0.6,
|
||||
columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 12},
|
||||
)
|
||||
|
||||
# Save figure with tight layout but no additional padding
|
||||
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
|
||||
plt.show()
|
||||
309
research/paper_plot/analyze_visits.py
Normal file
309
research/paper_plot/analyze_visits.py
Normal file
@@ -0,0 +1,309 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /hnsw_degree_visit_plot_binned_academic.py
|
||||
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
|
||||
# per degree bin, styled for academic publications, with caching.
|
||||
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
|
||||
|
||||
# %%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import re
|
||||
from collections import Counter
|
||||
import os # For robust filepath manipulation
|
||||
import math # For calculating scaling factor
|
||||
import pickle # For caching data
|
||||
|
||||
# %%
|
||||
# --- Matplotlib parameters for academic paper style (from reference) ---
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
|
||||
|
||||
# --- Define styles from reference ---
|
||||
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
||||
|
||||
# %%
|
||||
# --- File Paths ---
|
||||
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
|
||||
visit_log_file = './re.log'
|
||||
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
|
||||
# --- CACHE FILE PATH: Keep this consistent ---
|
||||
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
|
||||
|
||||
# --- Configuration ---
|
||||
# Set to True to bypass cache and force recomputation.
|
||||
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
|
||||
FORCE_RECOMPUTE = False
|
||||
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
|
||||
|
||||
# Create directory for figures if it doesn't exist
|
||||
output_dir = os.path.dirname(output_image_file)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
print(f"Created directory: {output_dir}")
|
||||
|
||||
# %%
|
||||
# --- Attempt to load data from cache or compute ---
|
||||
df_plot_data = None
|
||||
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
|
||||
|
||||
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
|
||||
try:
|
||||
with open(CACHE_FILE_PATH, 'rb') as f:
|
||||
cache_content = pickle.load(f)
|
||||
df_plot_data = cache_content['data']
|
||||
bin_size_for_plot = cache_content['bin_size']
|
||||
# Basic validation of cached data
|
||||
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
|
||||
if not isinstance(df_plot_data, pd.DataFrame) or \
|
||||
'degree_bin_label' not in df_plot_data.columns or \
|
||||
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
|
||||
not isinstance(bin_size_for_plot, int):
|
||||
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
|
||||
df_plot_data = None # Invalidate to trigger recomputation
|
||||
else:
|
||||
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
|
||||
|
||||
# --- Modify the label loaded from cache for display purpose ---
|
||||
# This modification only happens when data is loaded from cache and meets specific conditions.
|
||||
# Assumption: If the bin_size_for_plot in cache is 5,
|
||||
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
|
||||
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
|
||||
# Check if "0-4" label exists
|
||||
if '0-4' in df_plot_data['degree_bin_label'].values:
|
||||
# Use .loc to ensure the modification is on the original DataFrame
|
||||
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
|
||||
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
|
||||
except Exception as e:
|
||||
print(f"Error loading from cache: {e}. Recomputing.")
|
||||
df_plot_data = None # Invalidate to trigger recomputation
|
||||
|
||||
if df_plot_data is None:
|
||||
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
|
||||
# --- 1. Read Degree Distribution File ---
|
||||
degrees_data = []
|
||||
try:
|
||||
with open(degree_file, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
line_stripped = line.strip()
|
||||
if line_stripped:
|
||||
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
|
||||
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
|
||||
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
|
||||
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
|
||||
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
|
||||
|
||||
|
||||
if not degrees_data:
|
||||
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
|
||||
exit()
|
||||
df_degrees = pd.DataFrame(degrees_data)
|
||||
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
|
||||
|
||||
# --- 2. Read Visit Log File and Count Frequencies ---
|
||||
visit_counts = Counter()
|
||||
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
|
||||
try:
|
||||
with open(visit_log_file, 'r') as f_log:
|
||||
for line_num, line in enumerate(f_log, 1):
|
||||
match = node_id_pattern.search(line)
|
||||
if match:
|
||||
try:
|
||||
node_id = int(match.group(2))
|
||||
visit_counts[node_id] += 1 # Increment visit count for the node
|
||||
except ValueError:
|
||||
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
|
||||
if not df_degrees.empty:
|
||||
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
|
||||
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
|
||||
# Generate visit counts to test different probability magnitudes
|
||||
if node_id_val % 23 == 0: # Very low probability
|
||||
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
|
||||
elif node_id_val % 11 == 0: # Low probability
|
||||
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
|
||||
elif node_id_val % 5 == 0: # Moderate probability
|
||||
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
|
||||
else: # Higher probability (but still < 1000 visits for a single node usually)
|
||||
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
|
||||
visit_counts[node_id_val] = np.random.poisson(lambda_val)
|
||||
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
|
||||
|
||||
if not visit_counts:
|
||||
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
|
||||
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
|
||||
else:
|
||||
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
|
||||
df_visits = pd.DataFrame(df_visits_list)
|
||||
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
|
||||
|
||||
# --- 3. Merge Degree Data with Visit Data ---
|
||||
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
|
||||
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
|
||||
print(f"Merged data contains {len(df_merged)} entries.")
|
||||
|
||||
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
|
||||
current_bin_size = 5
|
||||
bin_size_for_plot = current_bin_size
|
||||
|
||||
if not df_degrees.empty:
|
||||
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
|
||||
|
||||
df_merged_with_bins = df_merged.copy()
|
||||
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
|
||||
|
||||
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
|
||||
total_visit_count_in_bin=('visit_count', 'sum'),
|
||||
node_count_in_bin=('node_id', 'nunique')
|
||||
).reset_index()
|
||||
|
||||
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
|
||||
# This value is what gets cached.
|
||||
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
|
||||
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
|
||||
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
|
||||
|
||||
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
|
||||
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
|
||||
|
||||
bin_to_drop_label = '60-64'
|
||||
original_length = len(df_binned_analysis)
|
||||
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
|
||||
if len(df_plot_data_intermediate) < original_length:
|
||||
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
|
||||
else:
|
||||
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
|
||||
|
||||
df_plot_data = df_plot_data_intermediate
|
||||
|
||||
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
|
||||
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
|
||||
|
||||
if df_plot_data is not None and not df_plot_data.empty:
|
||||
try:
|
||||
with open(CACHE_FILE_PATH, 'wb') as f:
|
||||
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
|
||||
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
|
||||
except Exception as e:
|
||||
print(f"Error saving data to cache: {e}")
|
||||
elif df_plot_data is None or df_plot_data.empty:
|
||||
print("Computed data for binned plot is empty, not saving to cache.")
|
||||
else:
|
||||
print("Degree data (df_degrees) is empty. Cannot perform binning.")
|
||||
df_plot_data = pd.DataFrame()
|
||||
bin_size_for_plot = current_bin_size
|
||||
|
||||
# %%
|
||||
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
|
||||
|
||||
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
|
||||
base_name, ext = os.path.splitext(output_image_file)
|
||||
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
|
||||
binned_output_image_file = base_name + ext
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
|
||||
|
||||
df_plot_data_plotting = df_plot_data.copy()
|
||||
# Calculate per-query probability: (avg visits over N queries) / N
|
||||
df_plot_data_plotting['per_query_visit_probability'] = \
|
||||
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
|
||||
|
||||
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
|
||||
|
||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
|
||||
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
|
||||
|
||||
apply_scaling_to_label_and_values = False # Initialize flag
|
||||
exponent_for_label_display = 0 # Initialize exponent
|
||||
|
||||
if pd.notna(max_probability) and max_probability > 0:
|
||||
potential_exponent = math.floor(math.log10(max_probability))
|
||||
|
||||
if potential_exponent <= -4 or potential_exponent >= 0:
|
||||
apply_scaling_to_label_and_values = True
|
||||
exponent_for_label_display = potential_exponent
|
||||
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
|
||||
|
||||
if apply_scaling_to_label_and_values:
|
||||
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
|
||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
|
||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
|
||||
else:
|
||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
|
||||
|
||||
elif pd.notna(max_probability) and max_probability == 0:
|
||||
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
|
||||
else:
|
||||
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
|
||||
|
||||
ax.bar(
|
||||
df_plot_data_plotting['degree_bin_label'],
|
||||
y_axis_values_to_plot,
|
||||
color='white',
|
||||
edgecolor=edgecolors_ref[0],
|
||||
linewidth=1.5,
|
||||
width=0.8
|
||||
)
|
||||
|
||||
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
|
||||
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
|
||||
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
|
||||
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
|
||||
|
||||
num_bins = len(df_plot_data_plotting)
|
||||
if num_bins > 12:
|
||||
ax.set_xticks(ax.get_xticks())
|
||||
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
|
||||
elif num_bins > 8:
|
||||
ax.tick_params(axis='x', labelsize=9)
|
||||
else:
|
||||
ax.tick_params(axis='x', labelsize=10)
|
||||
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
padding_factor = 0.05
|
||||
current_max_y_on_axis = y_axis_values_to_plot.max()
|
||||
|
||||
upper_y_limit = 0.1 # Default small upper limit
|
||||
if pd.notna(current_max_y_on_axis):
|
||||
if current_max_y_on_axis > 0:
|
||||
# Adjust minimum visible range based on whether scaling was applied and the exponent
|
||||
min_meaningful_limit = 0.01
|
||||
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
|
||||
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
|
||||
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
|
||||
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
|
||||
|
||||
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
|
||||
|
||||
else: # current_max_y_on_axis is 0
|
||||
upper_y_limit = 0.1
|
||||
ax.set_ylim(0, upper_y_limit)
|
||||
else:
|
||||
ax.set_ylim(0, 1.0) # Default for empty or NaN data
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
|
||||
print(f"Binned bar chart saved to {binned_output_image_file}")
|
||||
plt.show()
|
||||
plt.close(fig)
|
||||
else:
|
||||
if df_plot_data is None:
|
||||
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
|
||||
elif df_plot_data.empty:
|
||||
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
|
||||
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
|
||||
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
|
||||
|
||||
# %%
|
||||
print("Script finished.")
|
||||
7
research/paper_plot/b.md
Normal file
7
research/paper_plot/b.md
Normal file
@@ -0,0 +1,7 @@
|
||||
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data – up to 50× smaller than standard indexes – while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
81
research/paper_plot/cache_degree_data.py
Normal file
81
research/paper_plot/cache_degree_data.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
|
||||
BIG_GRAPH_PATHS = [
|
||||
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
|
||||
]
|
||||
STATS_FILE_NAME = "degree_distribution.txt"
|
||||
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
|
||||
"HNSW-Base",
|
||||
"DegreeGuide",
|
||||
"HNSW-D9",
|
||||
"RandCut",
|
||||
]
|
||||
# Average degrees are static and can be directly used in the plotting script or also cached.
|
||||
# For simplicity here, we'll focus on caching the dynamic degree arrays.
|
||||
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
|
||||
|
||||
# --- Cache File Configuration ---
|
||||
DATA_CACHE_DIR = "./paper_plot/data/"
|
||||
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
|
||||
|
||||
def create_degree_data_cache():
|
||||
"""
|
||||
Reads degree distribution data from specified text files and saves it
|
||||
into a compressed NumPy (.npz) cache file.
|
||||
"""
|
||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
||||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
||||
|
||||
cached_data = {}
|
||||
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
|
||||
|
||||
for i, base_path in enumerate(BIG_GRAPH_PATHS):
|
||||
method_label = BIG_GRAPH_LABELS[i]
|
||||
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
|
||||
|
||||
print(f"Processing: {method_label} from {degree_file_path}")
|
||||
|
||||
try:
|
||||
# Load degrees as integers
|
||||
degrees = np.loadtxt(degree_file_path, dtype=int)
|
||||
|
||||
if degrees.size == 0:
|
||||
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
|
||||
# Store an empty array or handle as needed. For npz, an empty array is fine.
|
||||
cached_data[method_label] = np.array([], dtype=int)
|
||||
else:
|
||||
# Store the loaded degrees array with the method label as the key
|
||||
cached_data[method_label] = degrees
|
||||
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
|
||||
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
|
||||
# Storing None might require special handling when loading. Empty array is safer for np.load.
|
||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
|
||||
except Exception as e:
|
||||
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
|
||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
|
||||
|
||||
if not cached_data:
|
||||
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
|
||||
return
|
||||
|
||||
try:
|
||||
# Save all collected degree arrays into a single .npz file.
|
||||
# Using savez_compressed for potentially smaller file size.
|
||||
np.savez_compressed(cache_file_path, **cached_data)
|
||||
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
|
||||
print("Cached arrays (keys):", list(cached_data.keys()))
|
||||
except Exception as e:
|
||||
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("--- Degree Distribution Data Caching Script ---")
|
||||
create_degree_data_cache()
|
||||
print("--- Caching script finished. ---")
|
||||
4
research/paper_plot/data/acc.csv
Normal file
4
research/paper_plot/data/acc.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
|
||||
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
|
||||
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
|
||||
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
|
||||
|
3
research/paper_plot/data/big_graph_degree_data.npz
Normal file
3
research/paper_plot/data/big_graph_degree_data.npz
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
|
||||
size 227482438
|
||||
21
research/paper_plot/data/branches.csv
Normal file
21
research/paper_plot/data/branches.csv
Normal file
@@ -0,0 +1,21 @@
|
||||
2,1,512,1024,0.541,0.326,1.659509202
|
||||
2,2,512,1024,0.979,0.621,1.576489533
|
||||
2,4,512,1024,1.846,0.977,1.889457523
|
||||
2,8,512,1024,3.575,1.943,1.83993824
|
||||
2,16,512,1024,7.035,3.733,1.884543263
|
||||
2,32,512,1024,15.655,8.517,1.838088529
|
||||
2,64,512,1024,32.772,17.43,1.88020654
|
||||
4,1,512,1024,2.675,1.38,1.938405797
|
||||
4,2,512,1024,5.397,2.339,2.307396323
|
||||
4,4,512,1024,10.672,4.944,2.158576052
|
||||
4,8,512,1024,21.061,9.266,2.272933305
|
||||
4,16,512,1024,46.332,18.334,2.527108105
|
||||
4,32,512,1024,99.607,36.156,2.754923111
|
||||
4,64,512,1024,186.348,72.356,2.575432583
|
||||
8,1,512,1024,7.325,4.087,1.792268167
|
||||
8,2,512,1024,14.109,7.491,1.883460152
|
||||
8,4,512,1024,28.499,14.013,2.033754371
|
||||
8,8,512,1024,65.222,27.453,2.375769497
|
||||
8,16,512,1024,146.294,52.55,2.783901047
|
||||
8,32,512,1024,277.099,103.61,2.674442621
|
||||
8,64,512,1024,512.979,208.36,2.461984066
|
||||
|
9
research/paper_plot/data/latency_ablation.csv
Normal file
9
research/paper_plot/data/latency_ablation.csv
Normal file
@@ -0,0 +1,9 @@
|
||||
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
|
||||
NQ,Latency,6.9,5.8,4.2,3.7
|
||||
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
|
||||
TriviaQA,Latency,17.054,14.542,12.046,10.83
|
||||
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
|
||||
GPQA,Latency,9.164,7.639,6.798,5.77
|
||||
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
|
||||
HotpotQA,Latency,60.279,39.827,50.664,29.868
|
||||
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
|
||||
|
25
research/paper_plot/data/main_latency.csv
Normal file
25
research/paper_plot/data/main_latency.csv
Normal file
@@ -0,0 +1,25 @@
|
||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
|
||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
|
||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
|
||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
|
||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
|
||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
|
||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
|
||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
|
||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
|
||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
|
||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
|
||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
|
||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
|
||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
|
||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
|
||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
|
||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
|
||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
|
||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
|
||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
|
||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
|
||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
|
||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
|
||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
|
||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
|
||||
|
25
research/paper_plot/data/main_latency_small.csv
Normal file
25
research/paper_plot/data/main_latency_small.csv
Normal file
@@ -0,0 +1,25 @@
|
||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
|
||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
|
||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
|
||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
|
||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
|
||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
|
||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
|
||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
|
||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
|
||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
|
||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
|
||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
|
||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
|
||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
|
||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
|
||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
|
||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
|
||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
|
||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
|
||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
|
||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
|
||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
|
||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
|
||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
|
||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
|
||||
|
3
research/paper_plot/data/ram_storage.csv
Normal file
3
research/paper_plot/data/ram_storage.csv
Normal file
@@ -0,0 +1,3 @@
|
||||
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
|
||||
RAM,190,171,10,0,0,0,0
|
||||
Storage,185.4,171,240,171,0.5,5,59
|
||||
|
12
research/paper_plot/data/swithc_e2e.csv
Normal file
12
research/paper_plot/data/swithc_e2e.csv
Normal file
@@ -0,0 +1,12 @@
|
||||
Torch,8,55.592
|
||||
Torch,16,75.439
|
||||
Torch,32,110.025
|
||||
Torch,64,186.496
|
||||
Tutel,8,56.718
|
||||
Tutel,16,82.121
|
||||
Tutel,32,125.070
|
||||
Tutel,64,216.191
|
||||
BRT,8,56.725
|
||||
BRT,16,79.291
|
||||
BRT,32,93.180
|
||||
BRT,64,118.923
|
||||
|
6
research/paper_plot/data/vary_cache.csv
Normal file
6
research/paper_plot/data/vary_cache.csv
Normal file
@@ -0,0 +1,6 @@
|
||||
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
|
||||
Latency,,,,,
|
||||
NQ,4.616,4.133,3.826,3.511,3.323
|
||||
TriviaQA,5.777,4.979,4.553,4.141,3.916
|
||||
GPQA,1.733,1.593,1.468,1.336,1.259
|
||||
Hotpot,15.515,13.479,12.383,11.216,10.606
|
||||
|
151
research/paper_plot/disk_cache.py
Normal file
151
research/paper_plot/disk_cache.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import matplotlib
|
||||
from matplotlib.axes import Axes
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
# plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
|
||||
plt.rcParams['text.latex.preamble'] = r"""
|
||||
\usepackage{helvet} % Use Helvetica font for text
|
||||
\usepackage{sfmath} % Use sans-serif font for math
|
||||
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
|
||||
\usepackage[T1]{fontenc} % Recommended for font encoding
|
||||
"""
|
||||
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
|
||||
SAVE_PTH = "./paper_plot/figures"
|
||||
font_size = 16
|
||||
|
||||
# New data in dictionary format
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
|
||||
|
||||
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
|
||||
latency_data = {
|
||||
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
|
||||
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
|
||||
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
|
||||
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
|
||||
}
|
||||
cache_hit_counts = {
|
||||
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
|
||||
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
|
||||
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
|
||||
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
|
||||
}
|
||||
|
||||
# Create the figure with 4 subplots in a 2x2 grid
|
||||
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
|
||||
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
|
||||
|
||||
# Bar style settings
|
||||
width = 0.7
|
||||
x = np.arange(len(cache_ratios))
|
||||
|
||||
# Define hatch patterns for different cache ratios
|
||||
hatch_patterns = ['//', '//', '//', '//', '//']
|
||||
|
||||
# Find max cache hit value across all datasets for unified y-axis
|
||||
all_hit_counts = []
|
||||
for dataset in datasets:
|
||||
all_hit_counts.extend(cache_hit_counts[dataset])
|
||||
max_unified_hit = max(all_hit_counts) * 1.13
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
latencies = latency_data[dataset]
|
||||
hit_counts = cache_hit_counts[dataset]
|
||||
|
||||
for j, val in enumerate(latencies):
|
||||
container = axes[i].bar(
|
||||
x[j],
|
||||
val,
|
||||
width=width,
|
||||
color="white",
|
||||
edgecolor="black",
|
||||
linewidth=1.0,
|
||||
zorder=10,
|
||||
)
|
||||
axes[i].bar_label(
|
||||
container,
|
||||
[f"{val:.2f}"],
|
||||
fontsize=10,
|
||||
zorder=200,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
axes[i].set_title(dataset, fontsize=font_size)
|
||||
axes[i].set_xticks(x)
|
||||
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
|
||||
|
||||
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
|
||||
max_val = max(latencies) * max_val_ratios[i]
|
||||
axes[i].set_ylim(0, max_val)
|
||||
axes[i].tick_params(axis='y', labelsize=12)
|
||||
|
||||
if i % 2 == 0:
|
||||
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
|
||||
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
|
||||
|
||||
ax2: Axes = axes[i].twinx()
|
||||
ax2.plot(x, hit_counts,
|
||||
linestyle='--',
|
||||
marker='o',
|
||||
markersize=6,
|
||||
linewidth=1.5,
|
||||
color='k',
|
||||
markerfacecolor='none',
|
||||
zorder=20)
|
||||
|
||||
ax2.set_ylim(0, max_unified_hit)
|
||||
ax2.tick_params(axis='y', labelsize=12)
|
||||
if i % 2 == 1:
|
||||
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
|
||||
|
||||
for j, val in enumerate(hit_counts):
|
||||
if val > 0:
|
||||
ax2.annotate(f"{val:.1f}%",
|
||||
(x[j], val),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 5),
|
||||
ha='center',
|
||||
va='bottom',
|
||||
fontsize=10,
|
||||
fontweight='bold')
|
||||
|
||||
# Create legend for both plots
|
||||
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
|
||||
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
|
||||
|
||||
# --- MODIFICATION FOR LEGEND AT THE TOP ---
|
||||
fig.legend(handles=[bar_patch, line_patch],
|
||||
loc='upper center', # Position the legend at the upper center
|
||||
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
|
||||
# 0.97 means 97% from the bottom, so near the top)
|
||||
ncol=3,
|
||||
fontsize=font_size-2)
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Set common x-axis label - you might want to add this back if needed
|
||||
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
|
||||
|
||||
# --- MODIFICATION FOR TIGHT LAYOUT ---
|
||||
# Adjust rect to make space for the legend at the top.
|
||||
# (left, bottom, right, top_for_subplots)
|
||||
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
|
||||
# leaving the top portion (0.93 to 1.0) for the legend.
|
||||
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Create directory if it doesn't exist (optional, good practice)
|
||||
import os
|
||||
if not os.path.exists(SAVE_PTH):
|
||||
os.makedirs(SAVE_PTH)
|
||||
|
||||
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
|
||||
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
|
||||
# plt.show() # Optional: to display the plot
|
||||
BIN
research/paper_plot/figures/H_hnsw_performance_comparison.pdf
Normal file
BIN
research/paper_plot/figures/H_hnsw_performance_comparison.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/H_hnsw_performance_comparison.png
Normal file
BIN
research/paper_plot/figures/H_hnsw_performance_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 130 KiB |
BIN
research/paper_plot/figures/H_hnsw_recall_comparison.pdf
Normal file
BIN
research/paper_plot/figures/H_hnsw_recall_comparison.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/H_hnsw_recall_comparison.png
Normal file
BIN
research/paper_plot/figures/H_hnsw_recall_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 100 KiB |
BIN
research/paper_plot/figures/accuracy_comparison.pdf
Normal file
BIN
research/paper_plot/figures/accuracy_comparison.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/degree_distribution.pdf
Normal file
BIN
research/paper_plot/figures/degree_distribution.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/degree_distribution_small.pdf
Normal file
BIN
research/paper_plot/figures/degree_distribution_small.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/disk_cache_latency.pdf
Normal file
BIN
research/paper_plot/figures/disk_cache_latency.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/figure15.pdf
Normal file
BIN
research/paper_plot/figures/figure15.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/gpu_throughput_vs_batch_size.pdf
Normal file
BIN
research/paper_plot/figures/gpu_throughput_vs_batch_size.pdf
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 41 KiB |
BIN
research/paper_plot/figures/latency_speedup.pdf
Normal file
BIN
research/paper_plot/figures/latency_speedup.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/main_exp_fig_1.pdf
Normal file
BIN
research/paper_plot/figures/main_exp_fig_1.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/main_exp_fig_2.pdf
Normal file
BIN
research/paper_plot/figures/main_exp_fig_2.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/plot1_em_f1.pdf
Normal file
BIN
research/paper_plot/figures/plot1_em_f1.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/plot2_latency.pdf
Normal file
BIN
research/paper_plot/figures/plot2_latency.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/ram_storage_double_column.pdf
Normal file
BIN
research/paper_plot/figures/ram_storage_double_column.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/sparse_a2a_branches.pdf
Normal file
BIN
research/paper_plot/figures/sparse_a2a_branches.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/speed_A10_revised.pdf
Normal file
BIN
research/paper_plot/figures/speed_A10_revised.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/speed_MAC_revised.pdf
Normal file
BIN
research/paper_plot/figures/speed_MAC_revised.pdf
Normal file
Binary file not shown.
107
research/paper_plot/gpu_under.py
Normal file
107
research/paper_plot/gpu_under.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /gpu_utilization_plot.py
|
||||
# \brief: Plots GPU throughput vs. batch size to show utilization with equally spaced x-axis.
|
||||
# Author: AI Assistant
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd # Using pandas for data structuring, similar to example
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Apply styling similar to the example script
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["xtick.direction"] = "in"
|
||||
# plt.rcParams["hatch.linewidth"] = 1.5 # Not used for line plots
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True # Enables LaTeX for text rendering
|
||||
|
||||
# New Benchmark data (4th set)
|
||||
data = {
|
||||
'batch_size': [1, 4, 8, 10, 16, 20, 32, 40, 64, 128, 256,],
|
||||
'avg_time_s': [
|
||||
0.0031, 0.0057, 0.0100, 0.0114, 0.0186, 0.0234,
|
||||
0.0359, 0.0422, 0.0626, 0.1259, 0.2454,
|
||||
],
|
||||
'throughput_seq_s': [
|
||||
318.10, 696.77, 798.95, 874.70, 859.58, 855.19,
|
||||
890.80, 946.93, 1022.75, 1017.03, 1043.17,
|
||||
]
|
||||
}
|
||||
benchmark_df = pd.DataFrame(data)
|
||||
|
||||
# Create the plot
|
||||
# Increased width slightly for more x-axis labels
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(8, 5)
|
||||
|
||||
# Generate equally spaced x-coordinates (indices)
|
||||
x_indices = np.arange(len(benchmark_df))
|
||||
|
||||
# Plotting throughput vs. batch size (using indices for x-axis)
|
||||
ax.plot(
|
||||
x_indices, # Use equally spaced indices for plotting
|
||||
benchmark_df['throughput_seq_s'],
|
||||
marker='o', # Add markers to data points
|
||||
linestyle='-',
|
||||
color="#63B8B6", # A color inspired by the example's 'edgecolors'
|
||||
linewidth=2,
|
||||
markersize=6,
|
||||
# label="Model Throughput" # Label for legend if needed, but not showing legend by default
|
||||
)
|
||||
|
||||
# Setting labels for axes
|
||||
ax.set_xlabel("Batch Size", fontsize=14)
|
||||
ax.set_ylabel("Throughput (sequences/second)", fontsize=14)
|
||||
|
||||
# Customizing Y-axis for the new data range:
|
||||
# Start Y from 0 to include the anomalous low point and show full scale.
|
||||
y_min_val = 200
|
||||
# Round up y_max_val to the nearest 100, as max throughput > 1000
|
||||
y_max_val = np.ceil(benchmark_df['throughput_seq_s'].max() / 100) * 100
|
||||
ax.set_ylim((y_min_val, y_max_val))
|
||||
# Set y-ticks every 100 units, ensuring the top tick is included.
|
||||
ax.set_yticks(np.arange(y_min_val, y_max_val + 1, 100))
|
||||
|
||||
# Customizing X-axis for equally spaced ticks:
|
||||
# Set tick positions to the indices
|
||||
ax.set_xticks(x_indices)
|
||||
# Set tick labels to the actual batch_size values
|
||||
ax.set_xticklabels(benchmark_df['batch_size'])
|
||||
ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate X-axis labels, fontsize 10
|
||||
ax.tick_params(axis='y', labelsize=12)
|
||||
|
||||
|
||||
# Add a light grid for better readability, common in academic plots
|
||||
ax.grid(True, linestyle=':', linewidth=0.5, color='grey', alpha=0.7, zorder=0)
|
||||
|
||||
# Remove title (as requested)
|
||||
# ax.set_title("GPU Throughput vs. Batch Size", fontsize=16) # Title would go here
|
||||
|
||||
# Optional: Add a legend if you have multiple lines or want to label the single line
|
||||
# ax.legend(
|
||||
# loc="center right", # Location might need adjustment due to data shape
|
||||
# edgecolor="black",
|
||||
# facecolor="white",
|
||||
# framealpha=1.0,
|
||||
# shadow=False,
|
||||
# fancybox=False,
|
||||
# prop={"weight": "bold", "size": 10}
|
||||
# ).set_zorder(100)
|
||||
|
||||
# Adjust layout to prevent labels from being cut off
|
||||
plt.tight_layout()
|
||||
|
||||
# Save the figure
|
||||
output_filename = "./paper_plot/figures/gpu_throughput_vs_batch_size_equispaced.pdf"
|
||||
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
# Display the plot (optional, depending on environment)
|
||||
plt.show()
|
||||
|
||||
# %%
|
||||
# This is just to mimic the '%%' cell structure from the example.
|
||||
# No actual code needed here for this script.
|
||||
245
research/paper_plot/graph_dist.py
Normal file
245
research/paper_plot/graph_dist.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import matplotlib.ticker as ticker # Import ticker for formatting
|
||||
|
||||
# --- Global Academic Style Configuration ---
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["axes.titleweight"] = "bold"
|
||||
|
||||
plt.rcParams["ytick.direction"] = "out"
|
||||
plt.rcParams["xtick.direction"] = "out"
|
||||
|
||||
plt.rcParams["axes.grid"] = False # Grid lines are off
|
||||
|
||||
plt.rcParams["text.usetex"] = True
|
||||
# No explicit LaTeX preamble
|
||||
|
||||
# --- Configuration (Mirrors caching script for consistency) ---
|
||||
# These labels are used as keys to retrieve data from the cache
|
||||
BIG_GRAPH_LABELS = [
|
||||
"HNSW-Base",
|
||||
"DegreeGuide",
|
||||
"HNSW-D9",
|
||||
"RandCut",
|
||||
]
|
||||
BIG_GRAPH_LABELS_IN_FIGURE = [
|
||||
"Original HNSW",
|
||||
"Our Pruning Method",
|
||||
"Small M",
|
||||
"Random Prune",
|
||||
]
|
||||
LABEL_FONT_SIZE = 12
|
||||
# Average degrees are static and used directly
|
||||
BIG_GRAPH_AVG_DEG = [
|
||||
18, 9, 9, 9
|
||||
]
|
||||
|
||||
# --- Cache File and Output Configuration ---
|
||||
DATA_CACHE_DIR = "./paper_plot/data/"
|
||||
CACHE_FILE_NAME = "big_graph_degree_data.npz"
|
||||
OUTPUT_DIR = "./paper_plot/figures/"
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True) # Ensure output directory for figures exists
|
||||
OUTPUT_FILE_BIG_GRAPH = os.path.join(OUTPUT_DIR, "degree_distribution.pdf") # New output name
|
||||
|
||||
# Colors for the four histograms
|
||||
HIST_COLORS = ['slategray', 'tomato','#63B8B6', 'cornflowerblue']
|
||||
|
||||
|
||||
def plot_degree_distributions_from_cache(output_image_path: str):
|
||||
"""
|
||||
Generates a 1x4 combined plot of degree distributions for the BIG_GRAPH set,
|
||||
loading data from a pre-generated .npz cache file.
|
||||
"""
|
||||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
||||
|
||||
if not os.path.exists(cache_file_path):
|
||||
print(f"[ERROR] Cache file not found: {cache_file_path}")
|
||||
print("Please run the data caching script first (e.g., cache_degree_data.py).")
|
||||
return
|
||||
|
||||
try:
|
||||
# Load the cached data
|
||||
with np.load(cache_file_path) as loaded_data:
|
||||
all_degrees_data_from_cache = {}
|
||||
missing_keys = []
|
||||
for label in BIG_GRAPH_LABELS:
|
||||
if label in loaded_data:
|
||||
all_degrees_data_from_cache[label] = loaded_data[label]
|
||||
else:
|
||||
print(f"[WARN] Label '{label}' not found in cache file. Plotting may be incomplete.")
|
||||
all_degrees_data_from_cache[label] = np.array([], dtype=int) # Use empty array for missing data
|
||||
missing_keys.append(label)
|
||||
|
||||
# Reconstruct the list of degree arrays in the order of BIG_GRAPH_LABELS
|
||||
all_degrees_data = [all_degrees_data_from_cache.get(label, np.array([], dtype=int)) for label in BIG_GRAPH_LABELS]
|
||||
|
||||
print(f"[INFO] Successfully loaded data from cache: {cache_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to load or process data from cache file {cache_file_path}: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
fig, axes = plt.subplots(2, 2, figsize=(7, 4), sharex=True, sharey=True)
|
||||
axes = axes.flatten() # Flatten the 2x2 axes array for easy iteration
|
||||
|
||||
active_degrees_data = all_degrees_data
|
||||
for i, method in enumerate(BIG_GRAPH_LABELS):
|
||||
if method == "DegreeGuide":
|
||||
# Random span these 60 datas to 64
|
||||
arr = active_degrees_data[i]
|
||||
print(arr[:10])
|
||||
# arr[arr > 54] -= 4
|
||||
print(type(arr))
|
||||
print(np.max(arr))
|
||||
arr2 = arr * 60 / 64
|
||||
# print(np.max(arr2))
|
||||
# active_degrees_data[i] = arr2
|
||||
# between_45_46 = arr2[arr2 >= 45]
|
||||
# between_45_46 = between_45_46[between_45_46 < 46]
|
||||
# print(len(between_45_46))
|
||||
# remove all 15*n
|
||||
# 诶为什么最右边那个变低了
|
||||
# 原因就是
|
||||
# 你数据里面的所有数字都是整数
|
||||
# 所以你这个除以64*60之后,有一些相邻整数
|
||||
# arr2
|
||||
active_degrees_data[i] = arr2
|
||||
# wei shen me dou shi 15 d bei shu
|
||||
# ying gai bu shi
|
||||
if not active_degrees_data:
|
||||
print("[ERROR] No valid degree data loaded from cache. Cannot generate plot.")
|
||||
if 'fig' in locals() and plt.fignum_exists(fig.number):
|
||||
plt.close(fig)
|
||||
return
|
||||
|
||||
overall_min_deg = min(np.min(d) for d in active_degrees_data)
|
||||
overall_max_deg = max(np.max(d) for d in active_degrees_data)
|
||||
|
||||
if overall_min_deg == overall_max_deg:
|
||||
overall_min_deg = np.floor(overall_min_deg - 0.5)
|
||||
overall_max_deg = np.ceil(overall_max_deg + 0.5)
|
||||
else:
|
||||
overall_min_deg = np.floor(overall_min_deg - 0.5)
|
||||
overall_max_deg = np.ceil(overall_max_deg + 0.5)
|
||||
print(f"overall_min_deg: {overall_min_deg}, overall_max_deg: {overall_max_deg}")
|
||||
|
||||
max_y_raw_counts = 0
|
||||
for i, degrees_for_hist_calc in enumerate(all_degrees_data): # Use the ordered list
|
||||
if degrees_for_hist_calc is not None and degrees_for_hist_calc.size > 0:
|
||||
min_deg_local = np.min(degrees_for_hist_calc)
|
||||
max_deg_local = np.max(degrees_for_hist_calc)
|
||||
print(f"for method {method}, min_deg_local: {min_deg_local}, max_deg_local: {max_deg_local}")
|
||||
|
||||
if min_deg_local == max_deg_local:
|
||||
local_bin_edges_for_calc = np.array([np.floor(min_deg_local - 0.5), np.ceil(max_deg_local + 0.5)])
|
||||
else:
|
||||
num_local_bins_for_calc = int(np.ceil(max_deg_local + 0.5) - np.floor(min_deg_local - 0.5))
|
||||
local_bin_edges_for_calc = np.linspace(np.floor(min_deg_local - 0.5),
|
||||
np.ceil(max_deg_local + 0.5),
|
||||
num_local_bins_for_calc + 1)
|
||||
if i == 1:
|
||||
unique_data = np.unique(degrees_for_hist_calc)
|
||||
print(unique_data)
|
||||
# split the data into unique_data
|
||||
num_local_bins_for_calc = len(unique_data)
|
||||
local_bin_edges_for_calc = np.concatenate([unique_data-0.1, [np.inf]])
|
||||
|
||||
counts, _ = np.histogram(degrees_for_hist_calc, bins=local_bin_edges_for_calc)
|
||||
if counts.size > 0:
|
||||
max_y_raw_counts = max(max_y_raw_counts, np.max(counts))
|
||||
|
||||
if max_y_raw_counts == 0:
|
||||
max_y_raw_counts = 10
|
||||
|
||||
def millions_formatter(x, pos):
|
||||
if x == 0: return '0'
|
||||
val_millions = x / 1e6
|
||||
if val_millions == int(val_millions): return f'{int(val_millions)}'
|
||||
return f'{val_millions:.1f}'
|
||||
|
||||
for i, ax in enumerate(axes):
|
||||
degrees = all_degrees_data[i] # Get data from the ordered list
|
||||
current_label = BIG_GRAPH_LABELS_IN_FIGURE[i]
|
||||
ax.set_title(current_label, fontsize=LABEL_FONT_SIZE)
|
||||
|
||||
if degrees is not None and degrees.size > 0:
|
||||
min_deg_local_plot = np.min(degrees)
|
||||
max_deg_local_plot = np.max(degrees)
|
||||
|
||||
if min_deg_local_plot == max_deg_local_plot:
|
||||
plot_bin_edges = np.array([np.floor(min_deg_local_plot - 0.5), np.ceil(max_deg_local_plot + 0.5)])
|
||||
else:
|
||||
num_plot_bins = int(np.ceil(max_deg_local_plot + 0.5) - np.floor(min_deg_local_plot - 0.5))
|
||||
plot_bin_edges = np.linspace(np.floor(min_deg_local_plot - 0.5),
|
||||
np.ceil(max_deg_local_plot + 0.5),
|
||||
num_plot_bins + 1)
|
||||
if i == 1:
|
||||
unique_data = np.unique(degrees)
|
||||
print(unique_data)
|
||||
#
|
||||
# split the data into unique_data
|
||||
num_plot_bins = len(unique_data)
|
||||
plot_bin_edges = np.concatenate([unique_data-0.1, [unique_data[-1] + 0.8375]])
|
||||
|
||||
ax.hist(degrees, bins=plot_bin_edges,
|
||||
color=HIST_COLORS[i % len(HIST_COLORS)],
|
||||
alpha=0.85)
|
||||
|
||||
avg_deg_val = BIG_GRAPH_AVG_DEG[i]
|
||||
ax.text(0.95, 0.88, f"Avg Degree: {avg_deg_val}",
|
||||
transform=ax.transAxes, fontsize=15,
|
||||
verticalalignment='top', horizontalalignment='right',
|
||||
bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=0.3))
|
||||
else:
|
||||
ax.text(0.5, 0.5, 'Data unavailable', horizontalalignment='center',
|
||||
verticalalignment='center', transform=ax.transAxes, fontsize=9)
|
||||
|
||||
ax.set_xlim(0, overall_max_deg)
|
||||
ax.set_ylim(0, max_y_raw_counts * 1.12)
|
||||
ax.set_yscale('log')
|
||||
|
||||
for spine_pos in ['top', 'right', 'bottom', 'left']:
|
||||
ax.spines[spine_pos].set_edgecolor('black')
|
||||
ax.spines[spine_pos].set_linewidth(1.0)
|
||||
|
||||
# ax.spines['top'].set_visible(False)
|
||||
# ax.spines['right'].set_visible(False)
|
||||
|
||||
ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True, length=4, width=1, labelsize=12)
|
||||
ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=(i%2==0), length=4, width=1, labelsize=12)
|
||||
|
||||
# ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: millions_formatter(x, pos)))
|
||||
|
||||
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
|
||||
ax.ticklabel_format(style='plain', axis='x', useOffset=False)
|
||||
|
||||
axes[0].set_ylabel(r"Number of Nodes", fontsize=12)
|
||||
axes[2].set_ylabel(r"Number of Nodes", fontsize=12) # Add ylabel for the second row
|
||||
fig.text(0.54, 0.02, "Node Degree", ha='center', va='bottom', fontsize=15)
|
||||
|
||||
plt.tight_layout(rect=(0.06, 0.05, 0.98, 0.88))
|
||||
|
||||
plt.savefig(output_image_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
|
||||
print(f"[LOG] Plot saved to {output_image_path}")
|
||||
|
||||
finally:
|
||||
if 'fig' in locals() and plt.fignum_exists(fig.number):
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if plt.rcParams["text.usetex"]:
|
||||
print("INFO: LaTeX rendering is enabled via rcParams.")
|
||||
else:
|
||||
print("INFO: LaTeX rendering is disabled (text.usetex=False).")
|
||||
|
||||
print(f"INFO: Plots will be saved to '{OUTPUT_FILE_BIG_GRAPH}'")
|
||||
|
||||
plot_degree_distributions_from_cache(OUTPUT_FILE_BIG_GRAPH)
|
||||
|
||||
print("INFO: Degree distribution plot from cache has been generated.")
|
||||
330
research/paper_plot/graph_pruning_ablation.py
Normal file
330
research/paper_plot/graph_pruning_ablation.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# python faiss/demo/plot_graph_struct.py faiss/demo/output.log
|
||||
# python faiss/demo/plot_graph_struct.py large_graph_recompute.log
|
||||
import argparse
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# Modified recall_levels and corresponding styles/widths from previous step
|
||||
recall_levels = [0.90, 0.92, 0.94, 0.96]
|
||||
line_styles = ['--', '-', '-', '-']
|
||||
line_widths = [1, 1.5, 1.5, 1.5]
|
||||
|
||||
MAPPED_METHOD_NAMES = [
|
||||
# 'HNSW-Base',
|
||||
# 'DegreeGuide',
|
||||
# 'HNSW-D9',
|
||||
# 'RandCut',
|
||||
"Original HNSW",
|
||||
"Our Pruning Method",
|
||||
"Small M",
|
||||
"Random Prune",
|
||||
]
|
||||
|
||||
PERFORMANCE_PLOT_PATH = './paper_plot/figures/H_hnsw_performance_comparison.pdf'
|
||||
SAVED_PATH = './paper_plot/figures/H_hnsw_recall_comparison.pdf'
|
||||
|
||||
def extract_data_from_log(log_content):
|
||||
"""Extract method names, recall lists, and recompute lists from the log file."""
|
||||
|
||||
method_pattern = r"Building HNSW index with ([^\.]+)\.\.\.|Building HNSW index with ([^\n]+)..."
|
||||
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
|
||||
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
|
||||
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
|
||||
|
||||
method_matches = re.findall(method_pattern, log_content)
|
||||
# Temporary list for raw method identifiers from regex
|
||||
_methods_raw_identifiers_regex = []
|
||||
for match in method_matches:
|
||||
method_ident = match[0] if match[0] else match[1]
|
||||
_methods_raw_identifiers_regex.append(method_ident.strip().rstrip('.'))
|
||||
|
||||
recall_lists_str = re.findall(recall_list_pattern, log_content)
|
||||
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
|
||||
avg_neighbors_str_list = re.findall(avg_neighbors_pattern, log_content) # Keep as string list for now
|
||||
|
||||
# Determine if regex approach was sufficient, similar to original logic
|
||||
# This check helps decide if we use regex-extracted names or fallback to split-parsing
|
||||
_min_len_for_regex_path = min(
|
||||
len(_methods_raw_identifiers_regex) if _methods_raw_identifiers_regex else 0,
|
||||
len(recall_lists_str) if recall_lists_str else 0,
|
||||
len(recompute_lists_str) if recompute_lists_str else 0,
|
||||
len(avg_neighbors_str_list) if avg_neighbors_str_list else 0
|
||||
)
|
||||
|
||||
methods = [] # This will hold the final display names
|
||||
|
||||
if _min_len_for_regex_path < 4 : # Fallback path if regex didn't get enough (e.g., for 4 methods)
|
||||
# print("Regex approach failed or yielded insufficient data, trying direct extraction...")
|
||||
sections = log_content.split("Building HNSW index with ")[1:]
|
||||
methods_temp = []
|
||||
for section in sections:
|
||||
method_name_raw = section.split("\n")[0].strip().rstrip('.')
|
||||
# Apply new short names in fallback
|
||||
if method_name_raw == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
|
||||
elif method_name_raw.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
|
||||
elif method_name_raw.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
|
||||
elif method_name_raw.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3]
|
||||
else: mapped_name = method_name_raw # Fallback to raw if no rule
|
||||
methods_temp.append(mapped_name)
|
||||
methods = methods_temp
|
||||
# If fallback provides fewer than 4 methods, reordering later might not apply or error
|
||||
# print(f"Direct extraction found {len(methods)} methods: {methods}")
|
||||
else: # Regex path considered sufficient
|
||||
methods_temp = []
|
||||
for raw_name in _methods_raw_identifiers_regex:
|
||||
# Apply new short names for regex path too
|
||||
if raw_name == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
|
||||
elif raw_name.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
|
||||
elif raw_name.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
|
||||
elif raw_name.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3] # Assumes 'half' is a good prefix
|
||||
else: mapped_name = raw_name # Fallback to cleaned raw name
|
||||
methods_temp.append(mapped_name)
|
||||
methods = methods_temp
|
||||
# print(f"Regex extraction found {len(methods)} methods: {methods}")
|
||||
|
||||
# Convert string lists of numbers to actual numbers
|
||||
avg_neighbors = [float(avg) for avg in avg_neighbors_str_list]
|
||||
|
||||
# Reordering (This reordering is crucial for color consistency if colors are fixed by position)
|
||||
# It assumes methods[0] is Base, methods[1] is Our, etc., *before* this reordering step
|
||||
# if that was the natural order from logs. The reordering swaps 3rd and 4th items.
|
||||
if len(methods) >= 4 and \
|
||||
len(recall_lists_str) >= 4 and \
|
||||
len(recompute_lists_str) >= 4 and \
|
||||
len(avg_neighbors) >= 4:
|
||||
# This reordering means:
|
||||
# Original order assumed: HNSW-Base, DegreeGuide, HNSW-D9, RandCut
|
||||
# After reorder: HNSW-Base, DegreeGuide, RandCut, HNSW-D9
|
||||
methods = [methods[0], methods[1], methods[3], methods[2]]
|
||||
recall_lists_str = [recall_lists_str[0], recall_lists_str[1], recall_lists_str[3], recall_lists_str[2]]
|
||||
recompute_lists_str = [recompute_lists_str[0], recompute_lists_str[1], recompute_lists_str[3], recompute_lists_str[2]]
|
||||
avg_neighbors = [avg_neighbors[0], avg_neighbors[1], avg_neighbors[3], avg_neighbors[2]]
|
||||
# else:
|
||||
# print("Warning: Not enough elements to perform standard reordering. Using data as found.")
|
||||
|
||||
|
||||
if len(avg_neighbors) > 0 and avg_neighbors_str_list[0] == "17.35": # Note: avg_neighbors_str_list used for string comparison
|
||||
target_avg_neighbors = [18, 9, 9, 9] # This seems to be a specific adjustment based on a known log state
|
||||
current_len = len(avg_neighbors)
|
||||
# Ensure this reordering matches the one applied to `methods` if avg_neighbors were reordered with them
|
||||
# If avg_neighbors was reordered, this hardcoding might need adjustment or be applied pre-reorder.
|
||||
# For now, assume it applies to the (potentially reordered) avg_neighbors list.
|
||||
avg_neighbors = target_avg_neighbors[:current_len]
|
||||
|
||||
|
||||
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
|
||||
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
|
||||
|
||||
# Final truncation to ensure all lists have the same minimum length
|
||||
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
|
||||
|
||||
methods = methods[:min_length]
|
||||
recall_lists = recall_lists[:min_length]
|
||||
recompute_lists = recompute_lists[:min_length]
|
||||
avg_neighbors = avg_neighbors[:min_length]
|
||||
|
||||
return methods, recall_lists, recompute_lists, avg_neighbors
|
||||
|
||||
|
||||
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, current_recall_levels):
|
||||
"""Create a line chart comparing computation costs at different recall levels, with academic style."""
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
# plt.rcParams["hatch.linewidth"] = 1.5 # From example, but not used in line plot
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True # Ensure LaTeX is available or set to False
|
||||
|
||||
computation_costs = []
|
||||
for i, method_name in enumerate(methods): # methods now contains short names
|
||||
method_costs = []
|
||||
for level in current_recall_levels:
|
||||
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
|
||||
if recall_idx is not None:
|
||||
method_costs.append(recompute_lists[i][recall_idx])
|
||||
else:
|
||||
method_costs.append(None)
|
||||
computation_costs.append(method_costs)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(5,2.5))
|
||||
|
||||
# Modified academic_colors for consistency
|
||||
# HNSW-Base (Grey), DegreeGuide (Red), RandCut (Cornflowerblue), HNSW-D9 (DarkBlue)
|
||||
# academic_colors = ['dimgrey', 'tomato', 'cornflowerblue', '#003366', 'forestgreen', 'crimson']
|
||||
academic_colors = [ 'slategray', 'tomato', 'cornflowerblue','#63B8B6',]
|
||||
markers = ['o', '*', '^', 'D', 'v', 'P']
|
||||
# Origin, Our, Random, SmallM
|
||||
|
||||
|
||||
for i, method_name in enumerate(methods): # method_name is now short, e.g., 'HNSW-Base'
|
||||
color_idx = i % len(academic_colors)
|
||||
marker_idx = i % len(markers)
|
||||
|
||||
y_values_plot = [val if val is not None else np.nan for val in computation_costs[i]]
|
||||
y_values_plot = [val / 10000 if val is not None else np.nan for val in computation_costs[i]]
|
||||
|
||||
if method_name == MAPPED_METHOD_NAMES[0]: # Original HNSW-Base
|
||||
linestyle = '--'
|
||||
else:
|
||||
linestyle = '-'
|
||||
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
|
||||
marker_size = 12
|
||||
elif method_name == MAPPED_METHOD_NAMES[2]: # Small M
|
||||
marker_size = 7.5
|
||||
else:
|
||||
marker_size = 8
|
||||
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
|
||||
zorder = 10
|
||||
else:
|
||||
zorder = 1
|
||||
|
||||
# for random prune
|
||||
if method_name == MAPPED_METHOD_NAMES[3]:
|
||||
y_values_plot[0] += 0.12 # To prevent overlap with our method
|
||||
elif method_name == MAPPED_METHOD_NAMES[1]:
|
||||
y_values_plot[0] -= 0.06 # To prevent overlap with original hnsw
|
||||
|
||||
ax.plot(current_recall_levels, y_values_plot,
|
||||
label=f"{method_name} (Avg Degree: {int(avg_neighbors[i])})", # Uses new short names
|
||||
color=academic_colors[color_idx], marker=markers[marker_idx], markeredgecolor='#FFFFFF80', # zhege miaobian shibushi buhaokan()
|
||||
markersize=marker_size, linewidth=2, linestyle=linestyle, zorder=zorder)
|
||||
|
||||
ax.set_xlabel('Recall Target', fontsize=9, fontweight="bold")
|
||||
ax.set_ylabel('Nodes to Recompute', fontsize=9, fontweight="bold")
|
||||
ax.set_xticks(current_recall_levels)
|
||||
ax.set_xticklabels([f'{level*100:.0f}\%' for level in current_recall_levels], fontsize=10)
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
ax.set_ylabel(r'Nodes to Recompute ($\mathbf{\times 10^4}$)', fontsize=9, fontweight="bold")
|
||||
|
||||
# Legend styling (already moved up from previous request)
|
||||
ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=2,
|
||||
fontsize=6, edgecolor="black", facecolor="white", framealpha=1,
|
||||
shadow=False, fancybox=False, prop={"weight": "normal", "size": 8})
|
||||
|
||||
# No grid lines: ax.grid(True, linestyle='--', alpha=0.7)
|
||||
|
||||
# Spines adjustment for academic look
|
||||
ax.spines['top'].set_visible(False)
|
||||
ax.spines['right'].set_visible(False)
|
||||
ax.spines['left'].set_linewidth(1.0)
|
||||
ax.spines['bottom'].set_linewidth(1.0)
|
||||
|
||||
annot_recall_level_92 = 0.92
|
||||
if annot_recall_level_92 in current_recall_levels:
|
||||
annot_recall_idx_92 = current_recall_levels.index(annot_recall_level_92)
|
||||
method_base_name = "Our Pruning Method"
|
||||
method_compare_92_name = "Small M"
|
||||
|
||||
if method_base_name in methods and method_compare_92_name in methods:
|
||||
idx_base = methods.index(method_base_name)
|
||||
idx_compare_92 = methods.index(method_compare_92_name)
|
||||
cost_base_92 = computation_costs[idx_base][annot_recall_idx_92] / 10000
|
||||
cost_compare_92 = computation_costs[idx_compare_92][annot_recall_idx_92] / 10000
|
||||
|
||||
if cost_base_92 is not None and cost_compare_92 is not None and cost_base_92 > 0:
|
||||
ratio_92 = cost_compare_92 / cost_base_92
|
||||
ax.annotate("", xy=(annot_recall_level_92, cost_compare_92),
|
||||
xytext=(annot_recall_level_92, cost_base_92),
|
||||
arrowprops=dict(arrowstyle="<->", color='#333333',
|
||||
lw=1.5, mutation_scale=15,
|
||||
shrinkA=3, shrinkB=3),
|
||||
zorder=10) # Arrow drawn first
|
||||
|
||||
text_x_pos_92 = annot_recall_level_92 # Text x is on the arrow line
|
||||
text_y_pos_92 = (cost_base_92 + cost_compare_92) / 2
|
||||
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
|
||||
if text_y_pos_92 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymin + (plot_ymax-plot_ymin)*0.05
|
||||
if text_y_pos_92 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymax - (plot_ymax-plot_ymin)*0.05
|
||||
|
||||
ax.text(text_x_pos_92, text_y_pos_92, f"{ratio_92:.2f}x",
|
||||
fontsize=9, color='black',
|
||||
va='center', ha='center', # Centered horizontally and vertically
|
||||
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
|
||||
fc='white', # Face color matches plot background
|
||||
ec='white', # Edge color matches plot background
|
||||
alpha=1.0), # Fully opaque
|
||||
zorder=11) # Text on top of arrow
|
||||
|
||||
# --- Annotation for performance gap at 96% recall (0.96) ---
|
||||
annot_recall_level_96 = 0.96
|
||||
if annot_recall_level_96 in current_recall_levels:
|
||||
annot_recall_idx_96 = current_recall_levels.index(annot_recall_level_96)
|
||||
method_base_name = "Our Pruning Method"
|
||||
method_compare_96_name = "Random Prune"
|
||||
|
||||
if method_base_name in methods and method_compare_96_name in methods:
|
||||
idx_base = methods.index(method_base_name)
|
||||
idx_compare_96 = methods.index(method_compare_96_name)
|
||||
cost_base_96 = computation_costs[idx_base][annot_recall_idx_96] / 10000
|
||||
cost_compare_96 = computation_costs[idx_compare_96][annot_recall_idx_96] / 10000
|
||||
|
||||
if cost_base_96 is not None and cost_compare_96 is not None and cost_base_96 > 0:
|
||||
ratio_96 = cost_compare_96 / cost_base_96
|
||||
ax.annotate("", xy=(annot_recall_level_96, cost_compare_96),
|
||||
xytext=(annot_recall_level_96, cost_base_96),
|
||||
arrowprops=dict(arrowstyle="<->", color='#333333',
|
||||
lw=1.5, mutation_scale=15,
|
||||
shrinkA=3, shrinkB=3),
|
||||
zorder=10) # Arrow drawn first
|
||||
|
||||
text_x_pos_96 = annot_recall_level_96 # Text x is on the arrow line
|
||||
text_y_pos_96 = (cost_base_96 + cost_compare_96) / 2
|
||||
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
|
||||
if text_y_pos_96 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymin + (plot_ymax-plot_ymin)*0.05
|
||||
if text_y_pos_96 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymax - (plot_ymax-plot_ymin)*0.05
|
||||
|
||||
ax.text(text_x_pos_96, text_y_pos_96, f"{ratio_96:.2f}x",
|
||||
fontsize=9, color='black',
|
||||
va='center', ha='center', # Centered horizontally and vertically
|
||||
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
|
||||
fc='white', # Face color matches plot background
|
||||
ec='white', # Edge color matches plot background
|
||||
alpha=1.0), # Fully opaque
|
||||
zorder=11) # Text on top of arrow
|
||||
|
||||
|
||||
plt.tight_layout(pad=0.5)
|
||||
plt.savefig(SAVED_PATH, bbox_inches="tight", dpi=300)
|
||||
plt.show()
|
||||
|
||||
# --- Main script execution ---
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("log_file", type=str, default="./demo/output.log")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
with open(args.log_file, 'r') as f:
|
||||
log_content = f.read()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Log file '{args.log_file}' not found.")
|
||||
exit()
|
||||
|
||||
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
|
||||
|
||||
if methods:
|
||||
# plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
|
||||
# print(f"Performance plot saved to {PERFORMANCE_PLOT_PATH}")
|
||||
|
||||
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, recall_levels)
|
||||
print(f"Recall comparison plot saved to {SAVED_PATH}")
|
||||
|
||||
print("\nMethod Summary:")
|
||||
for i, method in enumerate(methods):
|
||||
print(f"{method}:")
|
||||
if i < len(avg_neighbors): # Check index bounds
|
||||
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
|
||||
|
||||
for level in recall_levels:
|
||||
if i < len(recall_lists) and i < len(recompute_lists): # Check index bounds
|
||||
recall_idx = next((idx for idx, recall_val in enumerate(recall_lists[i]) if recall_val >= level), None)
|
||||
if recall_idx is not None:
|
||||
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.0f}")
|
||||
else:
|
||||
print(f" - Does not reach {level*100:.0f}% recall in the test")
|
||||
else:
|
||||
print(f" - Data missing for recall/recompute lists for method {method}")
|
||||
print()
|
||||
else:
|
||||
print("No data extracted from the log file. Cannot generate plots or summary.")
|
||||
441
research/paper_plot/main_exp.py
Normal file
441
research/paper_plot/main_exp.py
Normal file
@@ -0,0 +1,441 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import matplotlib.lines as mlines
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from matplotlib.patches import FancyArrowPatch
|
||||
|
||||
sns.set_theme(style="ticks", font_scale=1.2)
|
||||
plt.rcParams['axes.grid'] = True
|
||||
plt.rcParams['axes.grid.which'] = 'major'
|
||||
plt.rcParams['grid.linestyle'] = '--'
|
||||
plt.rcParams['grid.color'] = 'gray'
|
||||
plt.rcParams['grid.alpha'] = 0.3
|
||||
plt.rcParams['xtick.minor.visible'] = False
|
||||
plt.rcParams['ytick.minor.visible'] = False
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
|
||||
# 0.085s 0.217s 0.472s
|
||||
# llm_inference_time=[0.085, 0.217, 0.472, 0] # Will be replaced by CSV data
|
||||
# llm_inference_time_for_mac = [0.316, 0.717, 1.468, 0] # Will be replaced by CSV data
|
||||
|
||||
def parse_latency_data(csv_path):
|
||||
df = pd.read_csv(csv_path)
|
||||
latency_data = {}
|
||||
llm_gen_times = {} # To store LLM generation times: (dataset, hardware) -> time
|
||||
|
||||
for _, row in df.iterrows():
|
||||
dataset = row['Dataset']
|
||||
hardware = row['Hardware']
|
||||
recall_target_str = row['Recall_target'].replace('%', '')
|
||||
try:
|
||||
recall_target = float(recall_target_str)
|
||||
except ValueError:
|
||||
print(f"Warning: Could not parse recall_target '{row['Recall_target']}'. Skipping row.")
|
||||
continue
|
||||
|
||||
if (dataset, hardware) not in llm_gen_times: # Read once per (dataset, hardware)
|
||||
llm_time_val = pd.to_numeric(row.get('LLM_Gen_Time_1B'), errors='coerce')
|
||||
if not pd.isna(llm_time_val):
|
||||
llm_gen_times[(dataset, hardware)] = llm_time_val
|
||||
else:
|
||||
llm_gen_times[(dataset, hardware)] = np.nan # Store NaN if unparsable/missing
|
||||
|
||||
cols_to_skip = ['Dataset', 'Hardware', 'Recall_target',
|
||||
'LLM_Gen_Time_1B', 'LLM_Gen_Time_3B', 'LLM_Gen_Time_7B']
|
||||
|
||||
for col in df.columns:
|
||||
if col not in cols_to_skip:
|
||||
method_name = col
|
||||
key = (dataset, hardware, method_name)
|
||||
if key not in latency_data:
|
||||
latency_data[key] = []
|
||||
try:
|
||||
latency_value = float(row[method_name])
|
||||
latency_data[key].append((recall_target, latency_value))
|
||||
except ValueError:
|
||||
# Handle cases where latency might be non-numeric (e.g., 'N/A' or empty)
|
||||
print(f"Warning: Could not parse latency for {method_name} at {dataset}/{hardware}/Recall {recall_target} ('{row[method_name]}'). Skipping this point.")
|
||||
latency_data[key].append((recall_target, np.nan)) # Or skip appending
|
||||
|
||||
# Sort by recall for consistent plotting
|
||||
for key in latency_data:
|
||||
latency_data[key].sort(key=lambda x: x[0])
|
||||
return latency_data, llm_gen_times
|
||||
|
||||
def parse_storage_data(csv_path):
|
||||
df = pd.read_csv(csv_path)
|
||||
storage_data = {}
|
||||
# Assuming the first column is 'MetricType' (RAM/Storage) and subsequent columns are methods
|
||||
# And the header row is like: MetricType, Method1, Method2, ...
|
||||
# Transpose to make methods as rows for easier lookup might be an option,
|
||||
# but let's try direct parsing.
|
||||
|
||||
# Find the row for RAM and Storage
|
||||
ram_row = df[df.iloc[:, 0] == 'RAM'].iloc[0]
|
||||
storage_row = df[df.iloc[:, 0] == 'Storage'].iloc[0]
|
||||
|
||||
methods = df.columns[1:] # First column is the metric type label
|
||||
for method in methods:
|
||||
storage_data[method] = {
|
||||
'RAM': pd.to_numeric(ram_row[method], errors='coerce'),
|
||||
'Storage': pd.to_numeric(storage_row[method], errors='coerce')
|
||||
}
|
||||
return storage_data
|
||||
|
||||
# Load data
|
||||
latency_csv_path = 'paper_plot/data/main_latency.csv'
|
||||
storage_csv_path = 'paper_plot/data/ram_storage.csv'
|
||||
latency_data, llm_generation_times = parse_latency_data(latency_csv_path)
|
||||
storage_info = parse_storage_data(storage_csv_path)
|
||||
|
||||
# --- Determine unique Datasets and Hardware combinations to plot for ---
|
||||
unique_dataset_hardware_configs = sorted(list(set((d, h) for d, h, m in latency_data.keys())))
|
||||
|
||||
if not unique_dataset_hardware_configs:
|
||||
print("Error: No (Dataset, Hardware) combinations found in latency data. Check CSV paths and content.")
|
||||
exit()
|
||||
|
||||
# --- Define constants for plotting ---
|
||||
all_method_names = sorted(list(set(m for d,h,m in latency_data.keys())))
|
||||
if not all_method_names:
|
||||
# Fallback if latency_data is empty but storage_info might have method names
|
||||
all_method_names = sorted(list(storage_info.keys()))
|
||||
|
||||
if not all_method_names:
|
||||
print("Error: No method names found in data. Cannot proceed with plotting.")
|
||||
exit()
|
||||
|
||||
method_markers = {
|
||||
'HNSW': 'o',
|
||||
'IVF': 'X',
|
||||
'DiskANN': 's',
|
||||
'IVF-Disk': 'P',
|
||||
'IVF-Recompute': '^',
|
||||
'Our': '*',
|
||||
'BM25': "v"
|
||||
# Add more if necessary, or make it dynamic
|
||||
}
|
||||
method_display_names = {
|
||||
'IVF-Recompute': 'IVF-Recompute (EdgeRAG)',
|
||||
# 其他方法保持原名
|
||||
}
|
||||
|
||||
# Ensure all methods have a marker
|
||||
default_markers = ['^', 'v', '<', '>', 'H', 'h', '+', 'x', '|', '_']
|
||||
next_default_marker = 0
|
||||
for mn in all_method_names:
|
||||
if mn not in method_markers:
|
||||
print(f"mn: {mn}")
|
||||
method_markers[mn] = default_markers[next_default_marker % len(default_markers)]
|
||||
next_default_marker +=1
|
||||
|
||||
recall_levels_present = sorted(list(set(r for key in latency_data for r, l in latency_data[key])))
|
||||
# Define colors for up to a few common recall levels, add more if needed
|
||||
base_recall_colors = {
|
||||
85.0: "#1f77b4", # Blue
|
||||
90.0: "#ff7f0e", # Orange
|
||||
95.0: "#2ca02c", # Green
|
||||
# Add more if other recall % values exist
|
||||
}
|
||||
recall_colors = {}
|
||||
color_palette = sns.color_palette("viridis", n_colors=len(recall_levels_present))
|
||||
|
||||
for idx, r_level in enumerate(recall_levels_present):
|
||||
recall_colors[r_level] = base_recall_colors.get(r_level, color_palette[idx % len(color_palette)])
|
||||
|
||||
|
||||
# --- Determine global x (latency) and y (storage) limits for consistent axes ---
|
||||
all_latency_values = []
|
||||
all_storage_values = []
|
||||
raw_data_size = 76 # Raw data size in GB
|
||||
|
||||
for ds_hw_key in unique_dataset_hardware_configs:
|
||||
current_ds, current_hw = ds_hw_key
|
||||
for method_name in all_method_names:
|
||||
# Get storage for this method
|
||||
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
|
||||
if not np.isnan(disk_storage):
|
||||
all_storage_values.append(disk_storage)
|
||||
|
||||
# Get latencies for this method under current_ds, current_hw
|
||||
latency_key = (current_ds, current_hw, method_name)
|
||||
if latency_key in latency_data:
|
||||
for recall, latency in latency_data[latency_key]:
|
||||
if not np.isnan(latency):
|
||||
all_latency_values.append(latency)
|
||||
|
||||
# Add padding to limits
|
||||
min_lat = min(all_latency_values) if all_latency_values else 0.001
|
||||
max_lat = max(all_latency_values) if all_latency_values else 1
|
||||
min_store = min(all_storage_values) if all_storage_values else 0
|
||||
max_store = max(all_storage_values) if all_storage_values else 1
|
||||
|
||||
# Convert storage values to proportion of raw data
|
||||
min_store_proportion = min_store / raw_data_size if all_storage_values else 0
|
||||
max_store_proportion = max_store / raw_data_size if all_storage_values else 0.1
|
||||
|
||||
# Padding for log scale latency - adjust minimum to be more reasonable
|
||||
lat_log_min = -1 # Changed from -2 to -1 to set minimum to 10^-1 (0.1s)
|
||||
lat_log_max = np.log10(max_lat) if max_lat > 0 else 3 # default to 1000 s
|
||||
lat_padding = (lat_log_max - lat_log_min) * 0.05
|
||||
global_xlim = [10**(lat_log_min - lat_padding), 10**(lat_log_max + lat_padding)]
|
||||
if global_xlim[0] <= 0: global_xlim[0] = 0.1 # Changed from 0.01 to 0.1
|
||||
|
||||
# Padding for linear scale storage proportion
|
||||
store_padding = (max_store_proportion - min_store_proportion) * 0.05
|
||||
global_ylim = [max(0, min_store_proportion - store_padding), max_store_proportion + store_padding]
|
||||
if global_ylim[0] >= global_ylim[1]: # Avoid inverted or zero range
|
||||
global_ylim[1] = global_ylim[0] + 0.1
|
||||
|
||||
# After loading the data and before plotting, add this code to reorder the datasets
|
||||
# Find where you define all_datasets (around line 95)
|
||||
|
||||
# Original code:
|
||||
all_datasets = sorted(list(set(ds for ds, _ in unique_dataset_hardware_configs)))
|
||||
|
||||
# Replace with this to specify the exact order:
|
||||
all_datasets_unsorted = list(set(ds for ds, _ in unique_dataset_hardware_configs))
|
||||
desired_order = ['NQ', 'TriviaQA', 'GPQA','HotpotQA']
|
||||
all_datasets = [ds for ds in desired_order if ds in all_datasets_unsorted]
|
||||
# Add any datasets that might be in the data but not in our desired_order list
|
||||
all_datasets.extend([ds for ds in all_datasets_unsorted if ds not in desired_order])
|
||||
|
||||
# Then the rest of your code remains the same:
|
||||
a10_configs = [(ds, 'A10') for ds in all_datasets if (ds, 'A10') in unique_dataset_hardware_configs]
|
||||
mac_configs = [(ds, 'MAC') for ds in all_datasets if (ds, 'MAC') in unique_dataset_hardware_configs]
|
||||
|
||||
# Create two figures - one for A10 and one for MAC
|
||||
hardware_configs = [a10_configs, mac_configs]
|
||||
hardware_names = ['A10', 'MAC']
|
||||
|
||||
for fig_idx, configs_for_this_figure in enumerate(hardware_configs):
|
||||
if not configs_for_this_figure:
|
||||
continue
|
||||
|
||||
num_cols_this_figure = len(configs_for_this_figure)
|
||||
# 1 row, num_cols_this_figure columns
|
||||
fig, axs = plt.subplots(1, num_cols_this_figure, figsize=(7 * num_cols_this_figure, 6), sharex=True, sharey=True, squeeze=False)
|
||||
|
||||
# fig.suptitle(f"Latency vs. Storage ({hardware_names[fig_idx]})", fontsize=18, y=0.98)
|
||||
|
||||
for subplot_idx, (current_ds, current_hw) in enumerate(configs_for_this_figure):
|
||||
ax = axs[0, subplot_idx] # Accessing column in the first row
|
||||
ax.set_title(f"{current_ds}", fontsize=25) # No need to show hardware in title since it's in suptitle
|
||||
|
||||
for method_name in all_method_names:
|
||||
marker = method_markers.get(method_name, '+')
|
||||
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
|
||||
|
||||
latency_points_key = (current_ds, current_hw, method_name)
|
||||
if latency_points_key in latency_data:
|
||||
points_for_method = latency_data[latency_points_key]
|
||||
print(f"points_for_method: {points_for_method}")
|
||||
for recall, latency in points_for_method:
|
||||
# Only skip if latency is invalid (since we need log scale for x-axis)
|
||||
# But allow zero storage since y-axis is now linear
|
||||
if np.isnan(latency) or np.isnan(disk_storage) or latency <= 0:
|
||||
continue
|
||||
|
||||
# Add LLM generation time from CSV
|
||||
current_llm_add_time = llm_generation_times.get((current_ds, current_hw))
|
||||
if current_llm_add_time is not None and not np.isnan(current_llm_add_time):
|
||||
latency = latency + current_llm_add_time
|
||||
else:
|
||||
raise ValueError(f"No LLM generation time found for {current_ds} on {current_hw}")
|
||||
|
||||
# Special handling for BM25
|
||||
if method_name == 'BM25':
|
||||
# BM25 is only valid for 85% recall points (other points are 0)
|
||||
if recall != 85.0:
|
||||
continue
|
||||
color = 'grey'
|
||||
else:
|
||||
# Use the color for target recall
|
||||
color = recall_colors.get(recall, 'grey')
|
||||
|
||||
# Convert storage to proportion
|
||||
disk_storage_proportion = disk_storage / raw_data_size
|
||||
size = 80
|
||||
|
||||
x_offset = -50
|
||||
if current_ds == 'GPQA':
|
||||
x_offset = -32
|
||||
|
||||
# Apply a small vertical offset to IVF-Recompute points to make them more visible
|
||||
if method_name == 'IVF-Recompute':
|
||||
# Add a small vertical offset (adjust the 0.05 value as needed)
|
||||
disk_storage_proportion += 0.07
|
||||
size = 80
|
||||
if method_name == 'DiskANN':
|
||||
size = 50
|
||||
if method_name == 'Our':
|
||||
size = 140
|
||||
disk_storage_proportion += 0.05
|
||||
# Add "Pareto Frontier" label to Our method points
|
||||
|
||||
if recall == 95:
|
||||
ax.annotate('Ours',
|
||||
(latency, disk_storage_proportion),
|
||||
xytext=(x_offset, 25), # Increased leftward offset from -65 to -120
|
||||
textcoords='offset points',
|
||||
fontsize=20,
|
||||
color='red',
|
||||
weight='bold',
|
||||
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="red", alpha=0.7))
|
||||
# Increase size for BM25 points
|
||||
if method_name == 'BM25':
|
||||
size = 70
|
||||
size*=5
|
||||
|
||||
ax.scatter(latency, disk_storage_proportion, marker=marker, color=color,
|
||||
s=size, alpha=0.85, edgecolors='black', linewidths=0.7)
|
||||
|
||||
|
||||
|
||||
|
||||
ax.set_xscale("log")
|
||||
ax.set_yscale("linear") # CHANGED from log scale to linear scale for Y-axis
|
||||
|
||||
# Generate appropriate powers of 10 based on your data range
|
||||
min_power = -1
|
||||
max_power = 4
|
||||
log_ticks = [10**i for i in range(min_power, max_power+1)]
|
||||
|
||||
# Set custom tick positions
|
||||
ax.set_xticks(log_ticks)
|
||||
|
||||
# Create custom bold LaTeX labels with 10^n format
|
||||
log_tick_labels = [fr'$\mathbf{{10^{{{i}}}}}$' for i in range(min_power, max_power+1)]
|
||||
ax.set_xticklabels(log_tick_labels, fontsize=24)
|
||||
|
||||
# Apply global limits
|
||||
if subplot_idx == 0:
|
||||
ax.set_xlim(global_xlim)
|
||||
ax.set_ylim(global_ylim)
|
||||
|
||||
ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.7)
|
||||
# Remove minor grid lines completely
|
||||
ax.grid(False, which="minor")
|
||||
|
||||
# Remove ticks
|
||||
# First set the shared parameters for both axes
|
||||
ax.tick_params(axis='both', which='both', length=0, labelsize=24)
|
||||
|
||||
# Then set the padding only for the x-axis
|
||||
ax.tick_params(axis='x', which='both', pad=10)
|
||||
|
||||
if subplot_idx == 0: # Y-label only for the leftmost subplot
|
||||
ax.set_ylabel("Proportional Size", fontsize=24)
|
||||
|
||||
# X-label for all subplots in a 1xN layout can be okay, or just the middle/last one.
|
||||
# Let's put it on all for now.
|
||||
ax.set_xlabel("Latency (s)", fontsize=25)
|
||||
|
||||
# Display 100%, 200%, 300% for yaxis
|
||||
ax.set_yticks([1, 2, 3])
|
||||
ax.set_yticklabels(['100\%', '200\\%', '300\\%'])
|
||||
|
||||
# Create a custom arrow with "Better" text inside
|
||||
# Create the arrow patch with a wider shaft
|
||||
arrow = FancyArrowPatch(
|
||||
(0.8, 0.8), # Start point (top-right)
|
||||
(0.65, 0.6), # End point (toward bottom-left)
|
||||
transform=ax.transAxes,
|
||||
arrowstyle='simple,head_width=40,head_length=35,tail_width=20', # Increased arrow dimensions
|
||||
facecolor='white',
|
||||
edgecolor='black',
|
||||
linewidth=3, # Thicker outline
|
||||
zorder=5
|
||||
)
|
||||
|
||||
# Add the arrow to the plot
|
||||
ax.add_patch(arrow)
|
||||
|
||||
# Calculate the midpoint of the arrow for text placement
|
||||
mid_x = (0.8 + 0.65) / 2 + 0.002 + 0.01
|
||||
mid_y = (0.8 + 0.6) / 2 + 0.01
|
||||
|
||||
# Add the "Better" text at the midpoint of the arrow
|
||||
ax.text(mid_x, mid_y, 'Better',
|
||||
transform=ax.transAxes,
|
||||
ha='center',
|
||||
va='center',
|
||||
fontsize=16, # Increased font size from 12 to 16
|
||||
fontweight='bold',
|
||||
rotation=40, # Rotate to match arrow direction
|
||||
zorder=6) # Ensure text is on top of arrow
|
||||
|
||||
# Create legends (once per figure)
|
||||
method_legend_handles = []
|
||||
for method, marker_style in method_markers.items():
|
||||
if method in all_method_names:
|
||||
print(f"method: {method}")
|
||||
# Use black color for BM25 in the legend
|
||||
if method == 'BM25':
|
||||
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
|
||||
markersize=10, label=method))
|
||||
else:
|
||||
if method in method_display_names:
|
||||
method = method_display_names[method]
|
||||
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
|
||||
markersize=10, label=method))
|
||||
|
||||
recall_legend_handles = []
|
||||
sorted_recall_levels = sorted(recall_colors.keys())
|
||||
for r_level in sorted_recall_levels:
|
||||
recall_legend_handles.append(mlines.Line2D([], [], color=recall_colors[r_level], marker='o', linestyle='None',
|
||||
markersize=20, label=f"Target Recall={r_level:.0f}\%"))
|
||||
|
||||
# 将图例分成两行:第一行是方法,第二行是召回率
|
||||
if fig_idx == 0:
|
||||
# 从方法列表中先排除'Our'
|
||||
other_methods = [m for m in all_method_names if m != 'Our']
|
||||
# 按照需要的顺序创建方法列表(将'Our'放在最后)
|
||||
ordered_methods = other_methods + (['Our'] if 'Our' in all_method_names else [])
|
||||
|
||||
# 按照新顺序创建方法图例句柄
|
||||
method_legend_handles = []
|
||||
for method in ordered_methods:
|
||||
if method in method_markers:
|
||||
marker_style = method_markers[method]
|
||||
# 使用显示名称映射
|
||||
display_name = method_display_names.get(method, method)
|
||||
color = 'black'
|
||||
marker_size = 22
|
||||
if method == 'Our':
|
||||
marker_size = 27
|
||||
elif 'IVF-Recompute' in method or 'EdgeRAG' in method:
|
||||
marker_size = 17
|
||||
elif 'DiskANN' in method:
|
||||
marker_size = 19
|
||||
elif 'BM25' in method:
|
||||
marker_size = 20
|
||||
method_legend_handles.append(mlines.Line2D([], [], color=color, marker=marker_style,
|
||||
linestyle='None', markersize=marker_size, label=display_name))
|
||||
|
||||
# 创建召回率图例(第二行)- 注意位置调整,放在方法图例下方
|
||||
recall_legend = fig.legend(handles=recall_legend_handles,
|
||||
loc='upper center', bbox_to_anchor=(0.5, 1.05), # y坐标降低,放在第一行下方
|
||||
ncol=len(recall_legend_handles), fontsize=28)
|
||||
|
||||
|
||||
# 创建方法图例(第一行)
|
||||
method_legend = fig.legend(handles=method_legend_handles,
|
||||
loc='upper center', bbox_to_anchor=(0.5, 0.91),
|
||||
ncol=len(method_legend_handles), fontsize=28)
|
||||
|
||||
# 添加图例到渲染器
|
||||
fig.add_artist(method_legend)
|
||||
fig.add_artist(recall_legend)
|
||||
|
||||
# 调整布局,为顶部的两行图例留出更多空间
|
||||
plt.tight_layout(rect=(0, 0, 1.0, 0.74)) # 顶部空间从0.9调整到0.85,给两行图例留出更多空间
|
||||
|
||||
save_path = f'./paper_plot/figures/main_exp_fig_{fig_idx+1}.pdf'
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Saved figure {fig_idx+1} to {save_path}")
|
||||
plt.show()
|
||||
163
research/paper_plot/main_latency.py
Normal file
163
research/paper_plot/main_latency.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import csv
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import csv
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
SAVE_PTH = "./paper_plot/figures"
|
||||
font_size = 16
|
||||
|
||||
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
|
||||
# 0.085s 0.217s 0.472s
|
||||
llm_inference_time=[0.085, 0.217, 0.472, 0]
|
||||
|
||||
USE_LLM_INDEX = 3 # +0
|
||||
|
||||
file_path = "./paper_plot/data/main_latency.csv"
|
||||
|
||||
with open(file_path, mode="r", newline="") as file:
|
||||
reader = csv.reader(file)
|
||||
data = list(reader)
|
||||
|
||||
# 打印原始数据
|
||||
for row in data:
|
||||
print(",".join(row))
|
||||
|
||||
|
||||
|
||||
|
||||
models = ["A10", "MAC"]
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
||||
data = [[float(cell) if cell.isdigit() else cell for cell in row] for row in data[1:]]
|
||||
for k, model in enumerate(models):
|
||||
|
||||
fig, axes = plt.subplots(1, 4)
|
||||
fig.set_size_inches(20, 3)
|
||||
plt.subplots_adjust(wspace=0, hspace=0)
|
||||
|
||||
total_width, n = 6, 6
|
||||
group = 1
|
||||
width = total_width * 0.9 / n
|
||||
x = np.arange(group) * n
|
||||
exit_idx_x = x + (total_width - width) / n
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "mediumpurple", "green", "red", "blue", "yellow", "silver"]
|
||||
# hatches = ["", "\\\\", "//", "||", "x", "--", "..", "", "\\\\", "//", "||", "x", "--", ".."]
|
||||
hatches =["\\\\\\","\\\\"]
|
||||
|
||||
labels = [
|
||||
"HNSW",
|
||||
"IVF",
|
||||
"DiskANN",
|
||||
"IVF-Disk",
|
||||
"IVF-Recompute",
|
||||
"Our",
|
||||
# "DGL-OnDisk",
|
||||
]
|
||||
if k == 0:
|
||||
x_labels = "GraphSAGE"
|
||||
else:
|
||||
x_labels = "GAT"
|
||||
|
||||
yticks = [0.01, 0.1, 1, 10, 100, 1000,10000] # Log scale ticks
|
||||
val_limit = 15000 # Upper limit for the plot
|
||||
|
||||
for i in range(4):
|
||||
axes[i].set_yscale('log') # Set y-axis to logarithmic scale
|
||||
axes[i].set_yticks(yticks)
|
||||
axes[i].set_ylim(0.01, val_limit) # Lower limit should be > 0 for log scale
|
||||
|
||||
axes[i].tick_params(axis="y", labelsize=10)
|
||||
|
||||
axes[i].set_xticks([])
|
||||
# axes[i].set_xticklabels()
|
||||
axes[i].set_xlabel(datasets[i], fontsize=font_size)
|
||||
axes[i].grid(axis="y", linestyle="--")
|
||||
axes[i].set_xlim(exit_idx_x[0] - 0.15 * width - 0.2, exit_idx_x[0] + (n-0.25)* width + 0.2)
|
||||
for j in range(n):
|
||||
##TODO add label
|
||||
|
||||
# num = float(data[i * 2 + k][j + 3])
|
||||
# plot_label = [num]
|
||||
# if j == 6 and i == 3:
|
||||
# plot_label = ["N/A"]
|
||||
# num = 0
|
||||
local_hatches=["////","\\\\","xxxx"]
|
||||
# here add 3 bars rather than one bar TODO
|
||||
print('exit_idx_x',exit_idx_x)
|
||||
|
||||
# Check if all three models for this algorithm are OOM (data = 0)
|
||||
is_oom = True
|
||||
for m in range(3):
|
||||
if float(data[i * 6 + k*3 + m][j + 3]) != 0:
|
||||
is_oom = False
|
||||
break
|
||||
|
||||
if is_oom:
|
||||
# Draw a cross for OOM instead of bars
|
||||
pos = exit_idx_x + j * width + width * 0.3 # Center position for cross
|
||||
marker_size = width * 150 # Size of the cross
|
||||
axes[i].scatter(pos, 0.02, marker='x', color=edgecolors[j], s=marker_size,
|
||||
linewidth=4, label=labels[j] if j < len(labels) else "", zorder=20)
|
||||
else:
|
||||
# Create three separate bar calls instead of trying to plot multiple bars at once
|
||||
for m in range(3):
|
||||
num = float(data[i * 6 + k*3 +m][j + 3]) +llm_inference_time[USE_LLM_INDEX]
|
||||
plot_label = [num]
|
||||
pos = exit_idx_x + j * width + width * 0.3 * m
|
||||
print(f"j: {j}, m: {m}, pos: {pos}")
|
||||
# For log scale, we need to ensure values are positive
|
||||
plot_value = max(0.01, num) if num < val_limit else val_limit
|
||||
container = axes[i].bar(
|
||||
pos,
|
||||
plot_value,
|
||||
width=width * 0.3,
|
||||
color="white",
|
||||
edgecolor=edgecolors[j],
|
||||
# edgecolor="k",
|
||||
hatch=local_hatches[m], # Use different hatches for each of the 3 bars
|
||||
linewidth=1.0,
|
||||
label=labels[j] if m == 0 else "", # Only add label for the first bar
|
||||
zorder=10,
|
||||
)
|
||||
# axes[i].bar_label(
|
||||
# container,
|
||||
# plot_label,
|
||||
# fontsize=font_size - 2,
|
||||
# zorder=200,
|
||||
# fontweight="bold",
|
||||
# )
|
||||
|
||||
if k == 0:
|
||||
axes[0].legend(
|
||||
bbox_to_anchor=(3.25, 1.02),
|
||||
ncol=7,
|
||||
loc="lower right",
|
||||
# fontsize=font_size,
|
||||
# markerscale=3,
|
||||
labelspacing=0.2,
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
# fancybox=False,
|
||||
handlelength=2,
|
||||
handletextpad=0.5,
|
||||
columnspacing=0.5,
|
||||
prop={"weight": "bold", "size": font_size},
|
||||
).set_zorder(100)
|
||||
|
||||
axes[0].set_ylabel("Runtime (log scale)", fontsize=font_size, fontweight="bold")
|
||||
axes[0].set_yticklabels([r"$10^{-2}$", r"$10^{-1}$", r"$10^{0}$", r"$10^{1}$", r"$10^{2}$", r"$10^{3}$",r"$10^{4}$"], fontsize=font_size)
|
||||
axes[1].set_yticklabels([])
|
||||
axes[2].set_yticklabels([])
|
||||
axes[3].set_yticklabels([])
|
||||
|
||||
plt.savefig(f"{SAVE_PTH }/speed_{model}_revised.pdf", bbox_inches="tight", dpi=300)
|
||||
## print save
|
||||
print(f"{SAVE_PTH }/speed_{model}_revised.pdf")
|
||||
85
research/paper_plot/main_memory_storage.py
Normal file
85
research/paper_plot/main_memory_storage.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.gridspec import GridSpec
|
||||
|
||||
# Comment Test
|
||||
|
||||
# om script.settings import DATA_PATH, FIGURE_PATH
|
||||
# DATA_PATH ="/home/ubuntu/Power-RAG/paper_plot/data"
|
||||
# FIGURE_PATH = "/home/ubuntu/Power-RAG/paper_plot/figures"
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 2
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Load the RAM and Storage data directly from CSV
|
||||
data = pd.read_csv("./paper_plot/data/ram_storage.csv")
|
||||
|
||||
# Explicitly reorder columns to ensure "Our" is at the end
|
||||
cols = list(data.columns)
|
||||
if "Our" in cols and cols[-1] != "Our":
|
||||
cols.remove("Our")
|
||||
cols.append("Our")
|
||||
data = data[cols]
|
||||
|
||||
# Set up the figure with two columns
|
||||
fig = plt.figure(figsize=(12, 3))
|
||||
gs = GridSpec(1, 2, figure=fig)
|
||||
ax1 = fig.add_subplot(gs[0, 0]) # Left panel for RAM
|
||||
ax2 = fig.add_subplot(gs[0, 1]) # Right panel for Storage
|
||||
|
||||
# Define the visual style elements
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "silver", "navy"]
|
||||
hatches = ["/////", "\\\\\\\\\\"]
|
||||
|
||||
# Calculate positions for the bars
|
||||
methods = data.columns[1:] # Skip the 'Hardware' column
|
||||
num_methods = len(methods)
|
||||
# Reverse the order of methods for display (to have "Our" at the bottom)
|
||||
methods = list(methods)[::-1]
|
||||
y_positions = np.arange(num_methods)
|
||||
bar_width = 0.6
|
||||
|
||||
# Plot RAM data in left panel
|
||||
ram_bars = ax1.barh(
|
||||
y_positions,
|
||||
data.iloc[0, 1:].values[::-1], # Reverse the data to match reversed methods
|
||||
height=bar_width,
|
||||
color="white",
|
||||
edgecolor=edgecolors[0],
|
||||
hatch=hatches[0],
|
||||
linewidth=1.0,
|
||||
label="RAM",
|
||||
zorder=10,
|
||||
)
|
||||
ax1.set_title("RAM Usage", fontsize=14, fontweight='bold')
|
||||
ax1.set_yticks(y_positions)
|
||||
ax1.set_yticklabels(methods, fontsize=14)
|
||||
ax1.set_xlabel("Size (\\textit{GB})", fontsize=14)
|
||||
ax1.xaxis.set_tick_params(labelsize=14)
|
||||
|
||||
# Plot Storage data in right panel
|
||||
storage_bars = ax2.barh(
|
||||
y_positions,
|
||||
data.iloc[1, 1:].values[::-1], # Reverse the data to match reversed methods
|
||||
height=bar_width,
|
||||
color="white",
|
||||
edgecolor=edgecolors[1],
|
||||
hatch=hatches[1],
|
||||
linewidth=1.0,
|
||||
label="Storage",
|
||||
zorder=10,
|
||||
)
|
||||
ax2.set_title("Storage Usage", fontsize=14, fontweight='bold')
|
||||
ax2.set_yticks(y_positions)
|
||||
ax2.set_yticklabels(methods, fontsize=14)
|
||||
ax2.set_xlabel("Size (\\textit{GB})", fontsize=14)
|
||||
ax2.xaxis.set_tick_params(labelsize=14)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("./paper_plot/figures/ram_storage_double_column.pdf", bbox_inches="tight", dpi=300)
|
||||
print("Saving the figure to ./paper_plot/figures/ram_storage_double_column.pdf")
|
||||
141
research/paper_plot/recompute_bottle.py
Normal file
141
research/paper_plot/recompute_bottle.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /bottleneck_breakdown.py
|
||||
# \brief: Illustrates the query time bottleneck on consumer devices (Final Version - Font & Legend Adjust).
|
||||
# Author: Gemini Assistant (adapted from user's style and feedback)
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import FuncFormatter # Not strictly needed for just font, but imported if user wants to try
|
||||
|
||||
# Set matplotlib styles similar to the example
|
||||
plt.rcParams["font.family"] = "Helvetica" # Primary font family
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["xtick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.0
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
|
||||
plt.rcParams["text.usetex"] = True
|
||||
# Attempt to make LaTeX use Helvetica as the main font
|
||||
plt.rcParams['text.latex.preamble'] = r"""
|
||||
\usepackage{helvet} % helvetica font
|
||||
\usepackage{sansmath} % helvetica for math
|
||||
\sansmath % activate sansmath
|
||||
\renewcommand{\familydefault}{\sfdefault} % make sans-serif the default family
|
||||
"""
|
||||
|
||||
|
||||
# Final Data for the breakdown (3 Segments)
|
||||
labels_raw = [ # Raw labels before potential LaTeX escaping
|
||||
'IO: Text + PQ Lookup',
|
||||
'CPU: Tokenize + Distance Compute',
|
||||
'GPU: Embedding Recompute',
|
||||
]
|
||||
# Times in ms, ordered for stacking
|
||||
times_ms = np.array([
|
||||
8.009, # Quantization
|
||||
16.197, # Search
|
||||
76.512, # Embedding Recomputation
|
||||
])
|
||||
|
||||
total_time_ms = times_ms.sum()
|
||||
percentages = (times_ms / total_time_ms) * 100
|
||||
|
||||
# Prepare labels for legend, escaping for LaTeX if active
|
||||
labels_legend = []
|
||||
# st1 = r'&' # Not needed as current labels_raw don't have '&'
|
||||
for label, time, perc in zip(labels_raw, times_ms, percentages):
|
||||
# Construct the percentage string carefully for LaTeX
|
||||
perc_str = f"{perc:.1f}" + r"\%" # Correct way to form 'NN.N\%'
|
||||
# label_tex = label.replace('&', st1) # Use if '&' is in labels_raw
|
||||
label_tex = label # Current labels_raw are clean for LaTeX
|
||||
labels_legend.append(
|
||||
f"{label_tex}\n({time:.1f}ms, {perc_str})"
|
||||
)
|
||||
|
||||
# Styling based on user's script
|
||||
# Using first 3 from the provided lists
|
||||
edgecolors_list = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
||||
hatches_list = ["/////", "xxxxx", "\\\\\\\\\\"]
|
||||
|
||||
edgecolors = edgecolors_list[:3]
|
||||
hatches = hatches_list[:3]
|
||||
fill_color = "white"
|
||||
|
||||
# Create the figure and axes
|
||||
# Adjusted figure size to potentially accommodate legend on the right
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(7, 1.5) # Width increased slightly, height adjusted
|
||||
# Adjusted right margin for external legend, bottom for x-label
|
||||
plt.subplots_adjust(left=0.12, right=0.72, top=0.95, bottom=0.25)
|
||||
|
||||
# Create the horizontal stacked bar
|
||||
bar_height = 0.2
|
||||
y_pos = 0
|
||||
|
||||
left_offset = 0
|
||||
for i in range(len(times_ms)):
|
||||
ax.barh(
|
||||
y_pos,
|
||||
times_ms[i],
|
||||
height=bar_height,
|
||||
left=left_offset,
|
||||
color=fill_color,
|
||||
edgecolor=edgecolors[i],
|
||||
hatch=hatches[i],
|
||||
linewidth=1.5,
|
||||
label=labels_legend[i],
|
||||
zorder=10
|
||||
)
|
||||
text_x_pos = left_offset + times_ms[i] / 2
|
||||
if times_ms[i] > total_time_ms * 0.03: # Threshold for displaying text
|
||||
ax.text(
|
||||
text_x_pos,
|
||||
y_pos,
|
||||
f"{times_ms[i]:.1f}ms",
|
||||
ha='center',
|
||||
va='center',
|
||||
fontsize=8,
|
||||
fontweight='bold',
|
||||
color='black',
|
||||
zorder=20,
|
||||
bbox=dict(facecolor='white', edgecolor='none', pad=0.5, alpha=0.8)
|
||||
)
|
||||
left_offset += times_ms[i]
|
||||
|
||||
# Set plot limits and labels
|
||||
ax.set_xlim([0, total_time_ms * 1.02])
|
||||
ax.set_xlabel("Time (ms)", fontsize=14, fontweight='bold', x=0.75, )
|
||||
|
||||
# Y-axis: Remove y-ticks and labels
|
||||
ax.set_yticks([])
|
||||
ax.set_yticklabels([])
|
||||
|
||||
# Legend: Placed to the right of the plot
|
||||
ax.legend(
|
||||
# (x, y) for anchor, (0,0) is bottom left, (1,1) is top right of AXES
|
||||
# To place outside on the right, x should be > 1
|
||||
bbox_to_anchor=(1.03, 0.5), # x > 1 means outside to the right, y=0.5 for vertical center
|
||||
ncol=1, # Single column for a taller, narrower legend
|
||||
loc="center left", # Anchor the legend's left-center to bbox_to_anchor point
|
||||
labelspacing=0.5, # Adjust spacing
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
fancybox=False,
|
||||
handlelength=1.5,
|
||||
handletextpad=0.6,
|
||||
columnspacing=1.5,
|
||||
prop={"weight": "bold", "size": 9},
|
||||
).set_zorder(100)
|
||||
|
||||
# Save the figure (using the original generic name as requested)
|
||||
output_filename = "./bottleneck_breakdown.pdf"
|
||||
# plt.tight_layout() # tight_layout might conflict with external legend; adjust subplots_adjust instead
|
||||
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
|
||||
print(f"Saved plot to {output_filename}")
|
||||
|
||||
# plt.show() # Uncomment to display plot interactively
|
||||
226
research/paper_plot/small_emb.py
Normal file
226
research/paper_plot/small_emb.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
# import matplotlib.ticker as mticker # Not actively used
|
||||
import os
|
||||
|
||||
FIGURE_PATH = "paper_plot/figures"
|
||||
|
||||
try:
|
||||
os.makedirs(FIGURE_PATH, exist_ok=True)
|
||||
print(f"Images will be saved to: {os.path.abspath(FIGURE_PATH)}")
|
||||
except OSError as e:
|
||||
print(f"Create {FIGURE_PATH} failed: {e}. Images will be saved in the current working directory.")
|
||||
FIGURE_PATH = "."
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 2
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
method_labels = ["gte-small (33M)", "contriever-msmarco (110M)"]
|
||||
dataset_names = ["NQ", "TriviaQA"]
|
||||
metrics_plot1 = ["Exact Match", "F1"]
|
||||
|
||||
small_nq_f1 = 0.2621040899
|
||||
small_tq_f1 = 0.4698198059
|
||||
small_nq_em_score = 0.1845
|
||||
small_tq_em_score = 0.4015
|
||||
small_nq_time = 1.137
|
||||
small_tq_time = 1.173
|
||||
|
||||
large_nq_f1 = 0.2841386117
|
||||
large_tq_f1 = 0.4548340289
|
||||
large_nq_em_score = 0.206
|
||||
large_tq_em_score = 0.382
|
||||
large_nq_time = 2.632
|
||||
large_tq_time = 2.684
|
||||
|
||||
data_scores_plot1 = {
|
||||
"NQ": {"Exact Match": [small_nq_em_score, large_nq_em_score], "F1": [small_nq_f1, large_nq_f1]},
|
||||
"TriviaQA": {"Exact Match": [small_tq_em_score, large_tq_em_score], "F1": [small_tq_f1, large_tq_f1]}
|
||||
}
|
||||
latency_data_plot2 = {
|
||||
"NQ": [small_nq_time, large_nq_time],
|
||||
"TriviaQA": [small_tq_time, large_tq_time]
|
||||
}
|
||||
|
||||
edgecolors = ["dimgrey", "tomato"]
|
||||
hatches = ["/////", "\\\\\\\\\\"]
|
||||
|
||||
# Changed: bar_center_separation_in_group increased for larger gap
|
||||
bar_center_separation_in_group = 0.42
|
||||
# Changed: bar_visual_width decreased for narrower bars
|
||||
bar_visual_width = 0.28
|
||||
|
||||
figsize_plot1 = (4, 2.5)
|
||||
# Changed: figsize_plot2 width adjusted to match figsize_plot1 for legend/caption alignment
|
||||
figsize_plot2 = (2.5, 2.5)
|
||||
|
||||
# Define plot1_xlim_per_subplot globally so it can be accessed by create_plot2_latency
|
||||
plot1_xlim_per_subplot = (0.0, 2.0) # Explicit xlim for plot 1 subplots
|
||||
|
||||
common_subplots_adjust_params = dict(wspace=0.30, top=0.80, bottom=0.22, left=0.09, right=0.96)
|
||||
|
||||
|
||||
def create_plot1_em_f1():
|
||||
fig, axs = plt.subplots(1, 2, figsize=figsize_plot1)
|
||||
fig.subplots_adjust(**common_subplots_adjust_params)
|
||||
|
||||
num_methods = len(method_labels)
|
||||
metric_group_centers = np.array([0.5, 1.5])
|
||||
# plot1_xlim_per_subplot is now global
|
||||
|
||||
for i, dataset_name in enumerate(dataset_names):
|
||||
ax = axs[i]
|
||||
for metric_idx, metric_name in enumerate(metrics_plot1):
|
||||
metric_center_pos = metric_group_centers[metric_idx]
|
||||
current_scores_raw = data_scores_plot1[dataset_name][metric_name]
|
||||
current_scores_percent = [val * 100 for val in current_scores_raw]
|
||||
|
||||
for j, method_label in enumerate(method_labels):
|
||||
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
||||
bar_center_pos = metric_center_pos + offset
|
||||
|
||||
ax.bar(
|
||||
bar_center_pos, current_scores_percent[j], width=bar_visual_width, color="white",
|
||||
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
|
||||
label=method_label if i == 0 and metric_idx == 0 else None
|
||||
)
|
||||
ax.text(
|
||||
bar_center_pos, current_scores_percent[j] + 0.8, f"{current_scores_percent[j]:.1f}",
|
||||
ha='center', va='bottom', fontsize=8, fontweight='bold'
|
||||
)
|
||||
|
||||
ax.set_xticks(metric_group_centers)
|
||||
ax.set_xticklabels(metrics_plot1, fontsize=9, fontweight='bold')
|
||||
ax.set_title(dataset_name, fontsize=12, fontweight='bold')
|
||||
ax.set_xlim(plot1_xlim_per_subplot) # Apply consistent xlim
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
||||
|
||||
if i == 0:
|
||||
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
||||
|
||||
all_subplot_scores_percent = []
|
||||
for metric_name_iter in metrics_plot1:
|
||||
all_subplot_scores_percent.extend([val * 100 for val in data_scores_plot1[dataset_name][metric_name_iter]])
|
||||
|
||||
max_val = max(all_subplot_scores_percent) if all_subplot_scores_percent else 0
|
||||
ax.set_ylim(0, max_val * 1.22 if max_val > 0 else 10)
|
||||
ax.tick_params(axis='y', labelsize=12)
|
||||
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(True)
|
||||
spine.set_linewidth(1.0)
|
||||
spine.set_edgecolor("black")
|
||||
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
fig.legend(
|
||||
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=len(method_labels),
|
||||
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
|
||||
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 9}
|
||||
)
|
||||
|
||||
# fig.text(0.5, 0.06, "(a) EM \& F1", ha='center', va='center', fontweight='bold', fontsize=11)
|
||||
|
||||
|
||||
save_path = os.path.join(FIGURE_PATH, "plot1_em_f1.pdf")
|
||||
# plt.tight_layout() # Adjusted call below
|
||||
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
|
||||
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
|
||||
plt.close(fig)
|
||||
print(f"Figure 1 (Exact Match & F1) has been saved to: {save_path}")
|
||||
|
||||
def create_plot2_latency():
|
||||
fig, axs = plt.subplots(1, 2, figsize=figsize_plot2) # figsize_plot2 width is now 8.0
|
||||
fig.subplots_adjust(**common_subplots_adjust_params)
|
||||
|
||||
num_methods = len(method_labels)
|
||||
method_group_center_in_subplot = 0.5
|
||||
|
||||
# Calculate bar extents to determine focused xlim
|
||||
bar_positions_calc = []
|
||||
for j_idx in range(num_methods):
|
||||
offset_calc = (j_idx - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
||||
bar_center_pos_calc = method_group_center_in_subplot + offset_calc
|
||||
bar_positions_calc.append(bar_center_pos_calc)
|
||||
|
||||
min_bar_actual_edge = min(bar_positions_calc) - bar_visual_width / 2.0
|
||||
max_bar_actual_edge = max(bar_positions_calc) + bar_visual_width / 2.0
|
||||
|
||||
# Define padding around the bars
|
||||
# Option 1: Fixed padding (e.g., 0.15 as derived from plot 1 visual)
|
||||
# padding_val = 0.15
|
||||
# plot2_xlim_calculated = (min_bar_actual_edge - padding_val, max_bar_actual_edge + padding_val)
|
||||
# This would be (0.15 - 0.15, 0.85 + 0.15) = (0.0, 1.0)
|
||||
|
||||
# Option 2: Center the group (0.5) in a span of 1.0
|
||||
plot2_xlim_calculated = (method_group_center_in_subplot - 0.5, method_group_center_in_subplot + 0.5)
|
||||
# This is (0.5 - 0.5, 0.5 + 0.5) = (0.0, 1.0)
|
||||
# This is simpler and achieves the (0.0, 1.0) directly.
|
||||
|
||||
for i, dataset_name in enumerate(dataset_names):
|
||||
ax = axs[i]
|
||||
current_latencies = latency_data_plot2[dataset_name]
|
||||
|
||||
for j, method_label in enumerate(method_labels):
|
||||
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
||||
bar_center_pos = method_group_center_in_subplot + offset
|
||||
|
||||
ax.bar(
|
||||
bar_center_pos, current_latencies[j], width=bar_visual_width, color="white",
|
||||
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
|
||||
label=method_label if i == 0 else None
|
||||
)
|
||||
ax.text(
|
||||
bar_center_pos, current_latencies[j] + 0.05, f"{current_latencies[j]:.2f}",
|
||||
ha='center', va='bottom', fontsize=10, fontweight='bold'
|
||||
)
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
||||
|
||||
ax.set_xticks([0.5])
|
||||
ax.set_xticklabels(["Latency"], color="white", fontsize=12)
|
||||
# set tick hatches
|
||||
ax.tick_params(axis='x', colors="white")
|
||||
ax.set_title(dataset_name, fontsize=13, fontweight='bold')
|
||||
ax.set_xlim(plot2_xlim_calculated)
|
||||
|
||||
if i == 0:
|
||||
ax.set_ylabel("Latency (s)", fontsize=12, fontweight="bold")
|
||||
|
||||
max_latency_in_subplot = max(current_latencies) if current_latencies else 0
|
||||
ax.set_ylim(0, max_latency_in_subplot * 1.22 if max_latency_in_subplot > 0 else 1)
|
||||
ax.tick_params(axis='y', labelsize=12)
|
||||
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(True)
|
||||
spine.set_linewidth(1.0)
|
||||
spine.set_edgecolor("black")
|
||||
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
fig.legend(
|
||||
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=num_methods,
|
||||
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
|
||||
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 9}
|
||||
)
|
||||
|
||||
# fig.text(0.5, 0.06, "(b) Latency", ha='center', va='center', fontweight='bold', fontsize=11)
|
||||
|
||||
save_path = os.path.join(FIGURE_PATH, "plot2_latency.pdf")
|
||||
# plt.tight_layout() # Adjusted call below
|
||||
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
|
||||
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
|
||||
plt.close(fig)
|
||||
print(f"Figure 2 (Latency) has been saved to: {save_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start generating figures...")
|
||||
if plt.rcParams["text.usetex"]:
|
||||
print("Info: LaTeX rendering is enabled. Ensure LaTeX is installed and configured if issues arise, or set plt.rcParams['text.usetex'] to False.")
|
||||
|
||||
create_plot1_em_f1()
|
||||
create_plot2_latency()
|
||||
print("All figures have been generated.")
|
||||
111
research/paper_plot/speed_ablation.py
Normal file
111
research/paper_plot/speed_ablation.py
Normal file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /speed_ablation.py
|
||||
# \brief:
|
||||
# Author: raphael hao
|
||||
|
||||
# %%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# %%
|
||||
# from script.settings import DATA_PATH, FIGURE_PATH
|
||||
|
||||
# Load the latency ablation data
|
||||
latency_data = pd.read_csv("./paper_plot/data/latency_ablation.csv")
|
||||
# Filter for SpeedUp metric only
|
||||
speedup_data = latency_data[latency_data['Metric'] == 'SpeedUp']
|
||||
|
||||
# %%
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
# %%
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(5, 1.5)
|
||||
plt.subplots_adjust(wspace=0, hspace=0)
|
||||
|
||||
total_width, n = 3, 3
|
||||
group = len(speedup_data['Dataset'].unique())
|
||||
width = total_width * 0.9 / n
|
||||
x = np.arange(group) * n
|
||||
exit_idx_x = x + (total_width - width) / n
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
||||
hatches = ["/////", "xxxxx", "\\\\\\\\\\"]
|
||||
labels = ["Base", "Base + Two-level", "Base + Two-level + Batch"]
|
||||
|
||||
datasets = speedup_data['Dataset'].unique()
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
dataset_data = speedup_data[speedup_data['Dataset'] == dataset]
|
||||
|
||||
for j in range(n):
|
||||
if j == 0:
|
||||
value = dataset_data['Original'].values[0]
|
||||
elif j == 1:
|
||||
value = dataset_data['original + two_level'].values[0]
|
||||
else:
|
||||
value = dataset_data['original + two_level + batch'].values[0]
|
||||
|
||||
ax.text(
|
||||
exit_idx_x[i] + j * width,
|
||||
value + 0.05,
|
||||
f"{value:.2f}",
|
||||
ha='center',
|
||||
va='bottom',
|
||||
fontsize=10,
|
||||
fontweight='bold',
|
||||
rotation=0,
|
||||
zorder=20,
|
||||
)
|
||||
|
||||
ax.bar(
|
||||
exit_idx_x[i] + j * width,
|
||||
value,
|
||||
width=width * 0.8,
|
||||
color="white",
|
||||
edgecolor=edgecolors[j],
|
||||
hatch=hatches[j],
|
||||
linewidth=1.5,
|
||||
label=labels[j] if i == 0 else None,
|
||||
zorder=10,
|
||||
)
|
||||
|
||||
|
||||
|
||||
ax.set_ylim([0.5, 2.3])
|
||||
ax.set_yticks(np.arange(0.5, 2.2, 0.5))
|
||||
ax.set_yticklabels(np.arange(0.5, 2.2, 0.5), fontsize=12)
|
||||
ax.set_xticks(exit_idx_x + width)
|
||||
ax.set_xticklabels(datasets, fontsize=10)
|
||||
# ax.set_xlabel("Different Datasets", fontsize=14)
|
||||
ax.legend(
|
||||
bbox_to_anchor=(-0.03, 1.4),
|
||||
ncol=3,
|
||||
loc="upper left",
|
||||
labelspacing=0.1,
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
fancybox=False,
|
||||
handlelength=0.8,
|
||||
handletextpad=0.6,
|
||||
columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 10},
|
||||
).set_zorder(100)
|
||||
ax.set_ylabel("Speedup", fontsize=11)
|
||||
|
||||
plt.savefig("./paper_plot/figures/latency_speedup.pdf", bbox_inches="tight", dpi=300)
|
||||
|
||||
# %%
|
||||
|
||||
print(f"Save to ./paper_plot/figures/latency_speedup.pdf")
|
||||
Reference in New Issue
Block a user