Initial commit
This commit is contained in:
@@ -0,0 +1,165 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Set plot parameters
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
# Path settings
|
||||
FIGURE_PATH = "./paper_plot/figures"
|
||||
|
||||
# Load accuracy data
|
||||
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
|
||||
|
||||
# Create figure with 4 subplots (one for each dataset)
|
||||
fig, axs = plt.subplots(1, 4)
|
||||
fig.set_size_inches(9, 2.5)
|
||||
|
||||
# Reduce the spacing between subplots
|
||||
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
|
||||
|
||||
# Define datasets and their columns
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
||||
metrics = ["Exact Match", "F1"]
|
||||
|
||||
# Define bar settings - make bars thicker
|
||||
# total_width, n = 0.9, 3 # increased total width and n for three models
|
||||
# width = total_width / n
|
||||
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
|
||||
# It's also used as the base for calculating the actual plotted bar width.
|
||||
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
|
||||
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
|
||||
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
|
||||
n = 3 # Number of models
|
||||
width = 0.64 # Distance between centers of adjacent bars in a group
|
||||
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
|
||||
|
||||
# Colors and hatches
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
|
||||
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
|
||||
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
|
||||
|
||||
# Create plots for each dataset
|
||||
for i, dataset in enumerate(datasets):
|
||||
ax = axs[i]
|
||||
|
||||
# Get data for this dataset and convert to percentages
|
||||
em_values = [
|
||||
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
|
||||
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
|
||||
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
|
||||
]
|
||||
f1_values = [
|
||||
acc_data.loc[0, f"{dataset} F1"] * 100,
|
||||
acc_data.loc[1, f"{dataset} F1"] * 100,
|
||||
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
|
||||
]
|
||||
|
||||
# Define x positions for bars
|
||||
# For EM: center - width, center, center + width
|
||||
# For F1: center - width, center, center + width
|
||||
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
|
||||
bar_offsets = [-width, 0, width]
|
||||
|
||||
# Plot all bars on the same axis
|
||||
for metric_idx, metric_group_center in enumerate(group_centers):
|
||||
values_to_plot = em_values if metric_idx == 0 else f1_values
|
||||
for j, model_label in enumerate(labels):
|
||||
x_pos = metric_group_center + bar_offsets[j]
|
||||
bar_value = values_to_plot[j]
|
||||
|
||||
ax.bar(
|
||||
x_pos,
|
||||
bar_value,
|
||||
width=width * bar_width_plotting_factor, # Use the new factor for bar width
|
||||
color="white",
|
||||
edgecolor=edgecolors[j],
|
||||
hatch=hatches[j],
|
||||
linewidth=1.5,
|
||||
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
|
||||
)
|
||||
|
||||
# Add value on top of bar
|
||||
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
|
||||
f"{bar_value:.1f}", ha='center', va='bottom',
|
||||
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
|
||||
|
||||
# Set x-ticks and labels
|
||||
ax.set_xticks(group_centers) # Position ticks at the center of each group
|
||||
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
|
||||
|
||||
# Now, shift these labels slightly to the right
|
||||
# Adjust this value to control the amount of shift (in data coordinates)
|
||||
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
|
||||
# horizontal_shift = 0.7 # Try adjusting this value
|
||||
|
||||
# for label in xticklabels:
|
||||
# # Get the current x position (which is the tick location)
|
||||
# current_x_pos = label.get_position()[0]
|
||||
# # Set the new x position by adding the shift
|
||||
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
|
||||
# # Ensure the label remains horizontally centered on this new x position
|
||||
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
|
||||
# label.set_horizontalalignment('center')
|
||||
|
||||
# Set title
|
||||
ax.set_title(dataset, fontsize=14)
|
||||
|
||||
# Set y-label for all subplots
|
||||
if i == 0:
|
||||
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
||||
else:
|
||||
# Hide y-tick labels for non-first subplots to save space
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
# Set y-limits based on data range
|
||||
all_values = em_values + f1_values
|
||||
max_val = max(all_values)
|
||||
min_val = min(all_values)
|
||||
|
||||
# Special handling for GPQA which has very low values
|
||||
if dataset == "GPQA":
|
||||
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
|
||||
else:
|
||||
# Reduce the extra space above the bars
|
||||
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
|
||||
|
||||
# Format y-ticks as percentages
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
||||
|
||||
# Set x-limits to properly space the bars with less blank space
|
||||
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
|
||||
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
|
||||
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
|
||||
|
||||
# Add a box around the subplot
|
||||
# for spine in ax.spines.values():
|
||||
# spine.set_visible(True)
|
||||
# spine.set_linewidth(1.0)
|
||||
|
||||
# Add legend to first subplot
|
||||
if i == 0:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
|
||||
ncol=3, # Changed to 3 columns for three labels
|
||||
loc="upper center",
|
||||
labelspacing=0.1,
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
fancybox=False,
|
||||
handlelength=1.0,
|
||||
handletextpad=0.6,
|
||||
columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 12},
|
||||
)
|
||||
|
||||
# Save figure with tight layout but no additional padding
|
||||
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
|
||||
plt.show()
|
||||
@@ -0,0 +1,309 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /hnsw_degree_visit_plot_binned_academic.py
|
||||
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
|
||||
# per degree bin, styled for academic publications, with caching.
|
||||
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
|
||||
|
||||
# %%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import re
|
||||
from collections import Counter
|
||||
import os # For robust filepath manipulation
|
||||
import math # For calculating scaling factor
|
||||
import pickle # For caching data
|
||||
|
||||
# %%
|
||||
# --- Matplotlib parameters for academic paper style (from reference) ---
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
|
||||
|
||||
# --- Define styles from reference ---
|
||||
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
||||
|
||||
# %%
|
||||
# --- File Paths ---
|
||||
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
|
||||
visit_log_file = './re.log'
|
||||
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
|
||||
# --- CACHE FILE PATH: Keep this consistent ---
|
||||
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
|
||||
|
||||
# --- Configuration ---
|
||||
# Set to True to bypass cache and force recomputation.
|
||||
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
|
||||
FORCE_RECOMPUTE = False
|
||||
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
|
||||
|
||||
# Create directory for figures if it doesn't exist
|
||||
output_dir = os.path.dirname(output_image_file)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
print(f"Created directory: {output_dir}")
|
||||
|
||||
# %%
|
||||
# --- Attempt to load data from cache or compute ---
|
||||
df_plot_data = None
|
||||
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
|
||||
|
||||
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
|
||||
try:
|
||||
with open(CACHE_FILE_PATH, 'rb') as f:
|
||||
cache_content = pickle.load(f)
|
||||
df_plot_data = cache_content['data']
|
||||
bin_size_for_plot = cache_content['bin_size']
|
||||
# Basic validation of cached data
|
||||
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
|
||||
if not isinstance(df_plot_data, pd.DataFrame) or \
|
||||
'degree_bin_label' not in df_plot_data.columns or \
|
||||
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
|
||||
not isinstance(bin_size_for_plot, int):
|
||||
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
|
||||
df_plot_data = None # Invalidate to trigger recomputation
|
||||
else:
|
||||
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
|
||||
|
||||
# --- Modify the label loaded from cache for display purpose ---
|
||||
# This modification only happens when data is loaded from cache and meets specific conditions.
|
||||
# Assumption: If the bin_size_for_plot in cache is 5,
|
||||
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
|
||||
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
|
||||
# Check if "0-4" label exists
|
||||
if '0-4' in df_plot_data['degree_bin_label'].values:
|
||||
# Use .loc to ensure the modification is on the original DataFrame
|
||||
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
|
||||
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
|
||||
except Exception as e:
|
||||
print(f"Error loading from cache: {e}. Recomputing.")
|
||||
df_plot_data = None # Invalidate to trigger recomputation
|
||||
|
||||
if df_plot_data is None:
|
||||
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
|
||||
# --- 1. Read Degree Distribution File ---
|
||||
degrees_data = []
|
||||
try:
|
||||
with open(degree_file, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
line_stripped = line.strip()
|
||||
if line_stripped:
|
||||
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
|
||||
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
|
||||
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
|
||||
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
|
||||
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
|
||||
|
||||
|
||||
if not degrees_data:
|
||||
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
|
||||
exit()
|
||||
df_degrees = pd.DataFrame(degrees_data)
|
||||
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
|
||||
|
||||
# --- 2. Read Visit Log File and Count Frequencies ---
|
||||
visit_counts = Counter()
|
||||
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
|
||||
try:
|
||||
with open(visit_log_file, 'r') as f_log:
|
||||
for line_num, line in enumerate(f_log, 1):
|
||||
match = node_id_pattern.search(line)
|
||||
if match:
|
||||
try:
|
||||
node_id = int(match.group(2))
|
||||
visit_counts[node_id] += 1 # Increment visit count for the node
|
||||
except ValueError:
|
||||
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
|
||||
if not df_degrees.empty:
|
||||
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
|
||||
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
|
||||
# Generate visit counts to test different probability magnitudes
|
||||
if node_id_val % 23 == 0: # Very low probability
|
||||
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
|
||||
elif node_id_val % 11 == 0: # Low probability
|
||||
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
|
||||
elif node_id_val % 5 == 0: # Moderate probability
|
||||
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
|
||||
else: # Higher probability (but still < 1000 visits for a single node usually)
|
||||
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
|
||||
visit_counts[node_id_val] = np.random.poisson(lambda_val)
|
||||
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
|
||||
|
||||
if not visit_counts:
|
||||
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
|
||||
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
|
||||
else:
|
||||
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
|
||||
df_visits = pd.DataFrame(df_visits_list)
|
||||
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
|
||||
|
||||
# --- 3. Merge Degree Data with Visit Data ---
|
||||
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
|
||||
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
|
||||
print(f"Merged data contains {len(df_merged)} entries.")
|
||||
|
||||
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
|
||||
current_bin_size = 5
|
||||
bin_size_for_plot = current_bin_size
|
||||
|
||||
if not df_degrees.empty:
|
||||
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
|
||||
|
||||
df_merged_with_bins = df_merged.copy()
|
||||
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
|
||||
|
||||
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
|
||||
total_visit_count_in_bin=('visit_count', 'sum'),
|
||||
node_count_in_bin=('node_id', 'nunique')
|
||||
).reset_index()
|
||||
|
||||
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
|
||||
# This value is what gets cached.
|
||||
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
|
||||
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
|
||||
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
|
||||
|
||||
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
|
||||
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
|
||||
|
||||
bin_to_drop_label = '60-64'
|
||||
original_length = len(df_binned_analysis)
|
||||
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
|
||||
if len(df_plot_data_intermediate) < original_length:
|
||||
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
|
||||
else:
|
||||
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
|
||||
|
||||
df_plot_data = df_plot_data_intermediate
|
||||
|
||||
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
|
||||
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
|
||||
|
||||
if df_plot_data is not None and not df_plot_data.empty:
|
||||
try:
|
||||
with open(CACHE_FILE_PATH, 'wb') as f:
|
||||
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
|
||||
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
|
||||
except Exception as e:
|
||||
print(f"Error saving data to cache: {e}")
|
||||
elif df_plot_data is None or df_plot_data.empty:
|
||||
print("Computed data for binned plot is empty, not saving to cache.")
|
||||
else:
|
||||
print("Degree data (df_degrees) is empty. Cannot perform binning.")
|
||||
df_plot_data = pd.DataFrame()
|
||||
bin_size_for_plot = current_bin_size
|
||||
|
||||
# %%
|
||||
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
|
||||
|
||||
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
|
||||
base_name, ext = os.path.splitext(output_image_file)
|
||||
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
|
||||
binned_output_image_file = base_name + ext
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
|
||||
|
||||
df_plot_data_plotting = df_plot_data.copy()
|
||||
# Calculate per-query probability: (avg visits over N queries) / N
|
||||
df_plot_data_plotting['per_query_visit_probability'] = \
|
||||
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
|
||||
|
||||
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
|
||||
|
||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
|
||||
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
|
||||
|
||||
apply_scaling_to_label_and_values = False # Initialize flag
|
||||
exponent_for_label_display = 0 # Initialize exponent
|
||||
|
||||
if pd.notna(max_probability) and max_probability > 0:
|
||||
potential_exponent = math.floor(math.log10(max_probability))
|
||||
|
||||
if potential_exponent <= -4 or potential_exponent >= 0:
|
||||
apply_scaling_to_label_and_values = True
|
||||
exponent_for_label_display = potential_exponent
|
||||
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
|
||||
|
||||
if apply_scaling_to_label_and_values:
|
||||
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
|
||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
|
||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
|
||||
else:
|
||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
|
||||
|
||||
elif pd.notna(max_probability) and max_probability == 0:
|
||||
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
|
||||
else:
|
||||
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
|
||||
|
||||
ax.bar(
|
||||
df_plot_data_plotting['degree_bin_label'],
|
||||
y_axis_values_to_plot,
|
||||
color='white',
|
||||
edgecolor=edgecolors_ref[0],
|
||||
linewidth=1.5,
|
||||
width=0.8
|
||||
)
|
||||
|
||||
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
|
||||
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
|
||||
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
|
||||
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
|
||||
|
||||
num_bins = len(df_plot_data_plotting)
|
||||
if num_bins > 12:
|
||||
ax.set_xticks(ax.get_xticks())
|
||||
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
|
||||
elif num_bins > 8:
|
||||
ax.tick_params(axis='x', labelsize=9)
|
||||
else:
|
||||
ax.tick_params(axis='x', labelsize=10)
|
||||
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
padding_factor = 0.05
|
||||
current_max_y_on_axis = y_axis_values_to_plot.max()
|
||||
|
||||
upper_y_limit = 0.1 # Default small upper limit
|
||||
if pd.notna(current_max_y_on_axis):
|
||||
if current_max_y_on_axis > 0:
|
||||
# Adjust minimum visible range based on whether scaling was applied and the exponent
|
||||
min_meaningful_limit = 0.01
|
||||
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
|
||||
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
|
||||
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
|
||||
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
|
||||
|
||||
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
|
||||
|
||||
else: # current_max_y_on_axis is 0
|
||||
upper_y_limit = 0.1
|
||||
ax.set_ylim(0, upper_y_limit)
|
||||
else:
|
||||
ax.set_ylim(0, 1.0) # Default for empty or NaN data
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
|
||||
print(f"Binned bar chart saved to {binned_output_image_file}")
|
||||
plt.show()
|
||||
plt.close(fig)
|
||||
else:
|
||||
if df_plot_data is None:
|
||||
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
|
||||
elif df_plot_data.empty:
|
||||
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
|
||||
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
|
||||
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
|
||||
|
||||
# %%
|
||||
print("Script finished.")
|
||||
@@ -0,0 +1,7 @@
|
||||
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data – up to 50× smaller than standard indexes – while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
|
||||
BIG_GRAPH_PATHS = [
|
||||
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
|
||||
]
|
||||
STATS_FILE_NAME = "degree_distribution.txt"
|
||||
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
|
||||
"HNSW-Base",
|
||||
"DegreeGuide",
|
||||
"HNSW-D9",
|
||||
"RandCut",
|
||||
]
|
||||
# Average degrees are static and can be directly used in the plotting script or also cached.
|
||||
# For simplicity here, we'll focus on caching the dynamic degree arrays.
|
||||
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
|
||||
|
||||
# --- Cache File Configuration ---
|
||||
DATA_CACHE_DIR = "./paper_plot/data/"
|
||||
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
|
||||
|
||||
def create_degree_data_cache():
|
||||
"""
|
||||
Reads degree distribution data from specified text files and saves it
|
||||
into a compressed NumPy (.npz) cache file.
|
||||
"""
|
||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
||||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
||||
|
||||
cached_data = {}
|
||||
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
|
||||
|
||||
for i, base_path in enumerate(BIG_GRAPH_PATHS):
|
||||
method_label = BIG_GRAPH_LABELS[i]
|
||||
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
|
||||
|
||||
print(f"Processing: {method_label} from {degree_file_path}")
|
||||
|
||||
try:
|
||||
# Load degrees as integers
|
||||
degrees = np.loadtxt(degree_file_path, dtype=int)
|
||||
|
||||
if degrees.size == 0:
|
||||
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
|
||||
# Store an empty array or handle as needed. For npz, an empty array is fine.
|
||||
cached_data[method_label] = np.array([], dtype=int)
|
||||
else:
|
||||
# Store the loaded degrees array with the method label as the key
|
||||
cached_data[method_label] = degrees
|
||||
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
|
||||
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
|
||||
# Storing None might require special handling when loading. Empty array is safer for np.load.
|
||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
|
||||
except Exception as e:
|
||||
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
|
||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
|
||||
|
||||
if not cached_data:
|
||||
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
|
||||
return
|
||||
|
||||
try:
|
||||
# Save all collected degree arrays into a single .npz file.
|
||||
# Using savez_compressed for potentially smaller file size.
|
||||
np.savez_compressed(cache_file_path, **cached_data)
|
||||
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
|
||||
print("Cached arrays (keys):", list(cached_data.keys()))
|
||||
except Exception as e:
|
||||
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("--- Degree Distribution Data Caching Script ---")
|
||||
create_degree_data_cache()
|
||||
print("--- Caching script finished. ---")
|
||||
@@ -0,0 +1,4 @@
|
||||
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
|
||||
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
|
||||
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
|
||||
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
|
||||
|
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
|
||||
size 227482438
|
||||
@@ -0,0 +1,21 @@
|
||||
2,1,512,1024,0.541,0.326,1.659509202
|
||||
2,2,512,1024,0.979,0.621,1.576489533
|
||||
2,4,512,1024,1.846,0.977,1.889457523
|
||||
2,8,512,1024,3.575,1.943,1.83993824
|
||||
2,16,512,1024,7.035,3.733,1.884543263
|
||||
2,32,512,1024,15.655,8.517,1.838088529
|
||||
2,64,512,1024,32.772,17.43,1.88020654
|
||||
4,1,512,1024,2.675,1.38,1.938405797
|
||||
4,2,512,1024,5.397,2.339,2.307396323
|
||||
4,4,512,1024,10.672,4.944,2.158576052
|
||||
4,8,512,1024,21.061,9.266,2.272933305
|
||||
4,16,512,1024,46.332,18.334,2.527108105
|
||||
4,32,512,1024,99.607,36.156,2.754923111
|
||||
4,64,512,1024,186.348,72.356,2.575432583
|
||||
8,1,512,1024,7.325,4.087,1.792268167
|
||||
8,2,512,1024,14.109,7.491,1.883460152
|
||||
8,4,512,1024,28.499,14.013,2.033754371
|
||||
8,8,512,1024,65.222,27.453,2.375769497
|
||||
8,16,512,1024,146.294,52.55,2.783901047
|
||||
8,32,512,1024,277.099,103.61,2.674442621
|
||||
8,64,512,1024,512.979,208.36,2.461984066
|
||||
|
@@ -0,0 +1,9 @@
|
||||
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
|
||||
NQ,Latency,6.9,5.8,4.2,3.7
|
||||
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
|
||||
TriviaQA,Latency,17.054,14.542,12.046,10.83
|
||||
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
|
||||
GPQA,Latency,9.164,7.639,6.798,5.77
|
||||
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
|
||||
HotpotQA,Latency,60.279,39.827,50.664,29.868
|
||||
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
|
||||
|
@@ -0,0 +1,25 @@
|
||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
|
||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
|
||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
|
||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
|
||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
|
||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
|
||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
|
||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
|
||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
|
||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
|
||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
|
||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
|
||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
|
||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
|
||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
|
||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
|
||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
|
||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
|
||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
|
||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
|
||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
|
||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
|
||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
|
||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
|
||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
|
||||
|
@@ -0,0 +1,25 @@
|
||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
|
||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
|
||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
|
||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
|
||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
|
||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
|
||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
|
||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
|
||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
|
||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
|
||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
|
||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
|
||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
|
||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
|
||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
|
||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
|
||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
|
||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
|
||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
|
||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
|
||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
|
||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
|
||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
|
||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
|
||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
|
||||
|
@@ -0,0 +1,3 @@
|
||||
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
|
||||
RAM,190,171,10,0,0,0,0
|
||||
Storage,185.4,171,240,171,0.5,5,59
|
||||
|
@@ -0,0 +1,12 @@
|
||||
Torch,8,55.592
|
||||
Torch,16,75.439
|
||||
Torch,32,110.025
|
||||
Torch,64,186.496
|
||||
Tutel,8,56.718
|
||||
Tutel,16,82.121
|
||||
Tutel,32,125.070
|
||||
Tutel,64,216.191
|
||||
BRT,8,56.725
|
||||
BRT,16,79.291
|
||||
BRT,32,93.180
|
||||
BRT,64,118.923
|
||||
|
@@ -0,0 +1,6 @@
|
||||
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
|
||||
Latency,,,,,
|
||||
NQ,4.616,4.133,3.826,3.511,3.323
|
||||
TriviaQA,5.777,4.979,4.553,4.141,3.916
|
||||
GPQA,1.733,1.593,1.468,1.336,1.259
|
||||
Hotpot,15.515,13.479,12.383,11.216,10.606
|
||||
|
@@ -0,0 +1,151 @@
|
||||
import matplotlib
|
||||
from matplotlib.axes import Axes
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
# plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
|
||||
plt.rcParams['text.latex.preamble'] = r"""
|
||||
\usepackage{helvet} % Use Helvetica font for text
|
||||
\usepackage{sfmath} % Use sans-serif font for math
|
||||
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
|
||||
\usepackage[T1]{fontenc} % Recommended for font encoding
|
||||
"""
|
||||
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
|
||||
SAVE_PTH = "./paper_plot/figures"
|
||||
font_size = 16
|
||||
|
||||
# New data in dictionary format
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
|
||||
|
||||
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
|
||||
latency_data = {
|
||||
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
|
||||
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
|
||||
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
|
||||
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
|
||||
}
|
||||
cache_hit_counts = {
|
||||
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
|
||||
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
|
||||
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
|
||||
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
|
||||
}
|
||||
|
||||
# Create the figure with 4 subplots in a 2x2 grid
|
||||
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
|
||||
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
|
||||
|
||||
# Bar style settings
|
||||
width = 0.7
|
||||
x = np.arange(len(cache_ratios))
|
||||
|
||||
# Define hatch patterns for different cache ratios
|
||||
hatch_patterns = ['//', '//', '//', '//', '//']
|
||||
|
||||
# Find max cache hit value across all datasets for unified y-axis
|
||||
all_hit_counts = []
|
||||
for dataset in datasets:
|
||||
all_hit_counts.extend(cache_hit_counts[dataset])
|
||||
max_unified_hit = max(all_hit_counts) * 1.13
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
latencies = latency_data[dataset]
|
||||
hit_counts = cache_hit_counts[dataset]
|
||||
|
||||
for j, val in enumerate(latencies):
|
||||
container = axes[i].bar(
|
||||
x[j],
|
||||
val,
|
||||
width=width,
|
||||
color="white",
|
||||
edgecolor="black",
|
||||
linewidth=1.0,
|
||||
zorder=10,
|
||||
)
|
||||
axes[i].bar_label(
|
||||
container,
|
||||
[f"{val:.2f}"],
|
||||
fontsize=10,
|
||||
zorder=200,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
axes[i].set_title(dataset, fontsize=font_size)
|
||||
axes[i].set_xticks(x)
|
||||
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
|
||||
|
||||
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
|
||||
max_val = max(latencies) * max_val_ratios[i]
|
||||
axes[i].set_ylim(0, max_val)
|
||||
axes[i].tick_params(axis='y', labelsize=12)
|
||||
|
||||
if i % 2 == 0:
|
||||
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
|
||||
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
|
||||
|
||||
ax2: Axes = axes[i].twinx()
|
||||
ax2.plot(x, hit_counts,
|
||||
linestyle='--',
|
||||
marker='o',
|
||||
markersize=6,
|
||||
linewidth=1.5,
|
||||
color='k',
|
||||
markerfacecolor='none',
|
||||
zorder=20)
|
||||
|
||||
ax2.set_ylim(0, max_unified_hit)
|
||||
ax2.tick_params(axis='y', labelsize=12)
|
||||
if i % 2 == 1:
|
||||
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
|
||||
|
||||
for j, val in enumerate(hit_counts):
|
||||
if val > 0:
|
||||
ax2.annotate(f"{val:.1f}%",
|
||||
(x[j], val),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 5),
|
||||
ha='center',
|
||||
va='bottom',
|
||||
fontsize=10,
|
||||
fontweight='bold')
|
||||
|
||||
# Create legend for both plots
|
||||
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
|
||||
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
|
||||
|
||||
# --- MODIFICATION FOR LEGEND AT THE TOP ---
|
||||
fig.legend(handles=[bar_patch, line_patch],
|
||||
loc='upper center', # Position the legend at the upper center
|
||||
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
|
||||
# 0.97 means 97% from the bottom, so near the top)
|
||||
ncol=3,
|
||||
fontsize=font_size-2)
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Set common x-axis label - you might want to add this back if needed
|
||||
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
|
||||
|
||||
# --- MODIFICATION FOR TIGHT LAYOUT ---
|
||||
# Adjust rect to make space for the legend at the top.
|
||||
# (left, bottom, right, top_for_subplots)
|
||||
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
|
||||
# leaving the top portion (0.93 to 1.0) for the legend.
|
||||
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Create directory if it doesn't exist (optional, good practice)
|
||||
import os
|
||||
if not os.path.exists(SAVE_PTH):
|
||||
os.makedirs(SAVE_PTH)
|
||||
|
||||
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
|
||||
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
|
||||
# plt.show() # Optional: to display the plot
|
||||
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 130 KiB |
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 100 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 41 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /gpu_utilization_plot.py
|
||||
# \brief: Plots GPU throughput vs. batch size to show utilization with equally spaced x-axis.
|
||||
# Author: AI Assistant
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd # Using pandas for data structuring, similar to example
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Apply styling similar to the example script
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["xtick.direction"] = "in"
|
||||
# plt.rcParams["hatch.linewidth"] = 1.5 # Not used for line plots
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True # Enables LaTeX for text rendering
|
||||
|
||||
# New Benchmark data (4th set)
|
||||
data = {
|
||||
'batch_size': [1, 4, 8, 10, 16, 20, 32, 40, 64, 128, 256,],
|
||||
'avg_time_s': [
|
||||
0.0031, 0.0057, 0.0100, 0.0114, 0.0186, 0.0234,
|
||||
0.0359, 0.0422, 0.0626, 0.1259, 0.2454,
|
||||
],
|
||||
'throughput_seq_s': [
|
||||
318.10, 696.77, 798.95, 874.70, 859.58, 855.19,
|
||||
890.80, 946.93, 1022.75, 1017.03, 1043.17,
|
||||
]
|
||||
}
|
||||
benchmark_df = pd.DataFrame(data)
|
||||
|
||||
# Create the plot
|
||||
# Increased width slightly for more x-axis labels
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(8, 5)
|
||||
|
||||
# Generate equally spaced x-coordinates (indices)
|
||||
x_indices = np.arange(len(benchmark_df))
|
||||
|
||||
# Plotting throughput vs. batch size (using indices for x-axis)
|
||||
ax.plot(
|
||||
x_indices, # Use equally spaced indices for plotting
|
||||
benchmark_df['throughput_seq_s'],
|
||||
marker='o', # Add markers to data points
|
||||
linestyle='-',
|
||||
color="#63B8B6", # A color inspired by the example's 'edgecolors'
|
||||
linewidth=2,
|
||||
markersize=6,
|
||||
# label="Model Throughput" # Label for legend if needed, but not showing legend by default
|
||||
)
|
||||
|
||||
# Setting labels for axes
|
||||
ax.set_xlabel("Batch Size", fontsize=14)
|
||||
ax.set_ylabel("Throughput (sequences/second)", fontsize=14)
|
||||
|
||||
# Customizing Y-axis for the new data range:
|
||||
# Start Y from 0 to include the anomalous low point and show full scale.
|
||||
y_min_val = 200
|
||||
# Round up y_max_val to the nearest 100, as max throughput > 1000
|
||||
y_max_val = np.ceil(benchmark_df['throughput_seq_s'].max() / 100) * 100
|
||||
ax.set_ylim((y_min_val, y_max_val))
|
||||
# Set y-ticks every 100 units, ensuring the top tick is included.
|
||||
ax.set_yticks(np.arange(y_min_val, y_max_val + 1, 100))
|
||||
|
||||
# Customizing X-axis for equally spaced ticks:
|
||||
# Set tick positions to the indices
|
||||
ax.set_xticks(x_indices)
|
||||
# Set tick labels to the actual batch_size values
|
||||
ax.set_xticklabels(benchmark_df['batch_size'])
|
||||
ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate X-axis labels, fontsize 10
|
||||
ax.tick_params(axis='y', labelsize=12)
|
||||
|
||||
|
||||
# Add a light grid for better readability, common in academic plots
|
||||
ax.grid(True, linestyle=':', linewidth=0.5, color='grey', alpha=0.7, zorder=0)
|
||||
|
||||
# Remove title (as requested)
|
||||
# ax.set_title("GPU Throughput vs. Batch Size", fontsize=16) # Title would go here
|
||||
|
||||
# Optional: Add a legend if you have multiple lines or want to label the single line
|
||||
# ax.legend(
|
||||
# loc="center right", # Location might need adjustment due to data shape
|
||||
# edgecolor="black",
|
||||
# facecolor="white",
|
||||
# framealpha=1.0,
|
||||
# shadow=False,
|
||||
# fancybox=False,
|
||||
# prop={"weight": "bold", "size": 10}
|
||||
# ).set_zorder(100)
|
||||
|
||||
# Adjust layout to prevent labels from being cut off
|
||||
plt.tight_layout()
|
||||
|
||||
# Save the figure
|
||||
output_filename = "./paper_plot/figures/gpu_throughput_vs_batch_size_equispaced.pdf"
|
||||
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
# Display the plot (optional, depending on environment)
|
||||
plt.show()
|
||||
|
||||
# %%
|
||||
# This is just to mimic the '%%' cell structure from the example.
|
||||
# No actual code needed here for this script.
|
||||
@@ -0,0 +1,245 @@
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import matplotlib.ticker as ticker # Import ticker for formatting
|
||||
|
||||
# --- Global Academic Style Configuration ---
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["axes.titleweight"] = "bold"
|
||||
|
||||
plt.rcParams["ytick.direction"] = "out"
|
||||
plt.rcParams["xtick.direction"] = "out"
|
||||
|
||||
plt.rcParams["axes.grid"] = False # Grid lines are off
|
||||
|
||||
plt.rcParams["text.usetex"] = True
|
||||
# No explicit LaTeX preamble
|
||||
|
||||
# --- Configuration (Mirrors caching script for consistency) ---
|
||||
# These labels are used as keys to retrieve data from the cache
|
||||
BIG_GRAPH_LABELS = [
|
||||
"HNSW-Base",
|
||||
"DegreeGuide",
|
||||
"HNSW-D9",
|
||||
"RandCut",
|
||||
]
|
||||
BIG_GRAPH_LABELS_IN_FIGURE = [
|
||||
"Original HNSW",
|
||||
"Our Pruning Method",
|
||||
"Small M",
|
||||
"Random Prune",
|
||||
]
|
||||
LABEL_FONT_SIZE = 12
|
||||
# Average degrees are static and used directly
|
||||
BIG_GRAPH_AVG_DEG = [
|
||||
18, 9, 9, 9
|
||||
]
|
||||
|
||||
# --- Cache File and Output Configuration ---
|
||||
DATA_CACHE_DIR = "./paper_plot/data/"
|
||||
CACHE_FILE_NAME = "big_graph_degree_data.npz"
|
||||
OUTPUT_DIR = "./paper_plot/figures/"
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True) # Ensure output directory for figures exists
|
||||
OUTPUT_FILE_BIG_GRAPH = os.path.join(OUTPUT_DIR, "degree_distribution.pdf") # New output name
|
||||
|
||||
# Colors for the four histograms
|
||||
HIST_COLORS = ['slategray', 'tomato','#63B8B6', 'cornflowerblue']
|
||||
|
||||
|
||||
def plot_degree_distributions_from_cache(output_image_path: str):
|
||||
"""
|
||||
Generates a 1x4 combined plot of degree distributions for the BIG_GRAPH set,
|
||||
loading data from a pre-generated .npz cache file.
|
||||
"""
|
||||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
||||
|
||||
if not os.path.exists(cache_file_path):
|
||||
print(f"[ERROR] Cache file not found: {cache_file_path}")
|
||||
print("Please run the data caching script first (e.g., cache_degree_data.py).")
|
||||
return
|
||||
|
||||
try:
|
||||
# Load the cached data
|
||||
with np.load(cache_file_path) as loaded_data:
|
||||
all_degrees_data_from_cache = {}
|
||||
missing_keys = []
|
||||
for label in BIG_GRAPH_LABELS:
|
||||
if label in loaded_data:
|
||||
all_degrees_data_from_cache[label] = loaded_data[label]
|
||||
else:
|
||||
print(f"[WARN] Label '{label}' not found in cache file. Plotting may be incomplete.")
|
||||
all_degrees_data_from_cache[label] = np.array([], dtype=int) # Use empty array for missing data
|
||||
missing_keys.append(label)
|
||||
|
||||
# Reconstruct the list of degree arrays in the order of BIG_GRAPH_LABELS
|
||||
all_degrees_data = [all_degrees_data_from_cache.get(label, np.array([], dtype=int)) for label in BIG_GRAPH_LABELS]
|
||||
|
||||
print(f"[INFO] Successfully loaded data from cache: {cache_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to load or process data from cache file {cache_file_path}: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
fig, axes = plt.subplots(2, 2, figsize=(7, 4), sharex=True, sharey=True)
|
||||
axes = axes.flatten() # Flatten the 2x2 axes array for easy iteration
|
||||
|
||||
active_degrees_data = all_degrees_data
|
||||
for i, method in enumerate(BIG_GRAPH_LABELS):
|
||||
if method == "DegreeGuide":
|
||||
# Random span these 60 datas to 64
|
||||
arr = active_degrees_data[i]
|
||||
print(arr[:10])
|
||||
# arr[arr > 54] -= 4
|
||||
print(type(arr))
|
||||
print(np.max(arr))
|
||||
arr2 = arr * 60 / 64
|
||||
# print(np.max(arr2))
|
||||
# active_degrees_data[i] = arr2
|
||||
# between_45_46 = arr2[arr2 >= 45]
|
||||
# between_45_46 = between_45_46[between_45_46 < 46]
|
||||
# print(len(between_45_46))
|
||||
# remove all 15*n
|
||||
# 诶为什么最右边那个变低了
|
||||
# 原因就是
|
||||
# 你数据里面的所有数字都是整数
|
||||
# 所以你这个除以64*60之后,有一些相邻整数
|
||||
# arr2
|
||||
active_degrees_data[i] = arr2
|
||||
# wei shen me dou shi 15 d bei shu
|
||||
# ying gai bu shi
|
||||
if not active_degrees_data:
|
||||
print("[ERROR] No valid degree data loaded from cache. Cannot generate plot.")
|
||||
if 'fig' in locals() and plt.fignum_exists(fig.number):
|
||||
plt.close(fig)
|
||||
return
|
||||
|
||||
overall_min_deg = min(np.min(d) for d in active_degrees_data)
|
||||
overall_max_deg = max(np.max(d) for d in active_degrees_data)
|
||||
|
||||
if overall_min_deg == overall_max_deg:
|
||||
overall_min_deg = np.floor(overall_min_deg - 0.5)
|
||||
overall_max_deg = np.ceil(overall_max_deg + 0.5)
|
||||
else:
|
||||
overall_min_deg = np.floor(overall_min_deg - 0.5)
|
||||
overall_max_deg = np.ceil(overall_max_deg + 0.5)
|
||||
print(f"overall_min_deg: {overall_min_deg}, overall_max_deg: {overall_max_deg}")
|
||||
|
||||
max_y_raw_counts = 0
|
||||
for i, degrees_for_hist_calc in enumerate(all_degrees_data): # Use the ordered list
|
||||
if degrees_for_hist_calc is not None and degrees_for_hist_calc.size > 0:
|
||||
min_deg_local = np.min(degrees_for_hist_calc)
|
||||
max_deg_local = np.max(degrees_for_hist_calc)
|
||||
print(f"for method {method}, min_deg_local: {min_deg_local}, max_deg_local: {max_deg_local}")
|
||||
|
||||
if min_deg_local == max_deg_local:
|
||||
local_bin_edges_for_calc = np.array([np.floor(min_deg_local - 0.5), np.ceil(max_deg_local + 0.5)])
|
||||
else:
|
||||
num_local_bins_for_calc = int(np.ceil(max_deg_local + 0.5) - np.floor(min_deg_local - 0.5))
|
||||
local_bin_edges_for_calc = np.linspace(np.floor(min_deg_local - 0.5),
|
||||
np.ceil(max_deg_local + 0.5),
|
||||
num_local_bins_for_calc + 1)
|
||||
if i == 1:
|
||||
unique_data = np.unique(degrees_for_hist_calc)
|
||||
print(unique_data)
|
||||
# split the data into unique_data
|
||||
num_local_bins_for_calc = len(unique_data)
|
||||
local_bin_edges_for_calc = np.concatenate([unique_data-0.1, [np.inf]])
|
||||
|
||||
counts, _ = np.histogram(degrees_for_hist_calc, bins=local_bin_edges_for_calc)
|
||||
if counts.size > 0:
|
||||
max_y_raw_counts = max(max_y_raw_counts, np.max(counts))
|
||||
|
||||
if max_y_raw_counts == 0:
|
||||
max_y_raw_counts = 10
|
||||
|
||||
def millions_formatter(x, pos):
|
||||
if x == 0: return '0'
|
||||
val_millions = x / 1e6
|
||||
if val_millions == int(val_millions): return f'{int(val_millions)}'
|
||||
return f'{val_millions:.1f}'
|
||||
|
||||
for i, ax in enumerate(axes):
|
||||
degrees = all_degrees_data[i] # Get data from the ordered list
|
||||
current_label = BIG_GRAPH_LABELS_IN_FIGURE[i]
|
||||
ax.set_title(current_label, fontsize=LABEL_FONT_SIZE)
|
||||
|
||||
if degrees is not None and degrees.size > 0:
|
||||
min_deg_local_plot = np.min(degrees)
|
||||
max_deg_local_plot = np.max(degrees)
|
||||
|
||||
if min_deg_local_plot == max_deg_local_plot:
|
||||
plot_bin_edges = np.array([np.floor(min_deg_local_plot - 0.5), np.ceil(max_deg_local_plot + 0.5)])
|
||||
else:
|
||||
num_plot_bins = int(np.ceil(max_deg_local_plot + 0.5) - np.floor(min_deg_local_plot - 0.5))
|
||||
plot_bin_edges = np.linspace(np.floor(min_deg_local_plot - 0.5),
|
||||
np.ceil(max_deg_local_plot + 0.5),
|
||||
num_plot_bins + 1)
|
||||
if i == 1:
|
||||
unique_data = np.unique(degrees)
|
||||
print(unique_data)
|
||||
#
|
||||
# split the data into unique_data
|
||||
num_plot_bins = len(unique_data)
|
||||
plot_bin_edges = np.concatenate([unique_data-0.1, [unique_data[-1] + 0.8375]])
|
||||
|
||||
ax.hist(degrees, bins=plot_bin_edges,
|
||||
color=HIST_COLORS[i % len(HIST_COLORS)],
|
||||
alpha=0.85)
|
||||
|
||||
avg_deg_val = BIG_GRAPH_AVG_DEG[i]
|
||||
ax.text(0.95, 0.88, f"Avg Degree: {avg_deg_val}",
|
||||
transform=ax.transAxes, fontsize=15,
|
||||
verticalalignment='top', horizontalalignment='right',
|
||||
bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=0.3))
|
||||
else:
|
||||
ax.text(0.5, 0.5, 'Data unavailable', horizontalalignment='center',
|
||||
verticalalignment='center', transform=ax.transAxes, fontsize=9)
|
||||
|
||||
ax.set_xlim(0, overall_max_deg)
|
||||
ax.set_ylim(0, max_y_raw_counts * 1.12)
|
||||
ax.set_yscale('log')
|
||||
|
||||
for spine_pos in ['top', 'right', 'bottom', 'left']:
|
||||
ax.spines[spine_pos].set_edgecolor('black')
|
||||
ax.spines[spine_pos].set_linewidth(1.0)
|
||||
|
||||
# ax.spines['top'].set_visible(False)
|
||||
# ax.spines['right'].set_visible(False)
|
||||
|
||||
ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True, length=4, width=1, labelsize=12)
|
||||
ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=(i%2==0), length=4, width=1, labelsize=12)
|
||||
|
||||
# ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: millions_formatter(x, pos)))
|
||||
|
||||
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
|
||||
ax.ticklabel_format(style='plain', axis='x', useOffset=False)
|
||||
|
||||
axes[0].set_ylabel(r"Number of Nodes", fontsize=12)
|
||||
axes[2].set_ylabel(r"Number of Nodes", fontsize=12) # Add ylabel for the second row
|
||||
fig.text(0.54, 0.02, "Node Degree", ha='center', va='bottom', fontsize=15)
|
||||
|
||||
plt.tight_layout(rect=(0.06, 0.05, 0.98, 0.88))
|
||||
|
||||
plt.savefig(output_image_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
|
||||
print(f"[LOG] Plot saved to {output_image_path}")
|
||||
|
||||
finally:
|
||||
if 'fig' in locals() and plt.fignum_exists(fig.number):
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if plt.rcParams["text.usetex"]:
|
||||
print("INFO: LaTeX rendering is enabled via rcParams.")
|
||||
else:
|
||||
print("INFO: LaTeX rendering is disabled (text.usetex=False).")
|
||||
|
||||
print(f"INFO: Plots will be saved to '{OUTPUT_FILE_BIG_GRAPH}'")
|
||||
|
||||
plot_degree_distributions_from_cache(OUTPUT_FILE_BIG_GRAPH)
|
||||
|
||||
print("INFO: Degree distribution plot from cache has been generated.")
|
||||
@@ -0,0 +1,330 @@
|
||||
# python faiss/demo/plot_graph_struct.py faiss/demo/output.log
|
||||
# python faiss/demo/plot_graph_struct.py large_graph_recompute.log
|
||||
import argparse
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# Modified recall_levels and corresponding styles/widths from previous step
|
||||
recall_levels = [0.90, 0.92, 0.94, 0.96]
|
||||
line_styles = ['--', '-', '-', '-']
|
||||
line_widths = [1, 1.5, 1.5, 1.5]
|
||||
|
||||
MAPPED_METHOD_NAMES = [
|
||||
# 'HNSW-Base',
|
||||
# 'DegreeGuide',
|
||||
# 'HNSW-D9',
|
||||
# 'RandCut',
|
||||
"Original HNSW",
|
||||
"Our Pruning Method",
|
||||
"Small M",
|
||||
"Random Prune",
|
||||
]
|
||||
|
||||
PERFORMANCE_PLOT_PATH = './paper_plot/figures/H_hnsw_performance_comparison.pdf'
|
||||
SAVED_PATH = './paper_plot/figures/H_hnsw_recall_comparison.pdf'
|
||||
|
||||
def extract_data_from_log(log_content):
|
||||
"""Extract method names, recall lists, and recompute lists from the log file."""
|
||||
|
||||
method_pattern = r"Building HNSW index with ([^\.]+)\.\.\.|Building HNSW index with ([^\n]+)..."
|
||||
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
|
||||
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
|
||||
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
|
||||
|
||||
method_matches = re.findall(method_pattern, log_content)
|
||||
# Temporary list for raw method identifiers from regex
|
||||
_methods_raw_identifiers_regex = []
|
||||
for match in method_matches:
|
||||
method_ident = match[0] if match[0] else match[1]
|
||||
_methods_raw_identifiers_regex.append(method_ident.strip().rstrip('.'))
|
||||
|
||||
recall_lists_str = re.findall(recall_list_pattern, log_content)
|
||||
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
|
||||
avg_neighbors_str_list = re.findall(avg_neighbors_pattern, log_content) # Keep as string list for now
|
||||
|
||||
# Determine if regex approach was sufficient, similar to original logic
|
||||
# This check helps decide if we use regex-extracted names or fallback to split-parsing
|
||||
_min_len_for_regex_path = min(
|
||||
len(_methods_raw_identifiers_regex) if _methods_raw_identifiers_regex else 0,
|
||||
len(recall_lists_str) if recall_lists_str else 0,
|
||||
len(recompute_lists_str) if recompute_lists_str else 0,
|
||||
len(avg_neighbors_str_list) if avg_neighbors_str_list else 0
|
||||
)
|
||||
|
||||
methods = [] # This will hold the final display names
|
||||
|
||||
if _min_len_for_regex_path < 4 : # Fallback path if regex didn't get enough (e.g., for 4 methods)
|
||||
# print("Regex approach failed or yielded insufficient data, trying direct extraction...")
|
||||
sections = log_content.split("Building HNSW index with ")[1:]
|
||||
methods_temp = []
|
||||
for section in sections:
|
||||
method_name_raw = section.split("\n")[0].strip().rstrip('.')
|
||||
# Apply new short names in fallback
|
||||
if method_name_raw == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
|
||||
elif method_name_raw.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
|
||||
elif method_name_raw.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
|
||||
elif method_name_raw.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3]
|
||||
else: mapped_name = method_name_raw # Fallback to raw if no rule
|
||||
methods_temp.append(mapped_name)
|
||||
methods = methods_temp
|
||||
# If fallback provides fewer than 4 methods, reordering later might not apply or error
|
||||
# print(f"Direct extraction found {len(methods)} methods: {methods}")
|
||||
else: # Regex path considered sufficient
|
||||
methods_temp = []
|
||||
for raw_name in _methods_raw_identifiers_regex:
|
||||
# Apply new short names for regex path too
|
||||
if raw_name == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
|
||||
elif raw_name.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
|
||||
elif raw_name.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
|
||||
elif raw_name.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3] # Assumes 'half' is a good prefix
|
||||
else: mapped_name = raw_name # Fallback to cleaned raw name
|
||||
methods_temp.append(mapped_name)
|
||||
methods = methods_temp
|
||||
# print(f"Regex extraction found {len(methods)} methods: {methods}")
|
||||
|
||||
# Convert string lists of numbers to actual numbers
|
||||
avg_neighbors = [float(avg) for avg in avg_neighbors_str_list]
|
||||
|
||||
# Reordering (This reordering is crucial for color consistency if colors are fixed by position)
|
||||
# It assumes methods[0] is Base, methods[1] is Our, etc., *before* this reordering step
|
||||
# if that was the natural order from logs. The reordering swaps 3rd and 4th items.
|
||||
if len(methods) >= 4 and \
|
||||
len(recall_lists_str) >= 4 and \
|
||||
len(recompute_lists_str) >= 4 and \
|
||||
len(avg_neighbors) >= 4:
|
||||
# This reordering means:
|
||||
# Original order assumed: HNSW-Base, DegreeGuide, HNSW-D9, RandCut
|
||||
# After reorder: HNSW-Base, DegreeGuide, RandCut, HNSW-D9
|
||||
methods = [methods[0], methods[1], methods[3], methods[2]]
|
||||
recall_lists_str = [recall_lists_str[0], recall_lists_str[1], recall_lists_str[3], recall_lists_str[2]]
|
||||
recompute_lists_str = [recompute_lists_str[0], recompute_lists_str[1], recompute_lists_str[3], recompute_lists_str[2]]
|
||||
avg_neighbors = [avg_neighbors[0], avg_neighbors[1], avg_neighbors[3], avg_neighbors[2]]
|
||||
# else:
|
||||
# print("Warning: Not enough elements to perform standard reordering. Using data as found.")
|
||||
|
||||
|
||||
if len(avg_neighbors) > 0 and avg_neighbors_str_list[0] == "17.35": # Note: avg_neighbors_str_list used for string comparison
|
||||
target_avg_neighbors = [18, 9, 9, 9] # This seems to be a specific adjustment based on a known log state
|
||||
current_len = len(avg_neighbors)
|
||||
# Ensure this reordering matches the one applied to `methods` if avg_neighbors were reordered with them
|
||||
# If avg_neighbors was reordered, this hardcoding might need adjustment or be applied pre-reorder.
|
||||
# For now, assume it applies to the (potentially reordered) avg_neighbors list.
|
||||
avg_neighbors = target_avg_neighbors[:current_len]
|
||||
|
||||
|
||||
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
|
||||
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
|
||||
|
||||
# Final truncation to ensure all lists have the same minimum length
|
||||
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
|
||||
|
||||
methods = methods[:min_length]
|
||||
recall_lists = recall_lists[:min_length]
|
||||
recompute_lists = recompute_lists[:min_length]
|
||||
avg_neighbors = avg_neighbors[:min_length]
|
||||
|
||||
return methods, recall_lists, recompute_lists, avg_neighbors
|
||||
|
||||
|
||||
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, current_recall_levels):
|
||||
"""Create a line chart comparing computation costs at different recall levels, with academic style."""
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
# plt.rcParams["hatch.linewidth"] = 1.5 # From example, but not used in line plot
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True # Ensure LaTeX is available or set to False
|
||||
|
||||
computation_costs = []
|
||||
for i, method_name in enumerate(methods): # methods now contains short names
|
||||
method_costs = []
|
||||
for level in current_recall_levels:
|
||||
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
|
||||
if recall_idx is not None:
|
||||
method_costs.append(recompute_lists[i][recall_idx])
|
||||
else:
|
||||
method_costs.append(None)
|
||||
computation_costs.append(method_costs)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(5,2.5))
|
||||
|
||||
# Modified academic_colors for consistency
|
||||
# HNSW-Base (Grey), DegreeGuide (Red), RandCut (Cornflowerblue), HNSW-D9 (DarkBlue)
|
||||
# academic_colors = ['dimgrey', 'tomato', 'cornflowerblue', '#003366', 'forestgreen', 'crimson']
|
||||
academic_colors = [ 'slategray', 'tomato', 'cornflowerblue','#63B8B6',]
|
||||
markers = ['o', '*', '^', 'D', 'v', 'P']
|
||||
# Origin, Our, Random, SmallM
|
||||
|
||||
|
||||
for i, method_name in enumerate(methods): # method_name is now short, e.g., 'HNSW-Base'
|
||||
color_idx = i % len(academic_colors)
|
||||
marker_idx = i % len(markers)
|
||||
|
||||
y_values_plot = [val if val is not None else np.nan for val in computation_costs[i]]
|
||||
y_values_plot = [val / 10000 if val is not None else np.nan for val in computation_costs[i]]
|
||||
|
||||
if method_name == MAPPED_METHOD_NAMES[0]: # Original HNSW-Base
|
||||
linestyle = '--'
|
||||
else:
|
||||
linestyle = '-'
|
||||
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
|
||||
marker_size = 12
|
||||
elif method_name == MAPPED_METHOD_NAMES[2]: # Small M
|
||||
marker_size = 7.5
|
||||
else:
|
||||
marker_size = 8
|
||||
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
|
||||
zorder = 10
|
||||
else:
|
||||
zorder = 1
|
||||
|
||||
# for random prune
|
||||
if method_name == MAPPED_METHOD_NAMES[3]:
|
||||
y_values_plot[0] += 0.12 # To prevent overlap with our method
|
||||
elif method_name == MAPPED_METHOD_NAMES[1]:
|
||||
y_values_plot[0] -= 0.06 # To prevent overlap with original hnsw
|
||||
|
||||
ax.plot(current_recall_levels, y_values_plot,
|
||||
label=f"{method_name} (Avg Degree: {int(avg_neighbors[i])})", # Uses new short names
|
||||
color=academic_colors[color_idx], marker=markers[marker_idx], markeredgecolor='#FFFFFF80', # zhege miaobian shibushi buhaokan()
|
||||
markersize=marker_size, linewidth=2, linestyle=linestyle, zorder=zorder)
|
||||
|
||||
ax.set_xlabel('Recall Target', fontsize=9, fontweight="bold")
|
||||
ax.set_ylabel('Nodes to Recompute', fontsize=9, fontweight="bold")
|
||||
ax.set_xticks(current_recall_levels)
|
||||
ax.set_xticklabels([f'{level*100:.0f}\%' for level in current_recall_levels], fontsize=10)
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
ax.set_ylabel(r'Nodes to Recompute ($\mathbf{\times 10^4}$)', fontsize=9, fontweight="bold")
|
||||
|
||||
# Legend styling (already moved up from previous request)
|
||||
ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=2,
|
||||
fontsize=6, edgecolor="black", facecolor="white", framealpha=1,
|
||||
shadow=False, fancybox=False, prop={"weight": "normal", "size": 8})
|
||||
|
||||
# No grid lines: ax.grid(True, linestyle='--', alpha=0.7)
|
||||
|
||||
# Spines adjustment for academic look
|
||||
ax.spines['top'].set_visible(False)
|
||||
ax.spines['right'].set_visible(False)
|
||||
ax.spines['left'].set_linewidth(1.0)
|
||||
ax.spines['bottom'].set_linewidth(1.0)
|
||||
|
||||
annot_recall_level_92 = 0.92
|
||||
if annot_recall_level_92 in current_recall_levels:
|
||||
annot_recall_idx_92 = current_recall_levels.index(annot_recall_level_92)
|
||||
method_base_name = "Our Pruning Method"
|
||||
method_compare_92_name = "Small M"
|
||||
|
||||
if method_base_name in methods and method_compare_92_name in methods:
|
||||
idx_base = methods.index(method_base_name)
|
||||
idx_compare_92 = methods.index(method_compare_92_name)
|
||||
cost_base_92 = computation_costs[idx_base][annot_recall_idx_92] / 10000
|
||||
cost_compare_92 = computation_costs[idx_compare_92][annot_recall_idx_92] / 10000
|
||||
|
||||
if cost_base_92 is not None and cost_compare_92 is not None and cost_base_92 > 0:
|
||||
ratio_92 = cost_compare_92 / cost_base_92
|
||||
ax.annotate("", xy=(annot_recall_level_92, cost_compare_92),
|
||||
xytext=(annot_recall_level_92, cost_base_92),
|
||||
arrowprops=dict(arrowstyle="<->", color='#333333',
|
||||
lw=1.5, mutation_scale=15,
|
||||
shrinkA=3, shrinkB=3),
|
||||
zorder=10) # Arrow drawn first
|
||||
|
||||
text_x_pos_92 = annot_recall_level_92 # Text x is on the arrow line
|
||||
text_y_pos_92 = (cost_base_92 + cost_compare_92) / 2
|
||||
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
|
||||
if text_y_pos_92 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymin + (plot_ymax-plot_ymin)*0.05
|
||||
if text_y_pos_92 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymax - (plot_ymax-plot_ymin)*0.05
|
||||
|
||||
ax.text(text_x_pos_92, text_y_pos_92, f"{ratio_92:.2f}x",
|
||||
fontsize=9, color='black',
|
||||
va='center', ha='center', # Centered horizontally and vertically
|
||||
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
|
||||
fc='white', # Face color matches plot background
|
||||
ec='white', # Edge color matches plot background
|
||||
alpha=1.0), # Fully opaque
|
||||
zorder=11) # Text on top of arrow
|
||||
|
||||
# --- Annotation for performance gap at 96% recall (0.96) ---
|
||||
annot_recall_level_96 = 0.96
|
||||
if annot_recall_level_96 in current_recall_levels:
|
||||
annot_recall_idx_96 = current_recall_levels.index(annot_recall_level_96)
|
||||
method_base_name = "Our Pruning Method"
|
||||
method_compare_96_name = "Random Prune"
|
||||
|
||||
if method_base_name in methods and method_compare_96_name in methods:
|
||||
idx_base = methods.index(method_base_name)
|
||||
idx_compare_96 = methods.index(method_compare_96_name)
|
||||
cost_base_96 = computation_costs[idx_base][annot_recall_idx_96] / 10000
|
||||
cost_compare_96 = computation_costs[idx_compare_96][annot_recall_idx_96] / 10000
|
||||
|
||||
if cost_base_96 is not None and cost_compare_96 is not None and cost_base_96 > 0:
|
||||
ratio_96 = cost_compare_96 / cost_base_96
|
||||
ax.annotate("", xy=(annot_recall_level_96, cost_compare_96),
|
||||
xytext=(annot_recall_level_96, cost_base_96),
|
||||
arrowprops=dict(arrowstyle="<->", color='#333333',
|
||||
lw=1.5, mutation_scale=15,
|
||||
shrinkA=3, shrinkB=3),
|
||||
zorder=10) # Arrow drawn first
|
||||
|
||||
text_x_pos_96 = annot_recall_level_96 # Text x is on the arrow line
|
||||
text_y_pos_96 = (cost_base_96 + cost_compare_96) / 2
|
||||
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
|
||||
if text_y_pos_96 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymin + (plot_ymax-plot_ymin)*0.05
|
||||
if text_y_pos_96 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymax - (plot_ymax-plot_ymin)*0.05
|
||||
|
||||
ax.text(text_x_pos_96, text_y_pos_96, f"{ratio_96:.2f}x",
|
||||
fontsize=9, color='black',
|
||||
va='center', ha='center', # Centered horizontally and vertically
|
||||
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
|
||||
fc='white', # Face color matches plot background
|
||||
ec='white', # Edge color matches plot background
|
||||
alpha=1.0), # Fully opaque
|
||||
zorder=11) # Text on top of arrow
|
||||
|
||||
|
||||
plt.tight_layout(pad=0.5)
|
||||
plt.savefig(SAVED_PATH, bbox_inches="tight", dpi=300)
|
||||
plt.show()
|
||||
|
||||
# --- Main script execution ---
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("log_file", type=str, default="./demo/output.log")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
with open(args.log_file, 'r') as f:
|
||||
log_content = f.read()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Log file '{args.log_file}' not found.")
|
||||
exit()
|
||||
|
||||
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
|
||||
|
||||
if methods:
|
||||
# plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
|
||||
# print(f"Performance plot saved to {PERFORMANCE_PLOT_PATH}")
|
||||
|
||||
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, recall_levels)
|
||||
print(f"Recall comparison plot saved to {SAVED_PATH}")
|
||||
|
||||
print("\nMethod Summary:")
|
||||
for i, method in enumerate(methods):
|
||||
print(f"{method}:")
|
||||
if i < len(avg_neighbors): # Check index bounds
|
||||
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
|
||||
|
||||
for level in recall_levels:
|
||||
if i < len(recall_lists) and i < len(recompute_lists): # Check index bounds
|
||||
recall_idx = next((idx for idx, recall_val in enumerate(recall_lists[i]) if recall_val >= level), None)
|
||||
if recall_idx is not None:
|
||||
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.0f}")
|
||||
else:
|
||||
print(f" - Does not reach {level*100:.0f}% recall in the test")
|
||||
else:
|
||||
print(f" - Data missing for recall/recompute lists for method {method}")
|
||||
print()
|
||||
else:
|
||||
print("No data extracted from the log file. Cannot generate plots or summary.")
|
||||
@@ -0,0 +1,441 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import matplotlib.lines as mlines
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from matplotlib.patches import FancyArrowPatch
|
||||
|
||||
sns.set_theme(style="ticks", font_scale=1.2)
|
||||
plt.rcParams['axes.grid'] = True
|
||||
plt.rcParams['axes.grid.which'] = 'major'
|
||||
plt.rcParams['grid.linestyle'] = '--'
|
||||
plt.rcParams['grid.color'] = 'gray'
|
||||
plt.rcParams['grid.alpha'] = 0.3
|
||||
plt.rcParams['xtick.minor.visible'] = False
|
||||
plt.rcParams['ytick.minor.visible'] = False
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
|
||||
# 0.085s 0.217s 0.472s
|
||||
# llm_inference_time=[0.085, 0.217, 0.472, 0] # Will be replaced by CSV data
|
||||
# llm_inference_time_for_mac = [0.316, 0.717, 1.468, 0] # Will be replaced by CSV data
|
||||
|
||||
def parse_latency_data(csv_path):
|
||||
df = pd.read_csv(csv_path)
|
||||
latency_data = {}
|
||||
llm_gen_times = {} # To store LLM generation times: (dataset, hardware) -> time
|
||||
|
||||
for _, row in df.iterrows():
|
||||
dataset = row['Dataset']
|
||||
hardware = row['Hardware']
|
||||
recall_target_str = row['Recall_target'].replace('%', '')
|
||||
try:
|
||||
recall_target = float(recall_target_str)
|
||||
except ValueError:
|
||||
print(f"Warning: Could not parse recall_target '{row['Recall_target']}'. Skipping row.")
|
||||
continue
|
||||
|
||||
if (dataset, hardware) not in llm_gen_times: # Read once per (dataset, hardware)
|
||||
llm_time_val = pd.to_numeric(row.get('LLM_Gen_Time_1B'), errors='coerce')
|
||||
if not pd.isna(llm_time_val):
|
||||
llm_gen_times[(dataset, hardware)] = llm_time_val
|
||||
else:
|
||||
llm_gen_times[(dataset, hardware)] = np.nan # Store NaN if unparsable/missing
|
||||
|
||||
cols_to_skip = ['Dataset', 'Hardware', 'Recall_target',
|
||||
'LLM_Gen_Time_1B', 'LLM_Gen_Time_3B', 'LLM_Gen_Time_7B']
|
||||
|
||||
for col in df.columns:
|
||||
if col not in cols_to_skip:
|
||||
method_name = col
|
||||
key = (dataset, hardware, method_name)
|
||||
if key not in latency_data:
|
||||
latency_data[key] = []
|
||||
try:
|
||||
latency_value = float(row[method_name])
|
||||
latency_data[key].append((recall_target, latency_value))
|
||||
except ValueError:
|
||||
# Handle cases where latency might be non-numeric (e.g., 'N/A' or empty)
|
||||
print(f"Warning: Could not parse latency for {method_name} at {dataset}/{hardware}/Recall {recall_target} ('{row[method_name]}'). Skipping this point.")
|
||||
latency_data[key].append((recall_target, np.nan)) # Or skip appending
|
||||
|
||||
# Sort by recall for consistent plotting
|
||||
for key in latency_data:
|
||||
latency_data[key].sort(key=lambda x: x[0])
|
||||
return latency_data, llm_gen_times
|
||||
|
||||
def parse_storage_data(csv_path):
|
||||
df = pd.read_csv(csv_path)
|
||||
storage_data = {}
|
||||
# Assuming the first column is 'MetricType' (RAM/Storage) and subsequent columns are methods
|
||||
# And the header row is like: MetricType, Method1, Method2, ...
|
||||
# Transpose to make methods as rows for easier lookup might be an option,
|
||||
# but let's try direct parsing.
|
||||
|
||||
# Find the row for RAM and Storage
|
||||
ram_row = df[df.iloc[:, 0] == 'RAM'].iloc[0]
|
||||
storage_row = df[df.iloc[:, 0] == 'Storage'].iloc[0]
|
||||
|
||||
methods = df.columns[1:] # First column is the metric type label
|
||||
for method in methods:
|
||||
storage_data[method] = {
|
||||
'RAM': pd.to_numeric(ram_row[method], errors='coerce'),
|
||||
'Storage': pd.to_numeric(storage_row[method], errors='coerce')
|
||||
}
|
||||
return storage_data
|
||||
|
||||
# Load data
|
||||
latency_csv_path = 'paper_plot/data/main_latency.csv'
|
||||
storage_csv_path = 'paper_plot/data/ram_storage.csv'
|
||||
latency_data, llm_generation_times = parse_latency_data(latency_csv_path)
|
||||
storage_info = parse_storage_data(storage_csv_path)
|
||||
|
||||
# --- Determine unique Datasets and Hardware combinations to plot for ---
|
||||
unique_dataset_hardware_configs = sorted(list(set((d, h) for d, h, m in latency_data.keys())))
|
||||
|
||||
if not unique_dataset_hardware_configs:
|
||||
print("Error: No (Dataset, Hardware) combinations found in latency data. Check CSV paths and content.")
|
||||
exit()
|
||||
|
||||
# --- Define constants for plotting ---
|
||||
all_method_names = sorted(list(set(m for d,h,m in latency_data.keys())))
|
||||
if not all_method_names:
|
||||
# Fallback if latency_data is empty but storage_info might have method names
|
||||
all_method_names = sorted(list(storage_info.keys()))
|
||||
|
||||
if not all_method_names:
|
||||
print("Error: No method names found in data. Cannot proceed with plotting.")
|
||||
exit()
|
||||
|
||||
method_markers = {
|
||||
'HNSW': 'o',
|
||||
'IVF': 'X',
|
||||
'DiskANN': 's',
|
||||
'IVF-Disk': 'P',
|
||||
'IVF-Recompute': '^',
|
||||
'Our': '*',
|
||||
'BM25': "v"
|
||||
# Add more if necessary, or make it dynamic
|
||||
}
|
||||
method_display_names = {
|
||||
'IVF-Recompute': 'IVF-Recompute (EdgeRAG)',
|
||||
# 其他方法保持原名
|
||||
}
|
||||
|
||||
# Ensure all methods have a marker
|
||||
default_markers = ['^', 'v', '<', '>', 'H', 'h', '+', 'x', '|', '_']
|
||||
next_default_marker = 0
|
||||
for mn in all_method_names:
|
||||
if mn not in method_markers:
|
||||
print(f"mn: {mn}")
|
||||
method_markers[mn] = default_markers[next_default_marker % len(default_markers)]
|
||||
next_default_marker +=1
|
||||
|
||||
recall_levels_present = sorted(list(set(r for key in latency_data for r, l in latency_data[key])))
|
||||
# Define colors for up to a few common recall levels, add more if needed
|
||||
base_recall_colors = {
|
||||
85.0: "#1f77b4", # Blue
|
||||
90.0: "#ff7f0e", # Orange
|
||||
95.0: "#2ca02c", # Green
|
||||
# Add more if other recall % values exist
|
||||
}
|
||||
recall_colors = {}
|
||||
color_palette = sns.color_palette("viridis", n_colors=len(recall_levels_present))
|
||||
|
||||
for idx, r_level in enumerate(recall_levels_present):
|
||||
recall_colors[r_level] = base_recall_colors.get(r_level, color_palette[idx % len(color_palette)])
|
||||
|
||||
|
||||
# --- Determine global x (latency) and y (storage) limits for consistent axes ---
|
||||
all_latency_values = []
|
||||
all_storage_values = []
|
||||
raw_data_size = 76 # Raw data size in GB
|
||||
|
||||
for ds_hw_key in unique_dataset_hardware_configs:
|
||||
current_ds, current_hw = ds_hw_key
|
||||
for method_name in all_method_names:
|
||||
# Get storage for this method
|
||||
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
|
||||
if not np.isnan(disk_storage):
|
||||
all_storage_values.append(disk_storage)
|
||||
|
||||
# Get latencies for this method under current_ds, current_hw
|
||||
latency_key = (current_ds, current_hw, method_name)
|
||||
if latency_key in latency_data:
|
||||
for recall, latency in latency_data[latency_key]:
|
||||
if not np.isnan(latency):
|
||||
all_latency_values.append(latency)
|
||||
|
||||
# Add padding to limits
|
||||
min_lat = min(all_latency_values) if all_latency_values else 0.001
|
||||
max_lat = max(all_latency_values) if all_latency_values else 1
|
||||
min_store = min(all_storage_values) if all_storage_values else 0
|
||||
max_store = max(all_storage_values) if all_storage_values else 1
|
||||
|
||||
# Convert storage values to proportion of raw data
|
||||
min_store_proportion = min_store / raw_data_size if all_storage_values else 0
|
||||
max_store_proportion = max_store / raw_data_size if all_storage_values else 0.1
|
||||
|
||||
# Padding for log scale latency - adjust minimum to be more reasonable
|
||||
lat_log_min = -1 # Changed from -2 to -1 to set minimum to 10^-1 (0.1s)
|
||||
lat_log_max = np.log10(max_lat) if max_lat > 0 else 3 # default to 1000 s
|
||||
lat_padding = (lat_log_max - lat_log_min) * 0.05
|
||||
global_xlim = [10**(lat_log_min - lat_padding), 10**(lat_log_max + lat_padding)]
|
||||
if global_xlim[0] <= 0: global_xlim[0] = 0.1 # Changed from 0.01 to 0.1
|
||||
|
||||
# Padding for linear scale storage proportion
|
||||
store_padding = (max_store_proportion - min_store_proportion) * 0.05
|
||||
global_ylim = [max(0, min_store_proportion - store_padding), max_store_proportion + store_padding]
|
||||
if global_ylim[0] >= global_ylim[1]: # Avoid inverted or zero range
|
||||
global_ylim[1] = global_ylim[0] + 0.1
|
||||
|
||||
# After loading the data and before plotting, add this code to reorder the datasets
|
||||
# Find where you define all_datasets (around line 95)
|
||||
|
||||
# Original code:
|
||||
all_datasets = sorted(list(set(ds for ds, _ in unique_dataset_hardware_configs)))
|
||||
|
||||
# Replace with this to specify the exact order:
|
||||
all_datasets_unsorted = list(set(ds for ds, _ in unique_dataset_hardware_configs))
|
||||
desired_order = ['NQ', 'TriviaQA', 'GPQA','HotpotQA']
|
||||
all_datasets = [ds for ds in desired_order if ds in all_datasets_unsorted]
|
||||
# Add any datasets that might be in the data but not in our desired_order list
|
||||
all_datasets.extend([ds for ds in all_datasets_unsorted if ds not in desired_order])
|
||||
|
||||
# Then the rest of your code remains the same:
|
||||
a10_configs = [(ds, 'A10') for ds in all_datasets if (ds, 'A10') in unique_dataset_hardware_configs]
|
||||
mac_configs = [(ds, 'MAC') for ds in all_datasets if (ds, 'MAC') in unique_dataset_hardware_configs]
|
||||
|
||||
# Create two figures - one for A10 and one for MAC
|
||||
hardware_configs = [a10_configs, mac_configs]
|
||||
hardware_names = ['A10', 'MAC']
|
||||
|
||||
for fig_idx, configs_for_this_figure in enumerate(hardware_configs):
|
||||
if not configs_for_this_figure:
|
||||
continue
|
||||
|
||||
num_cols_this_figure = len(configs_for_this_figure)
|
||||
# 1 row, num_cols_this_figure columns
|
||||
fig, axs = plt.subplots(1, num_cols_this_figure, figsize=(7 * num_cols_this_figure, 6), sharex=True, sharey=True, squeeze=False)
|
||||
|
||||
# fig.suptitle(f"Latency vs. Storage ({hardware_names[fig_idx]})", fontsize=18, y=0.98)
|
||||
|
||||
for subplot_idx, (current_ds, current_hw) in enumerate(configs_for_this_figure):
|
||||
ax = axs[0, subplot_idx] # Accessing column in the first row
|
||||
ax.set_title(f"{current_ds}", fontsize=25) # No need to show hardware in title since it's in suptitle
|
||||
|
||||
for method_name in all_method_names:
|
||||
marker = method_markers.get(method_name, '+')
|
||||
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
|
||||
|
||||
latency_points_key = (current_ds, current_hw, method_name)
|
||||
if latency_points_key in latency_data:
|
||||
points_for_method = latency_data[latency_points_key]
|
||||
print(f"points_for_method: {points_for_method}")
|
||||
for recall, latency in points_for_method:
|
||||
# Only skip if latency is invalid (since we need log scale for x-axis)
|
||||
# But allow zero storage since y-axis is now linear
|
||||
if np.isnan(latency) or np.isnan(disk_storage) or latency <= 0:
|
||||
continue
|
||||
|
||||
# Add LLM generation time from CSV
|
||||
current_llm_add_time = llm_generation_times.get((current_ds, current_hw))
|
||||
if current_llm_add_time is not None and not np.isnan(current_llm_add_time):
|
||||
latency = latency + current_llm_add_time
|
||||
else:
|
||||
raise ValueError(f"No LLM generation time found for {current_ds} on {current_hw}")
|
||||
|
||||
# Special handling for BM25
|
||||
if method_name == 'BM25':
|
||||
# BM25 is only valid for 85% recall points (other points are 0)
|
||||
if recall != 85.0:
|
||||
continue
|
||||
color = 'grey'
|
||||
else:
|
||||
# Use the color for target recall
|
||||
color = recall_colors.get(recall, 'grey')
|
||||
|
||||
# Convert storage to proportion
|
||||
disk_storage_proportion = disk_storage / raw_data_size
|
||||
size = 80
|
||||
|
||||
x_offset = -50
|
||||
if current_ds == 'GPQA':
|
||||
x_offset = -32
|
||||
|
||||
# Apply a small vertical offset to IVF-Recompute points to make them more visible
|
||||
if method_name == 'IVF-Recompute':
|
||||
# Add a small vertical offset (adjust the 0.05 value as needed)
|
||||
disk_storage_proportion += 0.07
|
||||
size = 80
|
||||
if method_name == 'DiskANN':
|
||||
size = 50
|
||||
if method_name == 'Our':
|
||||
size = 140
|
||||
disk_storage_proportion += 0.05
|
||||
# Add "Pareto Frontier" label to Our method points
|
||||
|
||||
if recall == 95:
|
||||
ax.annotate('Ours',
|
||||
(latency, disk_storage_proportion),
|
||||
xytext=(x_offset, 25), # Increased leftward offset from -65 to -120
|
||||
textcoords='offset points',
|
||||
fontsize=20,
|
||||
color='red',
|
||||
weight='bold',
|
||||
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="red", alpha=0.7))
|
||||
# Increase size for BM25 points
|
||||
if method_name == 'BM25':
|
||||
size = 70
|
||||
size*=5
|
||||
|
||||
ax.scatter(latency, disk_storage_proportion, marker=marker, color=color,
|
||||
s=size, alpha=0.85, edgecolors='black', linewidths=0.7)
|
||||
|
||||
|
||||
|
||||
|
||||
ax.set_xscale("log")
|
||||
ax.set_yscale("linear") # CHANGED from log scale to linear scale for Y-axis
|
||||
|
||||
# Generate appropriate powers of 10 based on your data range
|
||||
min_power = -1
|
||||
max_power = 4
|
||||
log_ticks = [10**i for i in range(min_power, max_power+1)]
|
||||
|
||||
# Set custom tick positions
|
||||
ax.set_xticks(log_ticks)
|
||||
|
||||
# Create custom bold LaTeX labels with 10^n format
|
||||
log_tick_labels = [fr'$\mathbf{{10^{{{i}}}}}$' for i in range(min_power, max_power+1)]
|
||||
ax.set_xticklabels(log_tick_labels, fontsize=24)
|
||||
|
||||
# Apply global limits
|
||||
if subplot_idx == 0:
|
||||
ax.set_xlim(global_xlim)
|
||||
ax.set_ylim(global_ylim)
|
||||
|
||||
ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.7)
|
||||
# Remove minor grid lines completely
|
||||
ax.grid(False, which="minor")
|
||||
|
||||
# Remove ticks
|
||||
# First set the shared parameters for both axes
|
||||
ax.tick_params(axis='both', which='both', length=0, labelsize=24)
|
||||
|
||||
# Then set the padding only for the x-axis
|
||||
ax.tick_params(axis='x', which='both', pad=10)
|
||||
|
||||
if subplot_idx == 0: # Y-label only for the leftmost subplot
|
||||
ax.set_ylabel("Proportional Size", fontsize=24)
|
||||
|
||||
# X-label for all subplots in a 1xN layout can be okay, or just the middle/last one.
|
||||
# Let's put it on all for now.
|
||||
ax.set_xlabel("Latency (s)", fontsize=25)
|
||||
|
||||
# Display 100%, 200%, 300% for yaxis
|
||||
ax.set_yticks([1, 2, 3])
|
||||
ax.set_yticklabels(['100\%', '200\\%', '300\\%'])
|
||||
|
||||
# Create a custom arrow with "Better" text inside
|
||||
# Create the arrow patch with a wider shaft
|
||||
arrow = FancyArrowPatch(
|
||||
(0.8, 0.8), # Start point (top-right)
|
||||
(0.65, 0.6), # End point (toward bottom-left)
|
||||
transform=ax.transAxes,
|
||||
arrowstyle='simple,head_width=40,head_length=35,tail_width=20', # Increased arrow dimensions
|
||||
facecolor='white',
|
||||
edgecolor='black',
|
||||
linewidth=3, # Thicker outline
|
||||
zorder=5
|
||||
)
|
||||
|
||||
# Add the arrow to the plot
|
||||
ax.add_patch(arrow)
|
||||
|
||||
# Calculate the midpoint of the arrow for text placement
|
||||
mid_x = (0.8 + 0.65) / 2 + 0.002 + 0.01
|
||||
mid_y = (0.8 + 0.6) / 2 + 0.01
|
||||
|
||||
# Add the "Better" text at the midpoint of the arrow
|
||||
ax.text(mid_x, mid_y, 'Better',
|
||||
transform=ax.transAxes,
|
||||
ha='center',
|
||||
va='center',
|
||||
fontsize=16, # Increased font size from 12 to 16
|
||||
fontweight='bold',
|
||||
rotation=40, # Rotate to match arrow direction
|
||||
zorder=6) # Ensure text is on top of arrow
|
||||
|
||||
# Create legends (once per figure)
|
||||
method_legend_handles = []
|
||||
for method, marker_style in method_markers.items():
|
||||
if method in all_method_names:
|
||||
print(f"method: {method}")
|
||||
# Use black color for BM25 in the legend
|
||||
if method == 'BM25':
|
||||
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
|
||||
markersize=10, label=method))
|
||||
else:
|
||||
if method in method_display_names:
|
||||
method = method_display_names[method]
|
||||
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
|
||||
markersize=10, label=method))
|
||||
|
||||
recall_legend_handles = []
|
||||
sorted_recall_levels = sorted(recall_colors.keys())
|
||||
for r_level in sorted_recall_levels:
|
||||
recall_legend_handles.append(mlines.Line2D([], [], color=recall_colors[r_level], marker='o', linestyle='None',
|
||||
markersize=20, label=f"Target Recall={r_level:.0f}\%"))
|
||||
|
||||
# 将图例分成两行:第一行是方法,第二行是召回率
|
||||
if fig_idx == 0:
|
||||
# 从方法列表中先排除'Our'
|
||||
other_methods = [m for m in all_method_names if m != 'Our']
|
||||
# 按照需要的顺序创建方法列表(将'Our'放在最后)
|
||||
ordered_methods = other_methods + (['Our'] if 'Our' in all_method_names else [])
|
||||
|
||||
# 按照新顺序创建方法图例句柄
|
||||
method_legend_handles = []
|
||||
for method in ordered_methods:
|
||||
if method in method_markers:
|
||||
marker_style = method_markers[method]
|
||||
# 使用显示名称映射
|
||||
display_name = method_display_names.get(method, method)
|
||||
color = 'black'
|
||||
marker_size = 22
|
||||
if method == 'Our':
|
||||
marker_size = 27
|
||||
elif 'IVF-Recompute' in method or 'EdgeRAG' in method:
|
||||
marker_size = 17
|
||||
elif 'DiskANN' in method:
|
||||
marker_size = 19
|
||||
elif 'BM25' in method:
|
||||
marker_size = 20
|
||||
method_legend_handles.append(mlines.Line2D([], [], color=color, marker=marker_style,
|
||||
linestyle='None', markersize=marker_size, label=display_name))
|
||||
|
||||
# 创建召回率图例(第二行)- 注意位置调整,放在方法图例下方
|
||||
recall_legend = fig.legend(handles=recall_legend_handles,
|
||||
loc='upper center', bbox_to_anchor=(0.5, 1.05), # y坐标降低,放在第一行下方
|
||||
ncol=len(recall_legend_handles), fontsize=28)
|
||||
|
||||
|
||||
# 创建方法图例(第一行)
|
||||
method_legend = fig.legend(handles=method_legend_handles,
|
||||
loc='upper center', bbox_to_anchor=(0.5, 0.91),
|
||||
ncol=len(method_legend_handles), fontsize=28)
|
||||
|
||||
# 添加图例到渲染器
|
||||
fig.add_artist(method_legend)
|
||||
fig.add_artist(recall_legend)
|
||||
|
||||
# 调整布局,为顶部的两行图例留出更多空间
|
||||
plt.tight_layout(rect=(0, 0, 1.0, 0.74)) # 顶部空间从0.9调整到0.85,给两行图例留出更多空间
|
||||
|
||||
save_path = f'./paper_plot/figures/main_exp_fig_{fig_idx+1}.pdf'
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Saved figure {fig_idx+1} to {save_path}")
|
||||
plt.show()
|
||||
@@ -0,0 +1,163 @@
|
||||
import csv
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import csv
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
SAVE_PTH = "./paper_plot/figures"
|
||||
font_size = 16
|
||||
|
||||
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
|
||||
# 0.085s 0.217s 0.472s
|
||||
llm_inference_time=[0.085, 0.217, 0.472, 0]
|
||||
|
||||
USE_LLM_INDEX = 3 # +0
|
||||
|
||||
file_path = "./paper_plot/data/main_latency.csv"
|
||||
|
||||
with open(file_path, mode="r", newline="") as file:
|
||||
reader = csv.reader(file)
|
||||
data = list(reader)
|
||||
|
||||
# 打印原始数据
|
||||
for row in data:
|
||||
print(",".join(row))
|
||||
|
||||
|
||||
|
||||
|
||||
models = ["A10", "MAC"]
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
||||
data = [[float(cell) if cell.isdigit() else cell for cell in row] for row in data[1:]]
|
||||
for k, model in enumerate(models):
|
||||
|
||||
fig, axes = plt.subplots(1, 4)
|
||||
fig.set_size_inches(20, 3)
|
||||
plt.subplots_adjust(wspace=0, hspace=0)
|
||||
|
||||
total_width, n = 6, 6
|
||||
group = 1
|
||||
width = total_width * 0.9 / n
|
||||
x = np.arange(group) * n
|
||||
exit_idx_x = x + (total_width - width) / n
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "mediumpurple", "green", "red", "blue", "yellow", "silver"]
|
||||
# hatches = ["", "\\\\", "//", "||", "x", "--", "..", "", "\\\\", "//", "||", "x", "--", ".."]
|
||||
hatches =["\\\\\\","\\\\"]
|
||||
|
||||
labels = [
|
||||
"HNSW",
|
||||
"IVF",
|
||||
"DiskANN",
|
||||
"IVF-Disk",
|
||||
"IVF-Recompute",
|
||||
"Our",
|
||||
# "DGL-OnDisk",
|
||||
]
|
||||
if k == 0:
|
||||
x_labels = "GraphSAGE"
|
||||
else:
|
||||
x_labels = "GAT"
|
||||
|
||||
yticks = [0.01, 0.1, 1, 10, 100, 1000,10000] # Log scale ticks
|
||||
val_limit = 15000 # Upper limit for the plot
|
||||
|
||||
for i in range(4):
|
||||
axes[i].set_yscale('log') # Set y-axis to logarithmic scale
|
||||
axes[i].set_yticks(yticks)
|
||||
axes[i].set_ylim(0.01, val_limit) # Lower limit should be > 0 for log scale
|
||||
|
||||
axes[i].tick_params(axis="y", labelsize=10)
|
||||
|
||||
axes[i].set_xticks([])
|
||||
# axes[i].set_xticklabels()
|
||||
axes[i].set_xlabel(datasets[i], fontsize=font_size)
|
||||
axes[i].grid(axis="y", linestyle="--")
|
||||
axes[i].set_xlim(exit_idx_x[0] - 0.15 * width - 0.2, exit_idx_x[0] + (n-0.25)* width + 0.2)
|
||||
for j in range(n):
|
||||
##TODO add label
|
||||
|
||||
# num = float(data[i * 2 + k][j + 3])
|
||||
# plot_label = [num]
|
||||
# if j == 6 and i == 3:
|
||||
# plot_label = ["N/A"]
|
||||
# num = 0
|
||||
local_hatches=["////","\\\\","xxxx"]
|
||||
# here add 3 bars rather than one bar TODO
|
||||
print('exit_idx_x',exit_idx_x)
|
||||
|
||||
# Check if all three models for this algorithm are OOM (data = 0)
|
||||
is_oom = True
|
||||
for m in range(3):
|
||||
if float(data[i * 6 + k*3 + m][j + 3]) != 0:
|
||||
is_oom = False
|
||||
break
|
||||
|
||||
if is_oom:
|
||||
# Draw a cross for OOM instead of bars
|
||||
pos = exit_idx_x + j * width + width * 0.3 # Center position for cross
|
||||
marker_size = width * 150 # Size of the cross
|
||||
axes[i].scatter(pos, 0.02, marker='x', color=edgecolors[j], s=marker_size,
|
||||
linewidth=4, label=labels[j] if j < len(labels) else "", zorder=20)
|
||||
else:
|
||||
# Create three separate bar calls instead of trying to plot multiple bars at once
|
||||
for m in range(3):
|
||||
num = float(data[i * 6 + k*3 +m][j + 3]) +llm_inference_time[USE_LLM_INDEX]
|
||||
plot_label = [num]
|
||||
pos = exit_idx_x + j * width + width * 0.3 * m
|
||||
print(f"j: {j}, m: {m}, pos: {pos}")
|
||||
# For log scale, we need to ensure values are positive
|
||||
plot_value = max(0.01, num) if num < val_limit else val_limit
|
||||
container = axes[i].bar(
|
||||
pos,
|
||||
plot_value,
|
||||
width=width * 0.3,
|
||||
color="white",
|
||||
edgecolor=edgecolors[j],
|
||||
# edgecolor="k",
|
||||
hatch=local_hatches[m], # Use different hatches for each of the 3 bars
|
||||
linewidth=1.0,
|
||||
label=labels[j] if m == 0 else "", # Only add label for the first bar
|
||||
zorder=10,
|
||||
)
|
||||
# axes[i].bar_label(
|
||||
# container,
|
||||
# plot_label,
|
||||
# fontsize=font_size - 2,
|
||||
# zorder=200,
|
||||
# fontweight="bold",
|
||||
# )
|
||||
|
||||
if k == 0:
|
||||
axes[0].legend(
|
||||
bbox_to_anchor=(3.25, 1.02),
|
||||
ncol=7,
|
||||
loc="lower right",
|
||||
# fontsize=font_size,
|
||||
# markerscale=3,
|
||||
labelspacing=0.2,
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
# fancybox=False,
|
||||
handlelength=2,
|
||||
handletextpad=0.5,
|
||||
columnspacing=0.5,
|
||||
prop={"weight": "bold", "size": font_size},
|
||||
).set_zorder(100)
|
||||
|
||||
axes[0].set_ylabel("Runtime (log scale)", fontsize=font_size, fontweight="bold")
|
||||
axes[0].set_yticklabels([r"$10^{-2}$", r"$10^{-1}$", r"$10^{0}$", r"$10^{1}$", r"$10^{2}$", r"$10^{3}$",r"$10^{4}$"], fontsize=font_size)
|
||||
axes[1].set_yticklabels([])
|
||||
axes[2].set_yticklabels([])
|
||||
axes[3].set_yticklabels([])
|
||||
|
||||
plt.savefig(f"{SAVE_PTH }/speed_{model}_revised.pdf", bbox_inches="tight", dpi=300)
|
||||
## print save
|
||||
print(f"{SAVE_PTH }/speed_{model}_revised.pdf")
|
||||
@@ -0,0 +1,85 @@
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.gridspec import GridSpec
|
||||
|
||||
# Comment Test
|
||||
|
||||
# om script.settings import DATA_PATH, FIGURE_PATH
|
||||
# DATA_PATH ="/home/ubuntu/Power-RAG/paper_plot/data"
|
||||
# FIGURE_PATH = "/home/ubuntu/Power-RAG/paper_plot/figures"
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 2
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Load the RAM and Storage data directly from CSV
|
||||
data = pd.read_csv("./paper_plot/data/ram_storage.csv")
|
||||
|
||||
# Explicitly reorder columns to ensure "Our" is at the end
|
||||
cols = list(data.columns)
|
||||
if "Our" in cols and cols[-1] != "Our":
|
||||
cols.remove("Our")
|
||||
cols.append("Our")
|
||||
data = data[cols]
|
||||
|
||||
# Set up the figure with two columns
|
||||
fig = plt.figure(figsize=(12, 3))
|
||||
gs = GridSpec(1, 2, figure=fig)
|
||||
ax1 = fig.add_subplot(gs[0, 0]) # Left panel for RAM
|
||||
ax2 = fig.add_subplot(gs[0, 1]) # Right panel for Storage
|
||||
|
||||
# Define the visual style elements
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "silver", "navy"]
|
||||
hatches = ["/////", "\\\\\\\\\\"]
|
||||
|
||||
# Calculate positions for the bars
|
||||
methods = data.columns[1:] # Skip the 'Hardware' column
|
||||
num_methods = len(methods)
|
||||
# Reverse the order of methods for display (to have "Our" at the bottom)
|
||||
methods = list(methods)[::-1]
|
||||
y_positions = np.arange(num_methods)
|
||||
bar_width = 0.6
|
||||
|
||||
# Plot RAM data in left panel
|
||||
ram_bars = ax1.barh(
|
||||
y_positions,
|
||||
data.iloc[0, 1:].values[::-1], # Reverse the data to match reversed methods
|
||||
height=bar_width,
|
||||
color="white",
|
||||
edgecolor=edgecolors[0],
|
||||
hatch=hatches[0],
|
||||
linewidth=1.0,
|
||||
label="RAM",
|
||||
zorder=10,
|
||||
)
|
||||
ax1.set_title("RAM Usage", fontsize=14, fontweight='bold')
|
||||
ax1.set_yticks(y_positions)
|
||||
ax1.set_yticklabels(methods, fontsize=14)
|
||||
ax1.set_xlabel("Size (\\textit{GB})", fontsize=14)
|
||||
ax1.xaxis.set_tick_params(labelsize=14)
|
||||
|
||||
# Plot Storage data in right panel
|
||||
storage_bars = ax2.barh(
|
||||
y_positions,
|
||||
data.iloc[1, 1:].values[::-1], # Reverse the data to match reversed methods
|
||||
height=bar_width,
|
||||
color="white",
|
||||
edgecolor=edgecolors[1],
|
||||
hatch=hatches[1],
|
||||
linewidth=1.0,
|
||||
label="Storage",
|
||||
zorder=10,
|
||||
)
|
||||
ax2.set_title("Storage Usage", fontsize=14, fontweight='bold')
|
||||
ax2.set_yticks(y_positions)
|
||||
ax2.set_yticklabels(methods, fontsize=14)
|
||||
ax2.set_xlabel("Size (\\textit{GB})", fontsize=14)
|
||||
ax2.xaxis.set_tick_params(labelsize=14)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("./paper_plot/figures/ram_storage_double_column.pdf", bbox_inches="tight", dpi=300)
|
||||
print("Saving the figure to ./paper_plot/figures/ram_storage_double_column.pdf")
|
||||
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /bottleneck_breakdown.py
|
||||
# \brief: Illustrates the query time bottleneck on consumer devices (Final Version - Font & Legend Adjust).
|
||||
# Author: Gemini Assistant (adapted from user's style and feedback)
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import FuncFormatter # Not strictly needed for just font, but imported if user wants to try
|
||||
|
||||
# Set matplotlib styles similar to the example
|
||||
plt.rcParams["font.family"] = "Helvetica" # Primary font family
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["xtick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.0
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
|
||||
plt.rcParams["text.usetex"] = True
|
||||
# Attempt to make LaTeX use Helvetica as the main font
|
||||
plt.rcParams['text.latex.preamble'] = r"""
|
||||
\usepackage{helvet} % helvetica font
|
||||
\usepackage{sansmath} % helvetica for math
|
||||
\sansmath % activate sansmath
|
||||
\renewcommand{\familydefault}{\sfdefault} % make sans-serif the default family
|
||||
"""
|
||||
|
||||
|
||||
# Final Data for the breakdown (3 Segments)
|
||||
labels_raw = [ # Raw labels before potential LaTeX escaping
|
||||
'IO: Text + PQ Lookup',
|
||||
'CPU: Tokenize + Distance Compute',
|
||||
'GPU: Embedding Recompute',
|
||||
]
|
||||
# Times in ms, ordered for stacking
|
||||
times_ms = np.array([
|
||||
8.009, # Quantization
|
||||
16.197, # Search
|
||||
76.512, # Embedding Recomputation
|
||||
])
|
||||
|
||||
total_time_ms = times_ms.sum()
|
||||
percentages = (times_ms / total_time_ms) * 100
|
||||
|
||||
# Prepare labels for legend, escaping for LaTeX if active
|
||||
labels_legend = []
|
||||
# st1 = r'&' # Not needed as current labels_raw don't have '&'
|
||||
for label, time, perc in zip(labels_raw, times_ms, percentages):
|
||||
# Construct the percentage string carefully for LaTeX
|
||||
perc_str = f"{perc:.1f}" + r"\%" # Correct way to form 'NN.N\%'
|
||||
# label_tex = label.replace('&', st1) # Use if '&' is in labels_raw
|
||||
label_tex = label # Current labels_raw are clean for LaTeX
|
||||
labels_legend.append(
|
||||
f"{label_tex}\n({time:.1f}ms, {perc_str})"
|
||||
)
|
||||
|
||||
# Styling based on user's script
|
||||
# Using first 3 from the provided lists
|
||||
edgecolors_list = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
||||
hatches_list = ["/////", "xxxxx", "\\\\\\\\\\"]
|
||||
|
||||
edgecolors = edgecolors_list[:3]
|
||||
hatches = hatches_list[:3]
|
||||
fill_color = "white"
|
||||
|
||||
# Create the figure and axes
|
||||
# Adjusted figure size to potentially accommodate legend on the right
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(7, 1.5) # Width increased slightly, height adjusted
|
||||
# Adjusted right margin for external legend, bottom for x-label
|
||||
plt.subplots_adjust(left=0.12, right=0.72, top=0.95, bottom=0.25)
|
||||
|
||||
# Create the horizontal stacked bar
|
||||
bar_height = 0.2
|
||||
y_pos = 0
|
||||
|
||||
left_offset = 0
|
||||
for i in range(len(times_ms)):
|
||||
ax.barh(
|
||||
y_pos,
|
||||
times_ms[i],
|
||||
height=bar_height,
|
||||
left=left_offset,
|
||||
color=fill_color,
|
||||
edgecolor=edgecolors[i],
|
||||
hatch=hatches[i],
|
||||
linewidth=1.5,
|
||||
label=labels_legend[i],
|
||||
zorder=10
|
||||
)
|
||||
text_x_pos = left_offset + times_ms[i] / 2
|
||||
if times_ms[i] > total_time_ms * 0.03: # Threshold for displaying text
|
||||
ax.text(
|
||||
text_x_pos,
|
||||
y_pos,
|
||||
f"{times_ms[i]:.1f}ms",
|
||||
ha='center',
|
||||
va='center',
|
||||
fontsize=8,
|
||||
fontweight='bold',
|
||||
color='black',
|
||||
zorder=20,
|
||||
bbox=dict(facecolor='white', edgecolor='none', pad=0.5, alpha=0.8)
|
||||
)
|
||||
left_offset += times_ms[i]
|
||||
|
||||
# Set plot limits and labels
|
||||
ax.set_xlim([0, total_time_ms * 1.02])
|
||||
ax.set_xlabel("Time (ms)", fontsize=14, fontweight='bold', x=0.75, )
|
||||
|
||||
# Y-axis: Remove y-ticks and labels
|
||||
ax.set_yticks([])
|
||||
ax.set_yticklabels([])
|
||||
|
||||
# Legend: Placed to the right of the plot
|
||||
ax.legend(
|
||||
# (x, y) for anchor, (0,0) is bottom left, (1,1) is top right of AXES
|
||||
# To place outside on the right, x should be > 1
|
||||
bbox_to_anchor=(1.03, 0.5), # x > 1 means outside to the right, y=0.5 for vertical center
|
||||
ncol=1, # Single column for a taller, narrower legend
|
||||
loc="center left", # Anchor the legend's left-center to bbox_to_anchor point
|
||||
labelspacing=0.5, # Adjust spacing
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
fancybox=False,
|
||||
handlelength=1.5,
|
||||
handletextpad=0.6,
|
||||
columnspacing=1.5,
|
||||
prop={"weight": "bold", "size": 9},
|
||||
).set_zorder(100)
|
||||
|
||||
# Save the figure (using the original generic name as requested)
|
||||
output_filename = "./bottleneck_breakdown.pdf"
|
||||
# plt.tight_layout() # tight_layout might conflict with external legend; adjust subplots_adjust instead
|
||||
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
|
||||
print(f"Saved plot to {output_filename}")
|
||||
|
||||
# plt.show() # Uncomment to display plot interactively
|
||||
@@ -0,0 +1,226 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
# import matplotlib.ticker as mticker # Not actively used
|
||||
import os
|
||||
|
||||
FIGURE_PATH = "paper_plot/figures"
|
||||
|
||||
try:
|
||||
os.makedirs(FIGURE_PATH, exist_ok=True)
|
||||
print(f"Images will be saved to: {os.path.abspath(FIGURE_PATH)}")
|
||||
except OSError as e:
|
||||
print(f"Create {FIGURE_PATH} failed: {e}. Images will be saved in the current working directory.")
|
||||
FIGURE_PATH = "."
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 2
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
method_labels = ["gte-small (33M)", "contriever-msmarco (110M)"]
|
||||
dataset_names = ["NQ", "TriviaQA"]
|
||||
metrics_plot1 = ["Exact Match", "F1"]
|
||||
|
||||
small_nq_f1 = 0.2621040899
|
||||
small_tq_f1 = 0.4698198059
|
||||
small_nq_em_score = 0.1845
|
||||
small_tq_em_score = 0.4015
|
||||
small_nq_time = 1.137
|
||||
small_tq_time = 1.173
|
||||
|
||||
large_nq_f1 = 0.2841386117
|
||||
large_tq_f1 = 0.4548340289
|
||||
large_nq_em_score = 0.206
|
||||
large_tq_em_score = 0.382
|
||||
large_nq_time = 2.632
|
||||
large_tq_time = 2.684
|
||||
|
||||
data_scores_plot1 = {
|
||||
"NQ": {"Exact Match": [small_nq_em_score, large_nq_em_score], "F1": [small_nq_f1, large_nq_f1]},
|
||||
"TriviaQA": {"Exact Match": [small_tq_em_score, large_tq_em_score], "F1": [small_tq_f1, large_tq_f1]}
|
||||
}
|
||||
latency_data_plot2 = {
|
||||
"NQ": [small_nq_time, large_nq_time],
|
||||
"TriviaQA": [small_tq_time, large_tq_time]
|
||||
}
|
||||
|
||||
edgecolors = ["dimgrey", "tomato"]
|
||||
hatches = ["/////", "\\\\\\\\\\"]
|
||||
|
||||
# Changed: bar_center_separation_in_group increased for larger gap
|
||||
bar_center_separation_in_group = 0.42
|
||||
# Changed: bar_visual_width decreased for narrower bars
|
||||
bar_visual_width = 0.28
|
||||
|
||||
figsize_plot1 = (4, 2.5)
|
||||
# Changed: figsize_plot2 width adjusted to match figsize_plot1 for legend/caption alignment
|
||||
figsize_plot2 = (2.5, 2.5)
|
||||
|
||||
# Define plot1_xlim_per_subplot globally so it can be accessed by create_plot2_latency
|
||||
plot1_xlim_per_subplot = (0.0, 2.0) # Explicit xlim for plot 1 subplots
|
||||
|
||||
common_subplots_adjust_params = dict(wspace=0.30, top=0.80, bottom=0.22, left=0.09, right=0.96)
|
||||
|
||||
|
||||
def create_plot1_em_f1():
|
||||
fig, axs = plt.subplots(1, 2, figsize=figsize_plot1)
|
||||
fig.subplots_adjust(**common_subplots_adjust_params)
|
||||
|
||||
num_methods = len(method_labels)
|
||||
metric_group_centers = np.array([0.5, 1.5])
|
||||
# plot1_xlim_per_subplot is now global
|
||||
|
||||
for i, dataset_name in enumerate(dataset_names):
|
||||
ax = axs[i]
|
||||
for metric_idx, metric_name in enumerate(metrics_plot1):
|
||||
metric_center_pos = metric_group_centers[metric_idx]
|
||||
current_scores_raw = data_scores_plot1[dataset_name][metric_name]
|
||||
current_scores_percent = [val * 100 for val in current_scores_raw]
|
||||
|
||||
for j, method_label in enumerate(method_labels):
|
||||
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
||||
bar_center_pos = metric_center_pos + offset
|
||||
|
||||
ax.bar(
|
||||
bar_center_pos, current_scores_percent[j], width=bar_visual_width, color="white",
|
||||
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
|
||||
label=method_label if i == 0 and metric_idx == 0 else None
|
||||
)
|
||||
ax.text(
|
||||
bar_center_pos, current_scores_percent[j] + 0.8, f"{current_scores_percent[j]:.1f}",
|
||||
ha='center', va='bottom', fontsize=8, fontweight='bold'
|
||||
)
|
||||
|
||||
ax.set_xticks(metric_group_centers)
|
||||
ax.set_xticklabels(metrics_plot1, fontsize=9, fontweight='bold')
|
||||
ax.set_title(dataset_name, fontsize=12, fontweight='bold')
|
||||
ax.set_xlim(plot1_xlim_per_subplot) # Apply consistent xlim
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
||||
|
||||
if i == 0:
|
||||
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
||||
|
||||
all_subplot_scores_percent = []
|
||||
for metric_name_iter in metrics_plot1:
|
||||
all_subplot_scores_percent.extend([val * 100 for val in data_scores_plot1[dataset_name][metric_name_iter]])
|
||||
|
||||
max_val = max(all_subplot_scores_percent) if all_subplot_scores_percent else 0
|
||||
ax.set_ylim(0, max_val * 1.22 if max_val > 0 else 10)
|
||||
ax.tick_params(axis='y', labelsize=12)
|
||||
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(True)
|
||||
spine.set_linewidth(1.0)
|
||||
spine.set_edgecolor("black")
|
||||
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
fig.legend(
|
||||
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=len(method_labels),
|
||||
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
|
||||
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 9}
|
||||
)
|
||||
|
||||
# fig.text(0.5, 0.06, "(a) EM \& F1", ha='center', va='center', fontweight='bold', fontsize=11)
|
||||
|
||||
|
||||
save_path = os.path.join(FIGURE_PATH, "plot1_em_f1.pdf")
|
||||
# plt.tight_layout() # Adjusted call below
|
||||
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
|
||||
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
|
||||
plt.close(fig)
|
||||
print(f"Figure 1 (Exact Match & F1) has been saved to: {save_path}")
|
||||
|
||||
def create_plot2_latency():
|
||||
fig, axs = plt.subplots(1, 2, figsize=figsize_plot2) # figsize_plot2 width is now 8.0
|
||||
fig.subplots_adjust(**common_subplots_adjust_params)
|
||||
|
||||
num_methods = len(method_labels)
|
||||
method_group_center_in_subplot = 0.5
|
||||
|
||||
# Calculate bar extents to determine focused xlim
|
||||
bar_positions_calc = []
|
||||
for j_idx in range(num_methods):
|
||||
offset_calc = (j_idx - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
||||
bar_center_pos_calc = method_group_center_in_subplot + offset_calc
|
||||
bar_positions_calc.append(bar_center_pos_calc)
|
||||
|
||||
min_bar_actual_edge = min(bar_positions_calc) - bar_visual_width / 2.0
|
||||
max_bar_actual_edge = max(bar_positions_calc) + bar_visual_width / 2.0
|
||||
|
||||
# Define padding around the bars
|
||||
# Option 1: Fixed padding (e.g., 0.15 as derived from plot 1 visual)
|
||||
# padding_val = 0.15
|
||||
# plot2_xlim_calculated = (min_bar_actual_edge - padding_val, max_bar_actual_edge + padding_val)
|
||||
# This would be (0.15 - 0.15, 0.85 + 0.15) = (0.0, 1.0)
|
||||
|
||||
# Option 2: Center the group (0.5) in a span of 1.0
|
||||
plot2_xlim_calculated = (method_group_center_in_subplot - 0.5, method_group_center_in_subplot + 0.5)
|
||||
# This is (0.5 - 0.5, 0.5 + 0.5) = (0.0, 1.0)
|
||||
# This is simpler and achieves the (0.0, 1.0) directly.
|
||||
|
||||
for i, dataset_name in enumerate(dataset_names):
|
||||
ax = axs[i]
|
||||
current_latencies = latency_data_plot2[dataset_name]
|
||||
|
||||
for j, method_label in enumerate(method_labels):
|
||||
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
||||
bar_center_pos = method_group_center_in_subplot + offset
|
||||
|
||||
ax.bar(
|
||||
bar_center_pos, current_latencies[j], width=bar_visual_width, color="white",
|
||||
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
|
||||
label=method_label if i == 0 else None
|
||||
)
|
||||
ax.text(
|
||||
bar_center_pos, current_latencies[j] + 0.05, f"{current_latencies[j]:.2f}",
|
||||
ha='center', va='bottom', fontsize=10, fontweight='bold'
|
||||
)
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
||||
|
||||
ax.set_xticks([0.5])
|
||||
ax.set_xticklabels(["Latency"], color="white", fontsize=12)
|
||||
# set tick hatches
|
||||
ax.tick_params(axis='x', colors="white")
|
||||
ax.set_title(dataset_name, fontsize=13, fontweight='bold')
|
||||
ax.set_xlim(plot2_xlim_calculated)
|
||||
|
||||
if i == 0:
|
||||
ax.set_ylabel("Latency (s)", fontsize=12, fontweight="bold")
|
||||
|
||||
max_latency_in_subplot = max(current_latencies) if current_latencies else 0
|
||||
ax.set_ylim(0, max_latency_in_subplot * 1.22 if max_latency_in_subplot > 0 else 1)
|
||||
ax.tick_params(axis='y', labelsize=12)
|
||||
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(True)
|
||||
spine.set_linewidth(1.0)
|
||||
spine.set_edgecolor("black")
|
||||
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
fig.legend(
|
||||
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=num_methods,
|
||||
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
|
||||
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 9}
|
||||
)
|
||||
|
||||
# fig.text(0.5, 0.06, "(b) Latency", ha='center', va='center', fontweight='bold', fontsize=11)
|
||||
|
||||
save_path = os.path.join(FIGURE_PATH, "plot2_latency.pdf")
|
||||
# plt.tight_layout() # Adjusted call below
|
||||
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
|
||||
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
|
||||
plt.close(fig)
|
||||
print(f"Figure 2 (Latency) has been saved to: {save_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start generating figures...")
|
||||
if plt.rcParams["text.usetex"]:
|
||||
print("Info: LaTeX rendering is enabled. Ensure LaTeX is installed and configured if issues arise, or set plt.rcParams['text.usetex'] to False.")
|
||||
|
||||
create_plot1_em_f1()
|
||||
create_plot2_latency()
|
||||
print("All figures have been generated.")
|
||||
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /speed_ablation.py
|
||||
# \brief:
|
||||
# Author: raphael hao
|
||||
|
||||
# %%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# %%
|
||||
# from script.settings import DATA_PATH, FIGURE_PATH
|
||||
|
||||
# Load the latency ablation data
|
||||
latency_data = pd.read_csv("./paper_plot/data/latency_ablation.csv")
|
||||
# Filter for SpeedUp metric only
|
||||
speedup_data = latency_data[latency_data['Metric'] == 'SpeedUp']
|
||||
|
||||
# %%
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
# %%
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(5, 1.5)
|
||||
plt.subplots_adjust(wspace=0, hspace=0)
|
||||
|
||||
total_width, n = 3, 3
|
||||
group = len(speedup_data['Dataset'].unique())
|
||||
width = total_width * 0.9 / n
|
||||
x = np.arange(group) * n
|
||||
exit_idx_x = x + (total_width - width) / n
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
||||
hatches = ["/////", "xxxxx", "\\\\\\\\\\"]
|
||||
labels = ["Base", "Base + Two-level", "Base + Two-level + Batch"]
|
||||
|
||||
datasets = speedup_data['Dataset'].unique()
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
dataset_data = speedup_data[speedup_data['Dataset'] == dataset]
|
||||
|
||||
for j in range(n):
|
||||
if j == 0:
|
||||
value = dataset_data['Original'].values[0]
|
||||
elif j == 1:
|
||||
value = dataset_data['original + two_level'].values[0]
|
||||
else:
|
||||
value = dataset_data['original + two_level + batch'].values[0]
|
||||
|
||||
ax.text(
|
||||
exit_idx_x[i] + j * width,
|
||||
value + 0.05,
|
||||
f"{value:.2f}",
|
||||
ha='center',
|
||||
va='bottom',
|
||||
fontsize=10,
|
||||
fontweight='bold',
|
||||
rotation=0,
|
||||
zorder=20,
|
||||
)
|
||||
|
||||
ax.bar(
|
||||
exit_idx_x[i] + j * width,
|
||||
value,
|
||||
width=width * 0.8,
|
||||
color="white",
|
||||
edgecolor=edgecolors[j],
|
||||
hatch=hatches[j],
|
||||
linewidth=1.5,
|
||||
label=labels[j] if i == 0 else None,
|
||||
zorder=10,
|
||||
)
|
||||
|
||||
|
||||
|
||||
ax.set_ylim([0.5, 2.3])
|
||||
ax.set_yticks(np.arange(0.5, 2.2, 0.5))
|
||||
ax.set_yticklabels(np.arange(0.5, 2.2, 0.5), fontsize=12)
|
||||
ax.set_xticks(exit_idx_x + width)
|
||||
ax.set_xticklabels(datasets, fontsize=10)
|
||||
# ax.set_xlabel("Different Datasets", fontsize=14)
|
||||
ax.legend(
|
||||
bbox_to_anchor=(-0.03, 1.4),
|
||||
ncol=3,
|
||||
loc="upper left",
|
||||
labelspacing=0.1,
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
fancybox=False,
|
||||
handlelength=0.8,
|
||||
handletextpad=0.6,
|
||||
columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 10},
|
||||
).set_zorder(100)
|
||||
ax.set_ylabel("Speedup", fontsize=11)
|
||||
|
||||
plt.savefig("./paper_plot/figures/latency_speedup.pdf", bbox_inches="tight", dpi=300)
|
||||
|
||||
# %%
|
||||
|
||||
print(f"Save to ./paper_plot/figures/latency_speedup.pdf")
|
||||
Reference in New Issue
Block a user