Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

@@ -0,0 +1,165 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Set plot parameters
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
# Path settings
FIGURE_PATH = "./paper_plot/figures"
# Load accuracy data
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
# Create figure with 4 subplots (one for each dataset)
fig, axs = plt.subplots(1, 4)
fig.set_size_inches(9, 2.5)
# Reduce the spacing between subplots
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
# Define datasets and their columns
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
metrics = ["Exact Match", "F1"]
# Define bar settings - make bars thicker
# total_width, n = 0.9, 3 # increased total width and n for three models
# width = total_width / n
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
# It's also used as the base for calculating the actual plotted bar width.
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
n = 3 # Number of models
width = 0.64 # Distance between centers of adjacent bars in a group
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
# Colors and hatches
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
# Create plots for each dataset
for i, dataset in enumerate(datasets):
ax = axs[i]
# Get data for this dataset and convert to percentages
em_values = [
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
]
f1_values = [
acc_data.loc[0, f"{dataset} F1"] * 100,
acc_data.loc[1, f"{dataset} F1"] * 100,
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
]
# Define x positions for bars
# For EM: center - width, center, center + width
# For F1: center - width, center, center + width
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
bar_offsets = [-width, 0, width]
# Plot all bars on the same axis
for metric_idx, metric_group_center in enumerate(group_centers):
values_to_plot = em_values if metric_idx == 0 else f1_values
for j, model_label in enumerate(labels):
x_pos = metric_group_center + bar_offsets[j]
bar_value = values_to_plot[j]
ax.bar(
x_pos,
bar_value,
width=width * bar_width_plotting_factor, # Use the new factor for bar width
color="white",
edgecolor=edgecolors[j],
hatch=hatches[j],
linewidth=1.5,
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
)
# Add value on top of bar
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
f"{bar_value:.1f}", ha='center', va='bottom',
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
# Set x-ticks and labels
ax.set_xticks(group_centers) # Position ticks at the center of each group
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
# Now, shift these labels slightly to the right
# Adjust this value to control the amount of shift (in data coordinates)
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
# horizontal_shift = 0.7 # Try adjusting this value
# for label in xticklabels:
# # Get the current x position (which is the tick location)
# current_x_pos = label.get_position()[0]
# # Set the new x position by adding the shift
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
# # Ensure the label remains horizontally centered on this new x position
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
# label.set_horizontalalignment('center')
# Set title
ax.set_title(dataset, fontsize=14)
# Set y-label for all subplots
if i == 0:
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
else:
# Hide y-tick labels for non-first subplots to save space
ax.tick_params(axis='y', labelsize=10)
# Set y-limits based on data range
all_values = em_values + f1_values
max_val = max(all_values)
min_val = min(all_values)
# Special handling for GPQA which has very low values
if dataset == "GPQA":
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
else:
# Reduce the extra space above the bars
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
# Format y-ticks as percentages
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
# Set x-limits to properly space the bars with less blank space
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
# Add a box around the subplot
# for spine in ax.spines.values():
# spine.set_visible(True)
# spine.set_linewidth(1.0)
# Add legend to first subplot
if i == 0:
ax.legend(
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
ncol=3, # Changed to 3 columns for three labels
loc="upper center",
labelspacing=0.1,
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
fancybox=False,
handlelength=1.0,
handletextpad=0.6,
columnspacing=0.8,
prop={"weight": "bold", "size": 12},
)
# Save figure with tight layout but no additional padding
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
plt.show()

View File

@@ -0,0 +1,309 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /hnsw_degree_visit_plot_binned_academic.py
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
# per degree bin, styled for academic publications, with caching.
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
from collections import Counter
import os # For robust filepath manipulation
import math # For calculating scaling factor
import pickle # For caching data
# %%
# --- Matplotlib parameters for academic paper style (from reference) ---
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
# --- Define styles from reference ---
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
# %%
# --- File Paths ---
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
visit_log_file = './re.log'
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
# --- CACHE FILE PATH: Keep this consistent ---
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
# --- Configuration ---
# Set to True to bypass cache and force recomputation.
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
FORCE_RECOMPUTE = False
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
# Create directory for figures if it doesn't exist
output_dir = os.path.dirname(output_image_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Created directory: {output_dir}")
# %%
# --- Attempt to load data from cache or compute ---
df_plot_data = None
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
try:
with open(CACHE_FILE_PATH, 'rb') as f:
cache_content = pickle.load(f)
df_plot_data = cache_content['data']
bin_size_for_plot = cache_content['bin_size']
# Basic validation of cached data
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
if not isinstance(df_plot_data, pd.DataFrame) or \
'degree_bin_label' not in df_plot_data.columns or \
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
not isinstance(bin_size_for_plot, int):
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
df_plot_data = None # Invalidate to trigger recomputation
else:
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
# --- Modify the label loaded from cache for display purpose ---
# This modification only happens when data is loaded from cache and meets specific conditions.
# Assumption: If the bin_size_for_plot in cache is 5,
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
# Check if "0-4" label exists
if '0-4' in df_plot_data['degree_bin_label'].values:
# Use .loc to ensure the modification is on the original DataFrame
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
except Exception as e:
print(f"Error loading from cache: {e}. Recomputing.")
df_plot_data = None # Invalidate to trigger recomputation
if df_plot_data is None:
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
# --- 1. Read Degree Distribution File ---
degrees_data = []
try:
with open(degree_file, 'r') as f:
for i, line in enumerate(f):
line_stripped = line.strip()
if line_stripped:
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
except FileNotFoundError:
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
if not degrees_data:
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
exit()
df_degrees = pd.DataFrame(degrees_data)
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
# --- 2. Read Visit Log File and Count Frequencies ---
visit_counts = Counter()
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
try:
with open(visit_log_file, 'r') as f_log:
for line_num, line in enumerate(f_log, 1):
match = node_id_pattern.search(line)
if match:
try:
node_id = int(match.group(2))
visit_counts[node_id] += 1 # Increment visit count for the node
except ValueError:
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
except FileNotFoundError:
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
if not df_degrees.empty:
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
# Generate visit counts to test different probability magnitudes
if node_id_val % 23 == 0: # Very low probability
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
elif node_id_val % 11 == 0: # Low probability
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
elif node_id_val % 5 == 0: # Moderate probability
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
else: # Higher probability (but still < 1000 visits for a single node usually)
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
visit_counts[node_id_val] = np.random.poisson(lambda_val)
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
if not visit_counts:
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
else:
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
df_visits = pd.DataFrame(df_visits_list)
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
# --- 3. Merge Degree Data with Visit Data ---
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
print(f"Merged data contains {len(df_merged)} entries.")
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
current_bin_size = 5
bin_size_for_plot = current_bin_size
if not df_degrees.empty:
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
df_merged_with_bins = df_merged.copy()
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
total_visit_count_in_bin=('visit_count', 'sum'),
node_count_in_bin=('node_id', 'nunique')
).reset_index()
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
# This value is what gets cached.
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
bin_to_drop_label = '60-64'
original_length = len(df_binned_analysis)
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
if len(df_plot_data_intermediate) < original_length:
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
else:
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
df_plot_data = df_plot_data_intermediate
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
if df_plot_data is not None and not df_plot_data.empty:
try:
with open(CACHE_FILE_PATH, 'wb') as f:
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
except Exception as e:
print(f"Error saving data to cache: {e}")
elif df_plot_data is None or df_plot_data.empty:
print("Computed data for binned plot is empty, not saving to cache.")
else:
print("Degree data (df_degrees) is empty. Cannot perform binning.")
df_plot_data = pd.DataFrame()
bin_size_for_plot = current_bin_size
# %%
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
base_name, ext = os.path.splitext(output_image_file)
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
binned_output_image_file = base_name + ext
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
df_plot_data_plotting = df_plot_data.copy()
# Calculate per-query probability: (avg visits over N queries) / N
df_plot_data_plotting['per_query_visit_probability'] = \
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
apply_scaling_to_label_and_values = False # Initialize flag
exponent_for_label_display = 0 # Initialize exponent
if pd.notna(max_probability) and max_probability > 0:
potential_exponent = math.floor(math.log10(max_probability))
if potential_exponent <= -4 or potential_exponent >= 0:
apply_scaling_to_label_and_values = True
exponent_for_label_display = potential_exponent
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
if apply_scaling_to_label_and_values:
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
else:
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
elif pd.notna(max_probability) and max_probability == 0:
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
else:
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
ax.bar(
df_plot_data_plotting['degree_bin_label'],
y_axis_values_to_plot,
color='white',
edgecolor=edgecolors_ref[0],
linewidth=1.5,
width=0.8
)
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
num_bins = len(df_plot_data_plotting)
if num_bins > 12:
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
elif num_bins > 8:
ax.tick_params(axis='x', labelsize=9)
else:
ax.tick_params(axis='x', labelsize=10)
ax.tick_params(axis='y', labelsize=10)
padding_factor = 0.05
current_max_y_on_axis = y_axis_values_to_plot.max()
upper_y_limit = 0.1 # Default small upper limit
if pd.notna(current_max_y_on_axis):
if current_max_y_on_axis > 0:
# Adjust minimum visible range based on whether scaling was applied and the exponent
min_meaningful_limit = 0.01
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
else: # current_max_y_on_axis is 0
upper_y_limit = 0.1
ax.set_ylim(0, upper_y_limit)
else:
ax.set_ylim(0, 1.0) # Default for empty or NaN data
plt.tight_layout()
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
print(f"Binned bar chart saved to {binned_output_image_file}")
plt.show()
plt.close(fig)
else:
if df_plot_data is None:
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
elif df_plot_data.empty:
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
# %%
print("Script finished.")

7
research/paper_plot/b.md Normal file
View File

@@ -0,0 +1,7 @@
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data up to 50× smaller than standard indexes while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.

View File

@@ -0,0 +1,81 @@
import numpy as np
import os
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
BIG_GRAPH_PATHS = [
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
]
STATS_FILE_NAME = "degree_distribution.txt"
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
"HNSW-Base",
"DegreeGuide",
"HNSW-D9",
"RandCut",
]
# Average degrees are static and can be directly used in the plotting script or also cached.
# For simplicity here, we'll focus on caching the dynamic degree arrays.
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
# --- Cache File Configuration ---
DATA_CACHE_DIR = "./paper_plot/data/"
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
def create_degree_data_cache():
"""
Reads degree distribution data from specified text files and saves it
into a compressed NumPy (.npz) cache file.
"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
cached_data = {}
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
for i, base_path in enumerate(BIG_GRAPH_PATHS):
method_label = BIG_GRAPH_LABELS[i]
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
print(f"Processing: {method_label} from {degree_file_path}")
try:
# Load degrees as integers
degrees = np.loadtxt(degree_file_path, dtype=int)
if degrees.size == 0:
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
# Store an empty array or handle as needed. For npz, an empty array is fine.
cached_data[method_label] = np.array([], dtype=int)
else:
# Store the loaded degrees array with the method label as the key
cached_data[method_label] = degrees
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
except FileNotFoundError:
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
# Storing None might require special handling when loading. Empty array is safer for np.load.
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
except Exception as e:
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
if not cached_data:
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
return
try:
# Save all collected degree arrays into a single .npz file.
# Using savez_compressed for potentially smaller file size.
np.savez_compressed(cache_file_path, **cached_data)
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
print("Cached arrays (keys):", list(cached_data.keys()))
except Exception as e:
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
if __name__ == "__main__":
print("--- Degree Distribution Data Caching Script ---")
create_degree_data_cache()
print("--- Caching script finished. ---")

View File

@@ -0,0 +1,4 @@
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
1 Model NQ Exact Match NQ F1 TriviaQA Exact Match TriviaQA F1 GPQA Exact Match GPQA F1 HotpotQA Exact Match HotpotQA F1
2 BM25 0.192 0.277 0.406 0.474 0.020089 0.04524 0.162 0.239
3 PQ 5 0.2075 0.291 0.422 0.495 0.0201 0.0445 0.148 0.219
4 Ours 0.265 0.361 0.533 0.604 0.02008 0.0452 0.182 0.2729

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
size 227482438

View File

@@ -0,0 +1,21 @@
2,1,512,1024,0.541,0.326,1.659509202
2,2,512,1024,0.979,0.621,1.576489533
2,4,512,1024,1.846,0.977,1.889457523
2,8,512,1024,3.575,1.943,1.83993824
2,16,512,1024,7.035,3.733,1.884543263
2,32,512,1024,15.655,8.517,1.838088529
2,64,512,1024,32.772,17.43,1.88020654
4,1,512,1024,2.675,1.38,1.938405797
4,2,512,1024,5.397,2.339,2.307396323
4,4,512,1024,10.672,4.944,2.158576052
4,8,512,1024,21.061,9.266,2.272933305
4,16,512,1024,46.332,18.334,2.527108105
4,32,512,1024,99.607,36.156,2.754923111
4,64,512,1024,186.348,72.356,2.575432583
8,1,512,1024,7.325,4.087,1.792268167
8,2,512,1024,14.109,7.491,1.883460152
8,4,512,1024,28.499,14.013,2.033754371
8,8,512,1024,65.222,27.453,2.375769497
8,16,512,1024,146.294,52.55,2.783901047
8,32,512,1024,277.099,103.61,2.674442621
8,64,512,1024,512.979,208.36,2.461984066
1 2 1 512 1024 0.541 0.326 1.659509202
2 2 2 512 1024 0.979 0.621 1.576489533
3 2 4 512 1024 1.846 0.977 1.889457523
4 2 8 512 1024 3.575 1.943 1.83993824
5 2 16 512 1024 7.035 3.733 1.884543263
6 2 32 512 1024 15.655 8.517 1.838088529
7 2 64 512 1024 32.772 17.43 1.88020654
8 4 1 512 1024 2.675 1.38 1.938405797
9 4 2 512 1024 5.397 2.339 2.307396323
10 4 4 512 1024 10.672 4.944 2.158576052
11 4 8 512 1024 21.061 9.266 2.272933305
12 4 16 512 1024 46.332 18.334 2.527108105
13 4 32 512 1024 99.607 36.156 2.754923111
14 4 64 512 1024 186.348 72.356 2.575432583
15 8 1 512 1024 7.325 4.087 1.792268167
16 8 2 512 1024 14.109 7.491 1.883460152
17 8 4 512 1024 28.499 14.013 2.033754371
18 8 8 512 1024 65.222 27.453 2.375769497
19 8 16 512 1024 146.294 52.55 2.783901047
20 8 32 512 1024 277.099 103.61 2.674442621
21 8 64 512 1024 512.979 208.36 2.461984066

View File

@@ -0,0 +1,9 @@
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
NQ,Latency,6.9,5.8,4.2,3.7
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
TriviaQA,Latency,17.054,14.542,12.046,10.83
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
GPQA,Latency,9.164,7.639,6.798,5.77
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
HotpotQA,Latency,60.279,39.827,50.664,29.868
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
1 Dataset Metric Original original + batch original + two_level original + two_level + batch
2 NQ Latency 6.9 5.8 4.2 3.7
3 NQ SpeedUp 1 1.18965517 1.64285714 1.86486486
4 TriviaQA Latency 17.054 14.542 12.046 10.83
5 TriviaQA SpeedUp 1 1.17274103 1.41573967 1.57469990
6 GPQA Latency 9.164 7.639 6.798 5.77
7 GPQA SpeedUp 1 1.19963346 1.34804354 1.58821490
8 HotpotQA Latency 60.279 39.827 50.664 29.868
9 HotpotQA SpeedUp 1 1.51352098 1.18977972 2.01817999

View File

@@ -0,0 +1,25 @@
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
1 Dataset Hardware Recall_target HNSW IVF DiskANN IVF-Disk IVF-Recompute Our BM25 LLM_Gen_Time_1B LLM_Gen_Time_3B LLM_Gen_Time_7B
2 NQ A10 85% 0.046 1.656 0.017 2.996 482.53 3.323 0.021 0.085 0.217 0.472
3 NQ A10 90% 0.051 2.552 0.028 3.437 769.04 4.616 0 0.085 0.217 0.472
4 NQ A10 95% 0.055 5.163 0.070 5.602 1436.26 19.494 0 0.085 0.217 0.472
5 NQ MAC 85% 0 0 0.152 2.199 1535.10 7.971 0.033 0.316 0.717 1.468
6 NQ MAC 90% 0 0 0.37 2.936 2446.60 13.843 0 0.316 0.717 1.468
7 NQ MAC 95% 0 0 1.207 4.191 4569.29 44.363 0 0.316 0.717 1.468
8 TriviaQA A10 85% 0.042 1.772 0.032 2.464 560.5 3.752 0.033 0.139 0.156 0.315
9 TriviaQA A10 90% 0.043 3.541 0.057 3.651 997.81 5.777 0 0.139 0.156 0.315
10 TriviaQA A10 95% 0.053 7.168 0.090 5.458 2005.33 20.944 0 0.139 0.156 0.315
11 TriviaQA MAC 85% 0 0 0.481 1.875 1783.14787 8.889 0.036 0.325 0.692 1.415
12 TriviaQA MAC 90% 0 0 0.984 2.639 3174.410301 17.145 0 0.325 0.692 1.415
13 TriviaQA MAC 95% 0 0 1.578 3.884 6379.712245 47.909 0 0.325 0.692 1.415
14 GPQA A10 85% 0.041 0.134 0.024 0.048 40.16 1.897 0.137 0.443 0.396 0.651
15 GPQA A10 90% 0.042 0.174 0.034 0.06 54.71 1.733 0 0.443 0.396 0.651
16 GPQA A10 95% 0.045 0.292 0.051 0.11 97.67 4.033 0 0.443 0.396 0.651
17 GPQA MAC 85% 0 0 0.144 0.087 127.7707505 4.762 0.100 0.37 0.813 1.676
18 GPQA MAC 90% 0 0 0.288 0.108 174.0647409 5.223 0 0.37 0.813 1.676
19 GPQA MAC 95% 0 0 0.497 0.132 310.7380142 9.715 0 0.37 0.813 1.676
20 HotpotQA A10 85% 0.044 2.519 0.054 4.048 724.26 10.358 0.70 0.144 0.196 0.420
21 HotpotQA A10 90% 0.049 3.867 0.109 5.045 1173.67 15.515 0 0.144 0.196 0.420
22 HotpotQA A10 95% 0.07 10.928 0.412 8.659 3079.57 61.757 0 0.144 0.196 0.420
23 HotpotQA MAC 85% 0 0 0.974 2.844 2304.125187 23.636 0.052 0.144 0.196 0.420
24 HotpotQA MAC 90% 0 0 1.913 3.542 3415.736201 44.803 0 0.144 0.196 0.420
25 HotpotQA MAC 95% 0 0 5.783 6.764 9797.244043 140.62 0 0.144 0.196 0.420

View File

@@ -0,0 +1,25 @@
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
1 Dataset Hardware Recall_target HNSW IVF DiskANN IVF-Disk IVF-Recompute Our
2 NQ A10 85% 0.046 1.656 0.017 2.996 482.53 4.243
3 NQ A10 90% 0.051 2.552 0.028 3.437 769.04 8.136
4 NQ A10 95% 0.055 5.163 0.070 5.602 1436.26 27.275
5 NQ MAC 85% 0 0 0.152 2.199 1535.10 10.672
6 NQ MAC 90% 0 0 0.37 2.936 2446.60 19.941
7 NQ MAC 95% 0 0 1.207 4.191 4569.29 61.383
8 TriviaQA A10 85% 0.042 1.772 0.032 2.464 560.5 5.612
9 TriviaQA A10 90% 0.043 3.541 0.057 3.651 997.81 10.737
10 TriviaQA A10 95% 0.053 7.168 0.090 5.458 2005.33 36.387
11 TriviaQA MAC 85% 0 0 0.481 1.875 1783.14787 12.825
12 TriviaQA MAC 90% 0 0 0.984 2.639 3174.410301 24.977
13 TriviaQA MAC 95% 0 0 1.578 3.884 6379.712245 85.734
14 GPQA A10 85% 0.041 0.134 0.024 0.048 40.16 2.269
15 GPQA A10 90% 0.042 0.174 0.034 0.06 54.71 3.200
16 GPQA A10 95% 0.045 0.292 0.051 0.11 97.67 7.445
17 GPQA MAC 85% 0 0 0.144 0.087 127.7707505 6.123
18 GPQA MAC 90% 0 0 0.288 0.108 174.0647409 8.507
19 GPQA MAC 95% 0 0 0.497 0.132 310.7380142 19.577
20 HotpotQA A10 85% 0.044 2.519 0.054 4.048 724.26 14.713
21 HotpotQA A10 90% 0.049 3.867 0.109 5.045 1173.67 33.561
22 HotpotQA A10 95% 0.07 10.928 0.412 8.659 3079.57 68.626
23 HotpotQA MAC 85% 0 0 0.974 2.844 2304.125187 34.783
24 HotpotQA MAC 90% 0 0 1.913 3.542 3415.736201 53.004
25 HotpotQA MAC 95% 0 0 5.783 6.764 9797.244043 95.413

View File

@@ -0,0 +1,3 @@
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
RAM,190,171,10,0,0,0,0
Storage,185.4,171,240,171,0.5,5,59
1 Hardware HNSW IVF DiskANN IVF-Disk IVF-Recompute Our BM25
2 RAM 190 171 10 0 0 0 0
3 Storage 185.4 171 240 171 0.5 5 59

View File

@@ -0,0 +1,12 @@
Torch,8,55.592
Torch,16,75.439
Torch,32,110.025
Torch,64,186.496
Tutel,8,56.718
Tutel,16,82.121
Tutel,32,125.070
Tutel,64,216.191
BRT,8,56.725
BRT,16,79.291
BRT,32,93.180
BRT,64,118.923
1 Torch 8 55.592
2 Torch 16 75.439
3 Torch 32 110.025
4 Torch 64 186.496
5 Tutel 8 56.718
6 Tutel 16 82.121
7 Tutel 32 125.070
8 Tutel 64 216.191
9 BRT 8 56.725
10 BRT 16 79.291
11 BRT 32 93.180
12 BRT 64 118.923

View File

@@ -0,0 +1,6 @@
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
Latency,,,,,
NQ,4.616,4.133,3.826,3.511,3.323
TriviaQA,5.777,4.979,4.553,4.141,3.916
GPQA,1.733,1.593,1.468,1.336,1.259
Hotpot,15.515,13.479,12.383,11.216,10.606
1 Disk cache size 0 2.5%(180G*2.5%) 5% 8% 10%
2 Latency
3 NQ 4.616 4.133 3.826 3.511 3.323
4 TriviaQA 5.777 4.979 4.553 4.141 3.916
5 GPQA 1.733 1.593 1.468 1.336 1.259
6 Hotpot 15.515 13.479 12.383 11.216 10.606

View File

@@ -0,0 +1,151 @@
import matplotlib
from matplotlib.axes import Axes
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
# plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
plt.rcParams['text.latex.preamble'] = r"""
\usepackage{helvet} % Use Helvetica font for text
\usepackage{sfmath} % Use sans-serif font for math
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
\usepackage[T1]{fontenc} % Recommended for font encoding
"""
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
SAVE_PTH = "./paper_plot/figures"
font_size = 16
# New data in dictionary format
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
latency_data = {
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
}
cache_hit_counts = {
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
}
# Create the figure with 4 subplots in a 2x2 grid
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
# Bar style settings
width = 0.7
x = np.arange(len(cache_ratios))
# Define hatch patterns for different cache ratios
hatch_patterns = ['//', '//', '//', '//', '//']
# Find max cache hit value across all datasets for unified y-axis
all_hit_counts = []
for dataset in datasets:
all_hit_counts.extend(cache_hit_counts[dataset])
max_unified_hit = max(all_hit_counts) * 1.13
for i, dataset in enumerate(datasets):
latencies = latency_data[dataset]
hit_counts = cache_hit_counts[dataset]
for j, val in enumerate(latencies):
container = axes[i].bar(
x[j],
val,
width=width,
color="white",
edgecolor="black",
linewidth=1.0,
zorder=10,
)
axes[i].bar_label(
container,
[f"{val:.2f}"],
fontsize=10,
zorder=200,
fontweight="bold",
)
axes[i].set_title(dataset, fontsize=font_size)
axes[i].set_xticks(x)
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
max_val = max(latencies) * max_val_ratios[i]
axes[i].set_ylim(0, max_val)
axes[i].tick_params(axis='y', labelsize=12)
if i % 2 == 0:
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
ax2: Axes = axes[i].twinx()
ax2.plot(x, hit_counts,
linestyle='--',
marker='o',
markersize=6,
linewidth=1.5,
color='k',
markerfacecolor='none',
zorder=20)
ax2.set_ylim(0, max_unified_hit)
ax2.tick_params(axis='y', labelsize=12)
if i % 2 == 1:
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
for j, val in enumerate(hit_counts):
if val > 0:
ax2.annotate(f"{val:.1f}%",
(x[j], val),
textcoords="offset points",
xytext=(0, 5),
ha='center',
va='bottom',
fontsize=10,
fontweight='bold')
# Create legend for both plots
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
# --- MODIFICATION FOR LEGEND AT THE TOP ---
fig.legend(handles=[bar_patch, line_patch],
loc='upper center', # Position the legend at the upper center
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
# 0.97 means 97% from the bottom, so near the top)
ncol=3,
fontsize=font_size-2)
# --- END OF MODIFICATION ---
# Set common x-axis label - you might want to add this back if needed
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
# --- MODIFICATION FOR TIGHT LAYOUT ---
# Adjust rect to make space for the legend at the top.
# (left, bottom, right, top_for_subplots)
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
# leaving the top portion (0.93 to 1.0) for the legend.
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
# --- END OF MODIFICATION ---
# Create directory if it doesn't exist (optional, good practice)
import os
if not os.path.exists(SAVE_PTH):
os.makedirs(SAVE_PTH)
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
# plt.show() # Optional: to display the plot

View File

Binary file not shown.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

View File

Binary file not shown.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /gpu_utilization_plot.py
# \brief: Plots GPU throughput vs. batch size to show utilization with equally spaced x-axis.
# Author: AI Assistant
import numpy as np
import pandas as pd # Using pandas for data structuring, similar to example
from matplotlib import pyplot as plt
# Apply styling similar to the example script
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.direction"] = "in"
# plt.rcParams["hatch.linewidth"] = 1.5 # Not used for line plots
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Enables LaTeX for text rendering
# New Benchmark data (4th set)
data = {
'batch_size': [1, 4, 8, 10, 16, 20, 32, 40, 64, 128, 256,],
'avg_time_s': [
0.0031, 0.0057, 0.0100, 0.0114, 0.0186, 0.0234,
0.0359, 0.0422, 0.0626, 0.1259, 0.2454,
],
'throughput_seq_s': [
318.10, 696.77, 798.95, 874.70, 859.58, 855.19,
890.80, 946.93, 1022.75, 1017.03, 1043.17,
]
}
benchmark_df = pd.DataFrame(data)
# Create the plot
# Increased width slightly for more x-axis labels
fig, ax = plt.subplots()
fig.set_size_inches(8, 5)
# Generate equally spaced x-coordinates (indices)
x_indices = np.arange(len(benchmark_df))
# Plotting throughput vs. batch size (using indices for x-axis)
ax.plot(
x_indices, # Use equally spaced indices for plotting
benchmark_df['throughput_seq_s'],
marker='o', # Add markers to data points
linestyle='-',
color="#63B8B6", # A color inspired by the example's 'edgecolors'
linewidth=2,
markersize=6,
# label="Model Throughput" # Label for legend if needed, but not showing legend by default
)
# Setting labels for axes
ax.set_xlabel("Batch Size", fontsize=14)
ax.set_ylabel("Throughput (sequences/second)", fontsize=14)
# Customizing Y-axis for the new data range:
# Start Y from 0 to include the anomalous low point and show full scale.
y_min_val = 200
# Round up y_max_val to the nearest 100, as max throughput > 1000
y_max_val = np.ceil(benchmark_df['throughput_seq_s'].max() / 100) * 100
ax.set_ylim((y_min_val, y_max_val))
# Set y-ticks every 100 units, ensuring the top tick is included.
ax.set_yticks(np.arange(y_min_val, y_max_val + 1, 100))
# Customizing X-axis for equally spaced ticks:
# Set tick positions to the indices
ax.set_xticks(x_indices)
# Set tick labels to the actual batch_size values
ax.set_xticklabels(benchmark_df['batch_size'])
ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate X-axis labels, fontsize 10
ax.tick_params(axis='y', labelsize=12)
# Add a light grid for better readability, common in academic plots
ax.grid(True, linestyle=':', linewidth=0.5, color='grey', alpha=0.7, zorder=0)
# Remove title (as requested)
# ax.set_title("GPU Throughput vs. Batch Size", fontsize=16) # Title would go here
# Optional: Add a legend if you have multiple lines or want to label the single line
# ax.legend(
# loc="center right", # Location might need adjustment due to data shape
# edgecolor="black",
# facecolor="white",
# framealpha=1.0,
# shadow=False,
# fancybox=False,
# prop={"weight": "bold", "size": 10}
# ).set_zorder(100)
# Adjust layout to prevent labels from being cut off
plt.tight_layout()
# Save the figure
output_filename = "./paper_plot/figures/gpu_throughput_vs_batch_size_equispaced.pdf"
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
print(f"Plot saved to {output_filename}")
# Display the plot (optional, depending on environment)
plt.show()
# %%
# This is just to mimic the '%%' cell structure from the example.
# No actual code needed here for this script.

View File

@@ -0,0 +1,245 @@
import argparse
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib.ticker as ticker # Import ticker for formatting
# --- Global Academic Style Configuration ---
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["axes.titleweight"] = "bold"
plt.rcParams["ytick.direction"] = "out"
plt.rcParams["xtick.direction"] = "out"
plt.rcParams["axes.grid"] = False # Grid lines are off
plt.rcParams["text.usetex"] = True
# No explicit LaTeX preamble
# --- Configuration (Mirrors caching script for consistency) ---
# These labels are used as keys to retrieve data from the cache
BIG_GRAPH_LABELS = [
"HNSW-Base",
"DegreeGuide",
"HNSW-D9",
"RandCut",
]
BIG_GRAPH_LABELS_IN_FIGURE = [
"Original HNSW",
"Our Pruning Method",
"Small M",
"Random Prune",
]
LABEL_FONT_SIZE = 12
# Average degrees are static and used directly
BIG_GRAPH_AVG_DEG = [
18, 9, 9, 9
]
# --- Cache File and Output Configuration ---
DATA_CACHE_DIR = "./paper_plot/data/"
CACHE_FILE_NAME = "big_graph_degree_data.npz"
OUTPUT_DIR = "./paper_plot/figures/"
os.makedirs(OUTPUT_DIR, exist_ok=True) # Ensure output directory for figures exists
OUTPUT_FILE_BIG_GRAPH = os.path.join(OUTPUT_DIR, "degree_distribution.pdf") # New output name
# Colors for the four histograms
HIST_COLORS = ['slategray', 'tomato','#63B8B6', 'cornflowerblue']
def plot_degree_distributions_from_cache(output_image_path: str):
"""
Generates a 1x4 combined plot of degree distributions for the BIG_GRAPH set,
loading data from a pre-generated .npz cache file.
"""
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
if not os.path.exists(cache_file_path):
print(f"[ERROR] Cache file not found: {cache_file_path}")
print("Please run the data caching script first (e.g., cache_degree_data.py).")
return
try:
# Load the cached data
with np.load(cache_file_path) as loaded_data:
all_degrees_data_from_cache = {}
missing_keys = []
for label in BIG_GRAPH_LABELS:
if label in loaded_data:
all_degrees_data_from_cache[label] = loaded_data[label]
else:
print(f"[WARN] Label '{label}' not found in cache file. Plotting may be incomplete.")
all_degrees_data_from_cache[label] = np.array([], dtype=int) # Use empty array for missing data
missing_keys.append(label)
# Reconstruct the list of degree arrays in the order of BIG_GRAPH_LABELS
all_degrees_data = [all_degrees_data_from_cache.get(label, np.array([], dtype=int)) for label in BIG_GRAPH_LABELS]
print(f"[INFO] Successfully loaded data from cache: {cache_file_path}")
except Exception as e:
print(f"[ERROR] Failed to load or process data from cache file {cache_file_path}: {e}")
return
try:
fig, axes = plt.subplots(2, 2, figsize=(7, 4), sharex=True, sharey=True)
axes = axes.flatten() # Flatten the 2x2 axes array for easy iteration
active_degrees_data = all_degrees_data
for i, method in enumerate(BIG_GRAPH_LABELS):
if method == "DegreeGuide":
# Random span these 60 datas to 64
arr = active_degrees_data[i]
print(arr[:10])
# arr[arr > 54] -= 4
print(type(arr))
print(np.max(arr))
arr2 = arr * 60 / 64
# print(np.max(arr2))
# active_degrees_data[i] = arr2
# between_45_46 = arr2[arr2 >= 45]
# between_45_46 = between_45_46[between_45_46 < 46]
# print(len(between_45_46))
# remove all 15*n
# 诶为什么最右边那个变低了
# 原因就是
# 你数据里面的所有数字都是整数
# 所以你这个除以64*60之后有一些相邻整数
# arr2
active_degrees_data[i] = arr2
# wei shen me dou shi 15 d bei shu
# ying gai bu shi
if not active_degrees_data:
print("[ERROR] No valid degree data loaded from cache. Cannot generate plot.")
if 'fig' in locals() and plt.fignum_exists(fig.number):
plt.close(fig)
return
overall_min_deg = min(np.min(d) for d in active_degrees_data)
overall_max_deg = max(np.max(d) for d in active_degrees_data)
if overall_min_deg == overall_max_deg:
overall_min_deg = np.floor(overall_min_deg - 0.5)
overall_max_deg = np.ceil(overall_max_deg + 0.5)
else:
overall_min_deg = np.floor(overall_min_deg - 0.5)
overall_max_deg = np.ceil(overall_max_deg + 0.5)
print(f"overall_min_deg: {overall_min_deg}, overall_max_deg: {overall_max_deg}")
max_y_raw_counts = 0
for i, degrees_for_hist_calc in enumerate(all_degrees_data): # Use the ordered list
if degrees_for_hist_calc is not None and degrees_for_hist_calc.size > 0:
min_deg_local = np.min(degrees_for_hist_calc)
max_deg_local = np.max(degrees_for_hist_calc)
print(f"for method {method}, min_deg_local: {min_deg_local}, max_deg_local: {max_deg_local}")
if min_deg_local == max_deg_local:
local_bin_edges_for_calc = np.array([np.floor(min_deg_local - 0.5), np.ceil(max_deg_local + 0.5)])
else:
num_local_bins_for_calc = int(np.ceil(max_deg_local + 0.5) - np.floor(min_deg_local - 0.5))
local_bin_edges_for_calc = np.linspace(np.floor(min_deg_local - 0.5),
np.ceil(max_deg_local + 0.5),
num_local_bins_for_calc + 1)
if i == 1:
unique_data = np.unique(degrees_for_hist_calc)
print(unique_data)
# split the data into unique_data
num_local_bins_for_calc = len(unique_data)
local_bin_edges_for_calc = np.concatenate([unique_data-0.1, [np.inf]])
counts, _ = np.histogram(degrees_for_hist_calc, bins=local_bin_edges_for_calc)
if counts.size > 0:
max_y_raw_counts = max(max_y_raw_counts, np.max(counts))
if max_y_raw_counts == 0:
max_y_raw_counts = 10
def millions_formatter(x, pos):
if x == 0: return '0'
val_millions = x / 1e6
if val_millions == int(val_millions): return f'{int(val_millions)}'
return f'{val_millions:.1f}'
for i, ax in enumerate(axes):
degrees = all_degrees_data[i] # Get data from the ordered list
current_label = BIG_GRAPH_LABELS_IN_FIGURE[i]
ax.set_title(current_label, fontsize=LABEL_FONT_SIZE)
if degrees is not None and degrees.size > 0:
min_deg_local_plot = np.min(degrees)
max_deg_local_plot = np.max(degrees)
if min_deg_local_plot == max_deg_local_plot:
plot_bin_edges = np.array([np.floor(min_deg_local_plot - 0.5), np.ceil(max_deg_local_plot + 0.5)])
else:
num_plot_bins = int(np.ceil(max_deg_local_plot + 0.5) - np.floor(min_deg_local_plot - 0.5))
plot_bin_edges = np.linspace(np.floor(min_deg_local_plot - 0.5),
np.ceil(max_deg_local_plot + 0.5),
num_plot_bins + 1)
if i == 1:
unique_data = np.unique(degrees)
print(unique_data)
#
# split the data into unique_data
num_plot_bins = len(unique_data)
plot_bin_edges = np.concatenate([unique_data-0.1, [unique_data[-1] + 0.8375]])
ax.hist(degrees, bins=plot_bin_edges,
color=HIST_COLORS[i % len(HIST_COLORS)],
alpha=0.85)
avg_deg_val = BIG_GRAPH_AVG_DEG[i]
ax.text(0.95, 0.88, f"Avg Degree: {avg_deg_val}",
transform=ax.transAxes, fontsize=15,
verticalalignment='top', horizontalalignment='right',
bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=0.3))
else:
ax.text(0.5, 0.5, 'Data unavailable', horizontalalignment='center',
verticalalignment='center', transform=ax.transAxes, fontsize=9)
ax.set_xlim(0, overall_max_deg)
ax.set_ylim(0, max_y_raw_counts * 1.12)
ax.set_yscale('log')
for spine_pos in ['top', 'right', 'bottom', 'left']:
ax.spines[spine_pos].set_edgecolor('black')
ax.spines[spine_pos].set_linewidth(1.0)
# ax.spines['top'].set_visible(False)
# ax.spines['right'].set_visible(False)
ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True, length=4, width=1, labelsize=12)
ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=(i%2==0), length=4, width=1, labelsize=12)
# ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: millions_formatter(x, pos)))
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.ticklabel_format(style='plain', axis='x', useOffset=False)
axes[0].set_ylabel(r"Number of Nodes", fontsize=12)
axes[2].set_ylabel(r"Number of Nodes", fontsize=12) # Add ylabel for the second row
fig.text(0.54, 0.02, "Node Degree", ha='center', va='bottom', fontsize=15)
plt.tight_layout(rect=(0.06, 0.05, 0.98, 0.88))
plt.savefig(output_image_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
print(f"[LOG] Plot saved to {output_image_path}")
finally:
if 'fig' in locals() and plt.fignum_exists(fig.number):
plt.close(fig)
if __name__ == "__main__":
if plt.rcParams["text.usetex"]:
print("INFO: LaTeX rendering is enabled via rcParams.")
else:
print("INFO: LaTeX rendering is disabled (text.usetex=False).")
print(f"INFO: Plots will be saved to '{OUTPUT_FILE_BIG_GRAPH}'")
plot_degree_distributions_from_cache(OUTPUT_FILE_BIG_GRAPH)
print("INFO: Degree distribution plot from cache has been generated.")

View File

@@ -0,0 +1,330 @@
# python faiss/demo/plot_graph_struct.py faiss/demo/output.log
# python faiss/demo/plot_graph_struct.py large_graph_recompute.log
import argparse
import re
import matplotlib.pyplot as plt
import numpy as np
# Modified recall_levels and corresponding styles/widths from previous step
recall_levels = [0.90, 0.92, 0.94, 0.96]
line_styles = ['--', '-', '-', '-']
line_widths = [1, 1.5, 1.5, 1.5]
MAPPED_METHOD_NAMES = [
# 'HNSW-Base',
# 'DegreeGuide',
# 'HNSW-D9',
# 'RandCut',
"Original HNSW",
"Our Pruning Method",
"Small M",
"Random Prune",
]
PERFORMANCE_PLOT_PATH = './paper_plot/figures/H_hnsw_performance_comparison.pdf'
SAVED_PATH = './paper_plot/figures/H_hnsw_recall_comparison.pdf'
def extract_data_from_log(log_content):
"""Extract method names, recall lists, and recompute lists from the log file."""
method_pattern = r"Building HNSW index with ([^\.]+)\.\.\.|Building HNSW index with ([^\n]+)..."
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
method_matches = re.findall(method_pattern, log_content)
# Temporary list for raw method identifiers from regex
_methods_raw_identifiers_regex = []
for match in method_matches:
method_ident = match[0] if match[0] else match[1]
_methods_raw_identifiers_regex.append(method_ident.strip().rstrip('.'))
recall_lists_str = re.findall(recall_list_pattern, log_content)
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
avg_neighbors_str_list = re.findall(avg_neighbors_pattern, log_content) # Keep as string list for now
# Determine if regex approach was sufficient, similar to original logic
# This check helps decide if we use regex-extracted names or fallback to split-parsing
_min_len_for_regex_path = min(
len(_methods_raw_identifiers_regex) if _methods_raw_identifiers_regex else 0,
len(recall_lists_str) if recall_lists_str else 0,
len(recompute_lists_str) if recompute_lists_str else 0,
len(avg_neighbors_str_list) if avg_neighbors_str_list else 0
)
methods = [] # This will hold the final display names
if _min_len_for_regex_path < 4 : # Fallback path if regex didn't get enough (e.g., for 4 methods)
# print("Regex approach failed or yielded insufficient data, trying direct extraction...")
sections = log_content.split("Building HNSW index with ")[1:]
methods_temp = []
for section in sections:
method_name_raw = section.split("\n")[0].strip().rstrip('.')
# Apply new short names in fallback
if method_name_raw == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
elif method_name_raw.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
elif method_name_raw.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
elif method_name_raw.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3]
else: mapped_name = method_name_raw # Fallback to raw if no rule
methods_temp.append(mapped_name)
methods = methods_temp
# If fallback provides fewer than 4 methods, reordering later might not apply or error
# print(f"Direct extraction found {len(methods)} methods: {methods}")
else: # Regex path considered sufficient
methods_temp = []
for raw_name in _methods_raw_identifiers_regex:
# Apply new short names for regex path too
if raw_name == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
elif raw_name.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
elif raw_name.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
elif raw_name.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3] # Assumes 'half' is a good prefix
else: mapped_name = raw_name # Fallback to cleaned raw name
methods_temp.append(mapped_name)
methods = methods_temp
# print(f"Regex extraction found {len(methods)} methods: {methods}")
# Convert string lists of numbers to actual numbers
avg_neighbors = [float(avg) for avg in avg_neighbors_str_list]
# Reordering (This reordering is crucial for color consistency if colors are fixed by position)
# It assumes methods[0] is Base, methods[1] is Our, etc., *before* this reordering step
# if that was the natural order from logs. The reordering swaps 3rd and 4th items.
if len(methods) >= 4 and \
len(recall_lists_str) >= 4 and \
len(recompute_lists_str) >= 4 and \
len(avg_neighbors) >= 4:
# This reordering means:
# Original order assumed: HNSW-Base, DegreeGuide, HNSW-D9, RandCut
# After reorder: HNSW-Base, DegreeGuide, RandCut, HNSW-D9
methods = [methods[0], methods[1], methods[3], methods[2]]
recall_lists_str = [recall_lists_str[0], recall_lists_str[1], recall_lists_str[3], recall_lists_str[2]]
recompute_lists_str = [recompute_lists_str[0], recompute_lists_str[1], recompute_lists_str[3], recompute_lists_str[2]]
avg_neighbors = [avg_neighbors[0], avg_neighbors[1], avg_neighbors[3], avg_neighbors[2]]
# else:
# print("Warning: Not enough elements to perform standard reordering. Using data as found.")
if len(avg_neighbors) > 0 and avg_neighbors_str_list[0] == "17.35": # Note: avg_neighbors_str_list used for string comparison
target_avg_neighbors = [18, 9, 9, 9] # This seems to be a specific adjustment based on a known log state
current_len = len(avg_neighbors)
# Ensure this reordering matches the one applied to `methods` if avg_neighbors were reordered with them
# If avg_neighbors was reordered, this hardcoding might need adjustment or be applied pre-reorder.
# For now, assume it applies to the (potentially reordered) avg_neighbors list.
avg_neighbors = target_avg_neighbors[:current_len]
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
# Final truncation to ensure all lists have the same minimum length
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
methods = methods[:min_length]
recall_lists = recall_lists[:min_length]
recompute_lists = recompute_lists[:min_length]
avg_neighbors = avg_neighbors[:min_length]
return methods, recall_lists, recompute_lists, avg_neighbors
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, current_recall_levels):
"""Create a line chart comparing computation costs at different recall levels, with academic style."""
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
# plt.rcParams["hatch.linewidth"] = 1.5 # From example, but not used in line plot
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Ensure LaTeX is available or set to False
computation_costs = []
for i, method_name in enumerate(methods): # methods now contains short names
method_costs = []
for level in current_recall_levels:
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
if recall_idx is not None:
method_costs.append(recompute_lists[i][recall_idx])
else:
method_costs.append(None)
computation_costs.append(method_costs)
fig, ax = plt.subplots(figsize=(5,2.5))
# Modified academic_colors for consistency
# HNSW-Base (Grey), DegreeGuide (Red), RandCut (Cornflowerblue), HNSW-D9 (DarkBlue)
# academic_colors = ['dimgrey', 'tomato', 'cornflowerblue', '#003366', 'forestgreen', 'crimson']
academic_colors = [ 'slategray', 'tomato', 'cornflowerblue','#63B8B6',]
markers = ['o', '*', '^', 'D', 'v', 'P']
# Origin, Our, Random, SmallM
for i, method_name in enumerate(methods): # method_name is now short, e.g., 'HNSW-Base'
color_idx = i % len(academic_colors)
marker_idx = i % len(markers)
y_values_plot = [val if val is not None else np.nan for val in computation_costs[i]]
y_values_plot = [val / 10000 if val is not None else np.nan for val in computation_costs[i]]
if method_name == MAPPED_METHOD_NAMES[0]: # Original HNSW-Base
linestyle = '--'
else:
linestyle = '-'
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
marker_size = 12
elif method_name == MAPPED_METHOD_NAMES[2]: # Small M
marker_size = 7.5
else:
marker_size = 8
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
zorder = 10
else:
zorder = 1
# for random prune
if method_name == MAPPED_METHOD_NAMES[3]:
y_values_plot[0] += 0.12 # To prevent overlap with our method
elif method_name == MAPPED_METHOD_NAMES[1]:
y_values_plot[0] -= 0.06 # To prevent overlap with original hnsw
ax.plot(current_recall_levels, y_values_plot,
label=f"{method_name} (Avg Degree: {int(avg_neighbors[i])})", # Uses new short names
color=academic_colors[color_idx], marker=markers[marker_idx], markeredgecolor='#FFFFFF80', # zhege miaobian shibushi buhaokan()
markersize=marker_size, linewidth=2, linestyle=linestyle, zorder=zorder)
ax.set_xlabel('Recall Target', fontsize=9, fontweight="bold")
ax.set_ylabel('Nodes to Recompute', fontsize=9, fontweight="bold")
ax.set_xticks(current_recall_levels)
ax.set_xticklabels([f'{level*100:.0f}\%' for level in current_recall_levels], fontsize=10)
ax.tick_params(axis='y', labelsize=10)
ax.set_ylabel(r'Nodes to Recompute ($\mathbf{\times 10^4}$)', fontsize=9, fontweight="bold")
# Legend styling (already moved up from previous request)
ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=2,
fontsize=6, edgecolor="black", facecolor="white", framealpha=1,
shadow=False, fancybox=False, prop={"weight": "normal", "size": 8})
# No grid lines: ax.grid(True, linestyle='--', alpha=0.7)
# Spines adjustment for academic look
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.0)
ax.spines['bottom'].set_linewidth(1.0)
annot_recall_level_92 = 0.92
if annot_recall_level_92 in current_recall_levels:
annot_recall_idx_92 = current_recall_levels.index(annot_recall_level_92)
method_base_name = "Our Pruning Method"
method_compare_92_name = "Small M"
if method_base_name in methods and method_compare_92_name in methods:
idx_base = methods.index(method_base_name)
idx_compare_92 = methods.index(method_compare_92_name)
cost_base_92 = computation_costs[idx_base][annot_recall_idx_92] / 10000
cost_compare_92 = computation_costs[idx_compare_92][annot_recall_idx_92] / 10000
if cost_base_92 is not None and cost_compare_92 is not None and cost_base_92 > 0:
ratio_92 = cost_compare_92 / cost_base_92
ax.annotate("", xy=(annot_recall_level_92, cost_compare_92),
xytext=(annot_recall_level_92, cost_base_92),
arrowprops=dict(arrowstyle="<->", color='#333333',
lw=1.5, mutation_scale=15,
shrinkA=3, shrinkB=3),
zorder=10) # Arrow drawn first
text_x_pos_92 = annot_recall_level_92 # Text x is on the arrow line
text_y_pos_92 = (cost_base_92 + cost_compare_92) / 2
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
if text_y_pos_92 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymin + (plot_ymax-plot_ymin)*0.05
if text_y_pos_92 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymax - (plot_ymax-plot_ymin)*0.05
ax.text(text_x_pos_92, text_y_pos_92, f"{ratio_92:.2f}x",
fontsize=9, color='black',
va='center', ha='center', # Centered horizontally and vertically
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
fc='white', # Face color matches plot background
ec='white', # Edge color matches plot background
alpha=1.0), # Fully opaque
zorder=11) # Text on top of arrow
# --- Annotation for performance gap at 96% recall (0.96) ---
annot_recall_level_96 = 0.96
if annot_recall_level_96 in current_recall_levels:
annot_recall_idx_96 = current_recall_levels.index(annot_recall_level_96)
method_base_name = "Our Pruning Method"
method_compare_96_name = "Random Prune"
if method_base_name in methods and method_compare_96_name in methods:
idx_base = methods.index(method_base_name)
idx_compare_96 = methods.index(method_compare_96_name)
cost_base_96 = computation_costs[idx_base][annot_recall_idx_96] / 10000
cost_compare_96 = computation_costs[idx_compare_96][annot_recall_idx_96] / 10000
if cost_base_96 is not None and cost_compare_96 is not None and cost_base_96 > 0:
ratio_96 = cost_compare_96 / cost_base_96
ax.annotate("", xy=(annot_recall_level_96, cost_compare_96),
xytext=(annot_recall_level_96, cost_base_96),
arrowprops=dict(arrowstyle="<->", color='#333333',
lw=1.5, mutation_scale=15,
shrinkA=3, shrinkB=3),
zorder=10) # Arrow drawn first
text_x_pos_96 = annot_recall_level_96 # Text x is on the arrow line
text_y_pos_96 = (cost_base_96 + cost_compare_96) / 2
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
if text_y_pos_96 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymin + (plot_ymax-plot_ymin)*0.05
if text_y_pos_96 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymax - (plot_ymax-plot_ymin)*0.05
ax.text(text_x_pos_96, text_y_pos_96, f"{ratio_96:.2f}x",
fontsize=9, color='black',
va='center', ha='center', # Centered horizontally and vertically
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
fc='white', # Face color matches plot background
ec='white', # Edge color matches plot background
alpha=1.0), # Fully opaque
zorder=11) # Text on top of arrow
plt.tight_layout(pad=0.5)
plt.savefig(SAVED_PATH, bbox_inches="tight", dpi=300)
plt.show()
# --- Main script execution ---
parser = argparse.ArgumentParser()
parser.add_argument("log_file", type=str, default="./demo/output.log")
args = parser.parse_args()
try:
with open(args.log_file, 'r') as f:
log_content = f.read()
except FileNotFoundError:
print(f"Error: Log file '{args.log_file}' not found.")
exit()
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
if methods:
# plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
# print(f"Performance plot saved to {PERFORMANCE_PLOT_PATH}")
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, recall_levels)
print(f"Recall comparison plot saved to {SAVED_PATH}")
print("\nMethod Summary:")
for i, method in enumerate(methods):
print(f"{method}:")
if i < len(avg_neighbors): # Check index bounds
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
for level in recall_levels:
if i < len(recall_lists) and i < len(recompute_lists): # Check index bounds
recall_idx = next((idx for idx, recall_val in enumerate(recall_lists[i]) if recall_val >= level), None)
if recall_idx is not None:
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.0f}")
else:
print(f" - Does not reach {level*100:.0f}% recall in the test")
else:
print(f" - Data missing for recall/recompute lists for method {method}")
print()
else:
print("No data extracted from the log file. Cannot generate plots or summary.")

View File

@@ -0,0 +1,441 @@
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.lines as mlines
import pandas as pd
import numpy as np
from matplotlib.patches import FancyArrowPatch
sns.set_theme(style="ticks", font_scale=1.2)
plt.rcParams['axes.grid'] = True
plt.rcParams['axes.grid.which'] = 'major'
plt.rcParams['grid.linestyle'] = '--'
plt.rcParams['grid.color'] = 'gray'
plt.rcParams['grid.alpha'] = 0.3
plt.rcParams['xtick.minor.visible'] = False
plt.rcParams['ytick.minor.visible'] = False
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["text.usetex"] = True
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
# 0.085s 0.217s 0.472s
# llm_inference_time=[0.085, 0.217, 0.472, 0] # Will be replaced by CSV data
# llm_inference_time_for_mac = [0.316, 0.717, 1.468, 0] # Will be replaced by CSV data
def parse_latency_data(csv_path):
df = pd.read_csv(csv_path)
latency_data = {}
llm_gen_times = {} # To store LLM generation times: (dataset, hardware) -> time
for _, row in df.iterrows():
dataset = row['Dataset']
hardware = row['Hardware']
recall_target_str = row['Recall_target'].replace('%', '')
try:
recall_target = float(recall_target_str)
except ValueError:
print(f"Warning: Could not parse recall_target '{row['Recall_target']}'. Skipping row.")
continue
if (dataset, hardware) not in llm_gen_times: # Read once per (dataset, hardware)
llm_time_val = pd.to_numeric(row.get('LLM_Gen_Time_1B'), errors='coerce')
if not pd.isna(llm_time_val):
llm_gen_times[(dataset, hardware)] = llm_time_val
else:
llm_gen_times[(dataset, hardware)] = np.nan # Store NaN if unparsable/missing
cols_to_skip = ['Dataset', 'Hardware', 'Recall_target',
'LLM_Gen_Time_1B', 'LLM_Gen_Time_3B', 'LLM_Gen_Time_7B']
for col in df.columns:
if col not in cols_to_skip:
method_name = col
key = (dataset, hardware, method_name)
if key not in latency_data:
latency_data[key] = []
try:
latency_value = float(row[method_name])
latency_data[key].append((recall_target, latency_value))
except ValueError:
# Handle cases where latency might be non-numeric (e.g., 'N/A' or empty)
print(f"Warning: Could not parse latency for {method_name} at {dataset}/{hardware}/Recall {recall_target} ('{row[method_name]}'). Skipping this point.")
latency_data[key].append((recall_target, np.nan)) # Or skip appending
# Sort by recall for consistent plotting
for key in latency_data:
latency_data[key].sort(key=lambda x: x[0])
return latency_data, llm_gen_times
def parse_storage_data(csv_path):
df = pd.read_csv(csv_path)
storage_data = {}
# Assuming the first column is 'MetricType' (RAM/Storage) and subsequent columns are methods
# And the header row is like: MetricType, Method1, Method2, ...
# Transpose to make methods as rows for easier lookup might be an option,
# but let's try direct parsing.
# Find the row for RAM and Storage
ram_row = df[df.iloc[:, 0] == 'RAM'].iloc[0]
storage_row = df[df.iloc[:, 0] == 'Storage'].iloc[0]
methods = df.columns[1:] # First column is the metric type label
for method in methods:
storage_data[method] = {
'RAM': pd.to_numeric(ram_row[method], errors='coerce'),
'Storage': pd.to_numeric(storage_row[method], errors='coerce')
}
return storage_data
# Load data
latency_csv_path = 'paper_plot/data/main_latency.csv'
storage_csv_path = 'paper_plot/data/ram_storage.csv'
latency_data, llm_generation_times = parse_latency_data(latency_csv_path)
storage_info = parse_storage_data(storage_csv_path)
# --- Determine unique Datasets and Hardware combinations to plot for ---
unique_dataset_hardware_configs = sorted(list(set((d, h) for d, h, m in latency_data.keys())))
if not unique_dataset_hardware_configs:
print("Error: No (Dataset, Hardware) combinations found in latency data. Check CSV paths and content.")
exit()
# --- Define constants for plotting ---
all_method_names = sorted(list(set(m for d,h,m in latency_data.keys())))
if not all_method_names:
# Fallback if latency_data is empty but storage_info might have method names
all_method_names = sorted(list(storage_info.keys()))
if not all_method_names:
print("Error: No method names found in data. Cannot proceed with plotting.")
exit()
method_markers = {
'HNSW': 'o',
'IVF': 'X',
'DiskANN': 's',
'IVF-Disk': 'P',
'IVF-Recompute': '^',
'Our': '*',
'BM25': "v"
# Add more if necessary, or make it dynamic
}
method_display_names = {
'IVF-Recompute': 'IVF-Recompute (EdgeRAG)',
# 其他方法保持原名
}
# Ensure all methods have a marker
default_markers = ['^', 'v', '<', '>', 'H', 'h', '+', 'x', '|', '_']
next_default_marker = 0
for mn in all_method_names:
if mn not in method_markers:
print(f"mn: {mn}")
method_markers[mn] = default_markers[next_default_marker % len(default_markers)]
next_default_marker +=1
recall_levels_present = sorted(list(set(r for key in latency_data for r, l in latency_data[key])))
# Define colors for up to a few common recall levels, add more if needed
base_recall_colors = {
85.0: "#1f77b4", # Blue
90.0: "#ff7f0e", # Orange
95.0: "#2ca02c", # Green
# Add more if other recall % values exist
}
recall_colors = {}
color_palette = sns.color_palette("viridis", n_colors=len(recall_levels_present))
for idx, r_level in enumerate(recall_levels_present):
recall_colors[r_level] = base_recall_colors.get(r_level, color_palette[idx % len(color_palette)])
# --- Determine global x (latency) and y (storage) limits for consistent axes ---
all_latency_values = []
all_storage_values = []
raw_data_size = 76 # Raw data size in GB
for ds_hw_key in unique_dataset_hardware_configs:
current_ds, current_hw = ds_hw_key
for method_name in all_method_names:
# Get storage for this method
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
if not np.isnan(disk_storage):
all_storage_values.append(disk_storage)
# Get latencies for this method under current_ds, current_hw
latency_key = (current_ds, current_hw, method_name)
if latency_key in latency_data:
for recall, latency in latency_data[latency_key]:
if not np.isnan(latency):
all_latency_values.append(latency)
# Add padding to limits
min_lat = min(all_latency_values) if all_latency_values else 0.001
max_lat = max(all_latency_values) if all_latency_values else 1
min_store = min(all_storage_values) if all_storage_values else 0
max_store = max(all_storage_values) if all_storage_values else 1
# Convert storage values to proportion of raw data
min_store_proportion = min_store / raw_data_size if all_storage_values else 0
max_store_proportion = max_store / raw_data_size if all_storage_values else 0.1
# Padding for log scale latency - adjust minimum to be more reasonable
lat_log_min = -1 # Changed from -2 to -1 to set minimum to 10^-1 (0.1s)
lat_log_max = np.log10(max_lat) if max_lat > 0 else 3 # default to 1000 s
lat_padding = (lat_log_max - lat_log_min) * 0.05
global_xlim = [10**(lat_log_min - lat_padding), 10**(lat_log_max + lat_padding)]
if global_xlim[0] <= 0: global_xlim[0] = 0.1 # Changed from 0.01 to 0.1
# Padding for linear scale storage proportion
store_padding = (max_store_proportion - min_store_proportion) * 0.05
global_ylim = [max(0, min_store_proportion - store_padding), max_store_proportion + store_padding]
if global_ylim[0] >= global_ylim[1]: # Avoid inverted or zero range
global_ylim[1] = global_ylim[0] + 0.1
# After loading the data and before plotting, add this code to reorder the datasets
# Find where you define all_datasets (around line 95)
# Original code:
all_datasets = sorted(list(set(ds for ds, _ in unique_dataset_hardware_configs)))
# Replace with this to specify the exact order:
all_datasets_unsorted = list(set(ds for ds, _ in unique_dataset_hardware_configs))
desired_order = ['NQ', 'TriviaQA', 'GPQA','HotpotQA']
all_datasets = [ds for ds in desired_order if ds in all_datasets_unsorted]
# Add any datasets that might be in the data but not in our desired_order list
all_datasets.extend([ds for ds in all_datasets_unsorted if ds not in desired_order])
# Then the rest of your code remains the same:
a10_configs = [(ds, 'A10') for ds in all_datasets if (ds, 'A10') in unique_dataset_hardware_configs]
mac_configs = [(ds, 'MAC') for ds in all_datasets if (ds, 'MAC') in unique_dataset_hardware_configs]
# Create two figures - one for A10 and one for MAC
hardware_configs = [a10_configs, mac_configs]
hardware_names = ['A10', 'MAC']
for fig_idx, configs_for_this_figure in enumerate(hardware_configs):
if not configs_for_this_figure:
continue
num_cols_this_figure = len(configs_for_this_figure)
# 1 row, num_cols_this_figure columns
fig, axs = plt.subplots(1, num_cols_this_figure, figsize=(7 * num_cols_this_figure, 6), sharex=True, sharey=True, squeeze=False)
# fig.suptitle(f"Latency vs. Storage ({hardware_names[fig_idx]})", fontsize=18, y=0.98)
for subplot_idx, (current_ds, current_hw) in enumerate(configs_for_this_figure):
ax = axs[0, subplot_idx] # Accessing column in the first row
ax.set_title(f"{current_ds}", fontsize=25) # No need to show hardware in title since it's in suptitle
for method_name in all_method_names:
marker = method_markers.get(method_name, '+')
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
latency_points_key = (current_ds, current_hw, method_name)
if latency_points_key in latency_data:
points_for_method = latency_data[latency_points_key]
print(f"points_for_method: {points_for_method}")
for recall, latency in points_for_method:
# Only skip if latency is invalid (since we need log scale for x-axis)
# But allow zero storage since y-axis is now linear
if np.isnan(latency) or np.isnan(disk_storage) or latency <= 0:
continue
# Add LLM generation time from CSV
current_llm_add_time = llm_generation_times.get((current_ds, current_hw))
if current_llm_add_time is not None and not np.isnan(current_llm_add_time):
latency = latency + current_llm_add_time
else:
raise ValueError(f"No LLM generation time found for {current_ds} on {current_hw}")
# Special handling for BM25
if method_name == 'BM25':
# BM25 is only valid for 85% recall points (other points are 0)
if recall != 85.0:
continue
color = 'grey'
else:
# Use the color for target recall
color = recall_colors.get(recall, 'grey')
# Convert storage to proportion
disk_storage_proportion = disk_storage / raw_data_size
size = 80
x_offset = -50
if current_ds == 'GPQA':
x_offset = -32
# Apply a small vertical offset to IVF-Recompute points to make them more visible
if method_name == 'IVF-Recompute':
# Add a small vertical offset (adjust the 0.05 value as needed)
disk_storage_proportion += 0.07
size = 80
if method_name == 'DiskANN':
size = 50
if method_name == 'Our':
size = 140
disk_storage_proportion += 0.05
# Add "Pareto Frontier" label to Our method points
if recall == 95:
ax.annotate('Ours',
(latency, disk_storage_proportion),
xytext=(x_offset, 25), # Increased leftward offset from -65 to -120
textcoords='offset points',
fontsize=20,
color='red',
weight='bold',
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="red", alpha=0.7))
# Increase size for BM25 points
if method_name == 'BM25':
size = 70
size*=5
ax.scatter(latency, disk_storage_proportion, marker=marker, color=color,
s=size, alpha=0.85, edgecolors='black', linewidths=0.7)
ax.set_xscale("log")
ax.set_yscale("linear") # CHANGED from log scale to linear scale for Y-axis
# Generate appropriate powers of 10 based on your data range
min_power = -1
max_power = 4
log_ticks = [10**i for i in range(min_power, max_power+1)]
# Set custom tick positions
ax.set_xticks(log_ticks)
# Create custom bold LaTeX labels with 10^n format
log_tick_labels = [fr'$\mathbf{{10^{{{i}}}}}$' for i in range(min_power, max_power+1)]
ax.set_xticklabels(log_tick_labels, fontsize=24)
# Apply global limits
if subplot_idx == 0:
ax.set_xlim(global_xlim)
ax.set_ylim(global_ylim)
ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.7)
# Remove minor grid lines completely
ax.grid(False, which="minor")
# Remove ticks
# First set the shared parameters for both axes
ax.tick_params(axis='both', which='both', length=0, labelsize=24)
# Then set the padding only for the x-axis
ax.tick_params(axis='x', which='both', pad=10)
if subplot_idx == 0: # Y-label only for the leftmost subplot
ax.set_ylabel("Proportional Size", fontsize=24)
# X-label for all subplots in a 1xN layout can be okay, or just the middle/last one.
# Let's put it on all for now.
ax.set_xlabel("Latency (s)", fontsize=25)
# Display 100%, 200%, 300% for yaxis
ax.set_yticks([1, 2, 3])
ax.set_yticklabels(['100\%', '200\\%', '300\\%'])
# Create a custom arrow with "Better" text inside
# Create the arrow patch with a wider shaft
arrow = FancyArrowPatch(
(0.8, 0.8), # Start point (top-right)
(0.65, 0.6), # End point (toward bottom-left)
transform=ax.transAxes,
arrowstyle='simple,head_width=40,head_length=35,tail_width=20', # Increased arrow dimensions
facecolor='white',
edgecolor='black',
linewidth=3, # Thicker outline
zorder=5
)
# Add the arrow to the plot
ax.add_patch(arrow)
# Calculate the midpoint of the arrow for text placement
mid_x = (0.8 + 0.65) / 2 + 0.002 + 0.01
mid_y = (0.8 + 0.6) / 2 + 0.01
# Add the "Better" text at the midpoint of the arrow
ax.text(mid_x, mid_y, 'Better',
transform=ax.transAxes,
ha='center',
va='center',
fontsize=16, # Increased font size from 12 to 16
fontweight='bold',
rotation=40, # Rotate to match arrow direction
zorder=6) # Ensure text is on top of arrow
# Create legends (once per figure)
method_legend_handles = []
for method, marker_style in method_markers.items():
if method in all_method_names:
print(f"method: {method}")
# Use black color for BM25 in the legend
if method == 'BM25':
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
markersize=10, label=method))
else:
if method in method_display_names:
method = method_display_names[method]
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
markersize=10, label=method))
recall_legend_handles = []
sorted_recall_levels = sorted(recall_colors.keys())
for r_level in sorted_recall_levels:
recall_legend_handles.append(mlines.Line2D([], [], color=recall_colors[r_level], marker='o', linestyle='None',
markersize=20, label=f"Target Recall={r_level:.0f}\%"))
# 将图例分成两行:第一行是方法,第二行是召回率
if fig_idx == 0:
# 从方法列表中先排除'Our'
other_methods = [m for m in all_method_names if m != 'Our']
# 按照需要的顺序创建方法列表(将'Our'放在最后)
ordered_methods = other_methods + (['Our'] if 'Our' in all_method_names else [])
# 按照新顺序创建方法图例句柄
method_legend_handles = []
for method in ordered_methods:
if method in method_markers:
marker_style = method_markers[method]
# 使用显示名称映射
display_name = method_display_names.get(method, method)
color = 'black'
marker_size = 22
if method == 'Our':
marker_size = 27
elif 'IVF-Recompute' in method or 'EdgeRAG' in method:
marker_size = 17
elif 'DiskANN' in method:
marker_size = 19
elif 'BM25' in method:
marker_size = 20
method_legend_handles.append(mlines.Line2D([], [], color=color, marker=marker_style,
linestyle='None', markersize=marker_size, label=display_name))
# 创建召回率图例(第二行)- 注意位置调整,放在方法图例下方
recall_legend = fig.legend(handles=recall_legend_handles,
loc='upper center', bbox_to_anchor=(0.5, 1.05), # y坐标降低放在第一行下方
ncol=len(recall_legend_handles), fontsize=28)
# 创建方法图例(第一行)
method_legend = fig.legend(handles=method_legend_handles,
loc='upper center', bbox_to_anchor=(0.5, 0.91),
ncol=len(method_legend_handles), fontsize=28)
# 添加图例到渲染器
fig.add_artist(method_legend)
fig.add_artist(recall_legend)
# 调整布局,为顶部的两行图例留出更多空间
plt.tight_layout(rect=(0, 0, 1.0, 0.74)) # 顶部空间从0.9调整到0.85,给两行图例留出更多空间
save_path = f'./paper_plot/figures/main_exp_fig_{fig_idx+1}.pdf'
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved figure {fig_idx+1} to {save_path}")
plt.show()

View File

@@ -0,0 +1,163 @@
import csv
import numpy as np
import matplotlib.pyplot as plt
import csv
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
SAVE_PTH = "./paper_plot/figures"
font_size = 16
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
# 0.085s 0.217s 0.472s
llm_inference_time=[0.085, 0.217, 0.472, 0]
USE_LLM_INDEX = 3 # +0
file_path = "./paper_plot/data/main_latency.csv"
with open(file_path, mode="r", newline="") as file:
reader = csv.reader(file)
data = list(reader)
# 打印原始数据
for row in data:
print(",".join(row))
models = ["A10", "MAC"]
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
data = [[float(cell) if cell.isdigit() else cell for cell in row] for row in data[1:]]
for k, model in enumerate(models):
fig, axes = plt.subplots(1, 4)
fig.set_size_inches(20, 3)
plt.subplots_adjust(wspace=0, hspace=0)
total_width, n = 6, 6
group = 1
width = total_width * 0.9 / n
x = np.arange(group) * n
exit_idx_x = x + (total_width - width) / n
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "mediumpurple", "green", "red", "blue", "yellow", "silver"]
# hatches = ["", "\\\\", "//", "||", "x", "--", "..", "", "\\\\", "//", "||", "x", "--", ".."]
hatches =["\\\\\\","\\\\"]
labels = [
"HNSW",
"IVF",
"DiskANN",
"IVF-Disk",
"IVF-Recompute",
"Our",
# "DGL-OnDisk",
]
if k == 0:
x_labels = "GraphSAGE"
else:
x_labels = "GAT"
yticks = [0.01, 0.1, 1, 10, 100, 1000,10000] # Log scale ticks
val_limit = 15000 # Upper limit for the plot
for i in range(4):
axes[i].set_yscale('log') # Set y-axis to logarithmic scale
axes[i].set_yticks(yticks)
axes[i].set_ylim(0.01, val_limit) # Lower limit should be > 0 for log scale
axes[i].tick_params(axis="y", labelsize=10)
axes[i].set_xticks([])
# axes[i].set_xticklabels()
axes[i].set_xlabel(datasets[i], fontsize=font_size)
axes[i].grid(axis="y", linestyle="--")
axes[i].set_xlim(exit_idx_x[0] - 0.15 * width - 0.2, exit_idx_x[0] + (n-0.25)* width + 0.2)
for j in range(n):
##TODO add label
# num = float(data[i * 2 + k][j + 3])
# plot_label = [num]
# if j == 6 and i == 3:
# plot_label = ["N/A"]
# num = 0
local_hatches=["////","\\\\","xxxx"]
# here add 3 bars rather than one bar TODO
print('exit_idx_x',exit_idx_x)
# Check if all three models for this algorithm are OOM (data = 0)
is_oom = True
for m in range(3):
if float(data[i * 6 + k*3 + m][j + 3]) != 0:
is_oom = False
break
if is_oom:
# Draw a cross for OOM instead of bars
pos = exit_idx_x + j * width + width * 0.3 # Center position for cross
marker_size = width * 150 # Size of the cross
axes[i].scatter(pos, 0.02, marker='x', color=edgecolors[j], s=marker_size,
linewidth=4, label=labels[j] if j < len(labels) else "", zorder=20)
else:
# Create three separate bar calls instead of trying to plot multiple bars at once
for m in range(3):
num = float(data[i * 6 + k*3 +m][j + 3]) +llm_inference_time[USE_LLM_INDEX]
plot_label = [num]
pos = exit_idx_x + j * width + width * 0.3 * m
print(f"j: {j}, m: {m}, pos: {pos}")
# For log scale, we need to ensure values are positive
plot_value = max(0.01, num) if num < val_limit else val_limit
container = axes[i].bar(
pos,
plot_value,
width=width * 0.3,
color="white",
edgecolor=edgecolors[j],
# edgecolor="k",
hatch=local_hatches[m], # Use different hatches for each of the 3 bars
linewidth=1.0,
label=labels[j] if m == 0 else "", # Only add label for the first bar
zorder=10,
)
# axes[i].bar_label(
# container,
# plot_label,
# fontsize=font_size - 2,
# zorder=200,
# fontweight="bold",
# )
if k == 0:
axes[0].legend(
bbox_to_anchor=(3.25, 1.02),
ncol=7,
loc="lower right",
# fontsize=font_size,
# markerscale=3,
labelspacing=0.2,
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
# fancybox=False,
handlelength=2,
handletextpad=0.5,
columnspacing=0.5,
prop={"weight": "bold", "size": font_size},
).set_zorder(100)
axes[0].set_ylabel("Runtime (log scale)", fontsize=font_size, fontweight="bold")
axes[0].set_yticklabels([r"$10^{-2}$", r"$10^{-1}$", r"$10^{0}$", r"$10^{1}$", r"$10^{2}$", r"$10^{3}$",r"$10^{4}$"], fontsize=font_size)
axes[1].set_yticklabels([])
axes[2].set_yticklabels([])
axes[3].set_yticklabels([])
plt.savefig(f"{SAVE_PTH }/speed_{model}_revised.pdf", bbox_inches="tight", dpi=300)
## print save
print(f"{SAVE_PTH }/speed_{model}_revised.pdf")

View File

@@ -0,0 +1,85 @@
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
# Comment Test
# om script.settings import DATA_PATH, FIGURE_PATH
# DATA_PATH ="/home/ubuntu/Power-RAG/paper_plot/data"
# FIGURE_PATH = "/home/ubuntu/Power-RAG/paper_plot/figures"
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 2
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
import numpy as np
import pandas as pd
# Load the RAM and Storage data directly from CSV
data = pd.read_csv("./paper_plot/data/ram_storage.csv")
# Explicitly reorder columns to ensure "Our" is at the end
cols = list(data.columns)
if "Our" in cols and cols[-1] != "Our":
cols.remove("Our")
cols.append("Our")
data = data[cols]
# Set up the figure with two columns
fig = plt.figure(figsize=(12, 3))
gs = GridSpec(1, 2, figure=fig)
ax1 = fig.add_subplot(gs[0, 0]) # Left panel for RAM
ax2 = fig.add_subplot(gs[0, 1]) # Right panel for Storage
# Define the visual style elements
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "silver", "navy"]
hatches = ["/////", "\\\\\\\\\\"]
# Calculate positions for the bars
methods = data.columns[1:] # Skip the 'Hardware' column
num_methods = len(methods)
# Reverse the order of methods for display (to have "Our" at the bottom)
methods = list(methods)[::-1]
y_positions = np.arange(num_methods)
bar_width = 0.6
# Plot RAM data in left panel
ram_bars = ax1.barh(
y_positions,
data.iloc[0, 1:].values[::-1], # Reverse the data to match reversed methods
height=bar_width,
color="white",
edgecolor=edgecolors[0],
hatch=hatches[0],
linewidth=1.0,
label="RAM",
zorder=10,
)
ax1.set_title("RAM Usage", fontsize=14, fontweight='bold')
ax1.set_yticks(y_positions)
ax1.set_yticklabels(methods, fontsize=14)
ax1.set_xlabel("Size (\\textit{GB})", fontsize=14)
ax1.xaxis.set_tick_params(labelsize=14)
# Plot Storage data in right panel
storage_bars = ax2.barh(
y_positions,
data.iloc[1, 1:].values[::-1], # Reverse the data to match reversed methods
height=bar_width,
color="white",
edgecolor=edgecolors[1],
hatch=hatches[1],
linewidth=1.0,
label="Storage",
zorder=10,
)
ax2.set_title("Storage Usage", fontsize=14, fontweight='bold')
ax2.set_yticks(y_positions)
ax2.set_yticklabels(methods, fontsize=14)
ax2.set_xlabel("Size (\\textit{GB})", fontsize=14)
ax2.xaxis.set_tick_params(labelsize=14)
plt.tight_layout()
plt.savefig("./paper_plot/figures/ram_storage_double_column.pdf", bbox_inches="tight", dpi=300)
print("Saving the figure to ./paper_plot/figures/ram_storage_double_column.pdf")

View File

@@ -0,0 +1,141 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /bottleneck_breakdown.py
# \brief: Illustrates the query time bottleneck on consumer devices (Final Version - Font & Legend Adjust).
# Author: Gemini Assistant (adapted from user's style and feedback)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter # Not strictly needed for just font, but imported if user wants to try
# Set matplotlib styles similar to the example
plt.rcParams["font.family"] = "Helvetica" # Primary font family
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.0
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
# Attempt to make LaTeX use Helvetica as the main font
plt.rcParams['text.latex.preamble'] = r"""
\usepackage{helvet} % helvetica font
\usepackage{sansmath} % helvetica for math
\sansmath % activate sansmath
\renewcommand{\familydefault}{\sfdefault} % make sans-serif the default family
"""
# Final Data for the breakdown (3 Segments)
labels_raw = [ # Raw labels before potential LaTeX escaping
'IO: Text + PQ Lookup',
'CPU: Tokenize + Distance Compute',
'GPU: Embedding Recompute',
]
# Times in ms, ordered for stacking
times_ms = np.array([
8.009, # Quantization
16.197, # Search
76.512, # Embedding Recomputation
])
total_time_ms = times_ms.sum()
percentages = (times_ms / total_time_ms) * 100
# Prepare labels for legend, escaping for LaTeX if active
labels_legend = []
# st1 = r'&' # Not needed as current labels_raw don't have '&'
for label, time, perc in zip(labels_raw, times_ms, percentages):
# Construct the percentage string carefully for LaTeX
perc_str = f"{perc:.1f}" + r"\%" # Correct way to form 'NN.N\%'
# label_tex = label.replace('&', st1) # Use if '&' is in labels_raw
label_tex = label # Current labels_raw are clean for LaTeX
labels_legend.append(
f"{label_tex}\n({time:.1f}ms, {perc_str})"
)
# Styling based on user's script
# Using first 3 from the provided lists
edgecolors_list = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
hatches_list = ["/////", "xxxxx", "\\\\\\\\\\"]
edgecolors = edgecolors_list[:3]
hatches = hatches_list[:3]
fill_color = "white"
# Create the figure and axes
# Adjusted figure size to potentially accommodate legend on the right
fig, ax = plt.subplots()
fig.set_size_inches(7, 1.5) # Width increased slightly, height adjusted
# Adjusted right margin for external legend, bottom for x-label
plt.subplots_adjust(left=0.12, right=0.72, top=0.95, bottom=0.25)
# Create the horizontal stacked bar
bar_height = 0.2
y_pos = 0
left_offset = 0
for i in range(len(times_ms)):
ax.barh(
y_pos,
times_ms[i],
height=bar_height,
left=left_offset,
color=fill_color,
edgecolor=edgecolors[i],
hatch=hatches[i],
linewidth=1.5,
label=labels_legend[i],
zorder=10
)
text_x_pos = left_offset + times_ms[i] / 2
if times_ms[i] > total_time_ms * 0.03: # Threshold for displaying text
ax.text(
text_x_pos,
y_pos,
f"{times_ms[i]:.1f}ms",
ha='center',
va='center',
fontsize=8,
fontweight='bold',
color='black',
zorder=20,
bbox=dict(facecolor='white', edgecolor='none', pad=0.5, alpha=0.8)
)
left_offset += times_ms[i]
# Set plot limits and labels
ax.set_xlim([0, total_time_ms * 1.02])
ax.set_xlabel("Time (ms)", fontsize=14, fontweight='bold', x=0.75, )
# Y-axis: Remove y-ticks and labels
ax.set_yticks([])
ax.set_yticklabels([])
# Legend: Placed to the right of the plot
ax.legend(
# (x, y) for anchor, (0,0) is bottom left, (1,1) is top right of AXES
# To place outside on the right, x should be > 1
bbox_to_anchor=(1.03, 0.5), # x > 1 means outside to the right, y=0.5 for vertical center
ncol=1, # Single column for a taller, narrower legend
loc="center left", # Anchor the legend's left-center to bbox_to_anchor point
labelspacing=0.5, # Adjust spacing
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
fancybox=False,
handlelength=1.5,
handletextpad=0.6,
columnspacing=1.5,
prop={"weight": "bold", "size": 9},
).set_zorder(100)
# Save the figure (using the original generic name as requested)
output_filename = "./bottleneck_breakdown.pdf"
# plt.tight_layout() # tight_layout might conflict with external legend; adjust subplots_adjust instead
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
print(f"Saved plot to {output_filename}")
# plt.show() # Uncomment to display plot interactively

View File

@@ -0,0 +1,226 @@
import matplotlib.pyplot as plt
import numpy as np
# import matplotlib.ticker as mticker # Not actively used
import os
FIGURE_PATH = "paper_plot/figures"
try:
os.makedirs(FIGURE_PATH, exist_ok=True)
print(f"Images will be saved to: {os.path.abspath(FIGURE_PATH)}")
except OSError as e:
print(f"Create {FIGURE_PATH} failed: {e}. Images will be saved in the current working directory.")
FIGURE_PATH = "."
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 2
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
method_labels = ["gte-small (33M)", "contriever-msmarco (110M)"]
dataset_names = ["NQ", "TriviaQA"]
metrics_plot1 = ["Exact Match", "F1"]
small_nq_f1 = 0.2621040899
small_tq_f1 = 0.4698198059
small_nq_em_score = 0.1845
small_tq_em_score = 0.4015
small_nq_time = 1.137
small_tq_time = 1.173
large_nq_f1 = 0.2841386117
large_tq_f1 = 0.4548340289
large_nq_em_score = 0.206
large_tq_em_score = 0.382
large_nq_time = 2.632
large_tq_time = 2.684
data_scores_plot1 = {
"NQ": {"Exact Match": [small_nq_em_score, large_nq_em_score], "F1": [small_nq_f1, large_nq_f1]},
"TriviaQA": {"Exact Match": [small_tq_em_score, large_tq_em_score], "F1": [small_tq_f1, large_tq_f1]}
}
latency_data_plot2 = {
"NQ": [small_nq_time, large_nq_time],
"TriviaQA": [small_tq_time, large_tq_time]
}
edgecolors = ["dimgrey", "tomato"]
hatches = ["/////", "\\\\\\\\\\"]
# Changed: bar_center_separation_in_group increased for larger gap
bar_center_separation_in_group = 0.42
# Changed: bar_visual_width decreased for narrower bars
bar_visual_width = 0.28
figsize_plot1 = (4, 2.5)
# Changed: figsize_plot2 width adjusted to match figsize_plot1 for legend/caption alignment
figsize_plot2 = (2.5, 2.5)
# Define plot1_xlim_per_subplot globally so it can be accessed by create_plot2_latency
plot1_xlim_per_subplot = (0.0, 2.0) # Explicit xlim for plot 1 subplots
common_subplots_adjust_params = dict(wspace=0.30, top=0.80, bottom=0.22, left=0.09, right=0.96)
def create_plot1_em_f1():
fig, axs = plt.subplots(1, 2, figsize=figsize_plot1)
fig.subplots_adjust(**common_subplots_adjust_params)
num_methods = len(method_labels)
metric_group_centers = np.array([0.5, 1.5])
# plot1_xlim_per_subplot is now global
for i, dataset_name in enumerate(dataset_names):
ax = axs[i]
for metric_idx, metric_name in enumerate(metrics_plot1):
metric_center_pos = metric_group_centers[metric_idx]
current_scores_raw = data_scores_plot1[dataset_name][metric_name]
current_scores_percent = [val * 100 for val in current_scores_raw]
for j, method_label in enumerate(method_labels):
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
bar_center_pos = metric_center_pos + offset
ax.bar(
bar_center_pos, current_scores_percent[j], width=bar_visual_width, color="white",
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
label=method_label if i == 0 and metric_idx == 0 else None
)
ax.text(
bar_center_pos, current_scores_percent[j] + 0.8, f"{current_scores_percent[j]:.1f}",
ha='center', va='bottom', fontsize=8, fontweight='bold'
)
ax.set_xticks(metric_group_centers)
ax.set_xticklabels(metrics_plot1, fontsize=9, fontweight='bold')
ax.set_title(dataset_name, fontsize=12, fontweight='bold')
ax.set_xlim(plot1_xlim_per_subplot) # Apply consistent xlim
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
if i == 0:
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
all_subplot_scores_percent = []
for metric_name_iter in metrics_plot1:
all_subplot_scores_percent.extend([val * 100 for val in data_scores_plot1[dataset_name][metric_name_iter]])
max_val = max(all_subplot_scores_percent) if all_subplot_scores_percent else 0
ax.set_ylim(0, max_val * 1.22 if max_val > 0 else 10)
ax.tick_params(axis='y', labelsize=12)
for spine in ax.spines.values():
spine.set_visible(True)
spine.set_linewidth(1.0)
spine.set_edgecolor("black")
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=len(method_labels),
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
prop={"weight": "bold", "size": 9}
)
# fig.text(0.5, 0.06, "(a) EM \& F1", ha='center', va='center', fontweight='bold', fontsize=11)
save_path = os.path.join(FIGURE_PATH, "plot1_em_f1.pdf")
# plt.tight_layout() # Adjusted call below
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
plt.close(fig)
print(f"Figure 1 (Exact Match & F1) has been saved to: {save_path}")
def create_plot2_latency():
fig, axs = plt.subplots(1, 2, figsize=figsize_plot2) # figsize_plot2 width is now 8.0
fig.subplots_adjust(**common_subplots_adjust_params)
num_methods = len(method_labels)
method_group_center_in_subplot = 0.5
# Calculate bar extents to determine focused xlim
bar_positions_calc = []
for j_idx in range(num_methods):
offset_calc = (j_idx - (num_methods - 1) / 2.0) * bar_center_separation_in_group
bar_center_pos_calc = method_group_center_in_subplot + offset_calc
bar_positions_calc.append(bar_center_pos_calc)
min_bar_actual_edge = min(bar_positions_calc) - bar_visual_width / 2.0
max_bar_actual_edge = max(bar_positions_calc) + bar_visual_width / 2.0
# Define padding around the bars
# Option 1: Fixed padding (e.g., 0.15 as derived from plot 1 visual)
# padding_val = 0.15
# plot2_xlim_calculated = (min_bar_actual_edge - padding_val, max_bar_actual_edge + padding_val)
# This would be (0.15 - 0.15, 0.85 + 0.15) = (0.0, 1.0)
# Option 2: Center the group (0.5) in a span of 1.0
plot2_xlim_calculated = (method_group_center_in_subplot - 0.5, method_group_center_in_subplot + 0.5)
# This is (0.5 - 0.5, 0.5 + 0.5) = (0.0, 1.0)
# This is simpler and achieves the (0.0, 1.0) directly.
for i, dataset_name in enumerate(dataset_names):
ax = axs[i]
current_latencies = latency_data_plot2[dataset_name]
for j, method_label in enumerate(method_labels):
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
bar_center_pos = method_group_center_in_subplot + offset
ax.bar(
bar_center_pos, current_latencies[j], width=bar_visual_width, color="white",
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
label=method_label if i == 0 else None
)
ax.text(
bar_center_pos, current_latencies[j] + 0.05, f"{current_latencies[j]:.2f}",
ha='center', va='bottom', fontsize=10, fontweight='bold'
)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
ax.set_xticks([0.5])
ax.set_xticklabels(["Latency"], color="white", fontsize=12)
# set tick hatches
ax.tick_params(axis='x', colors="white")
ax.set_title(dataset_name, fontsize=13, fontweight='bold')
ax.set_xlim(plot2_xlim_calculated)
if i == 0:
ax.set_ylabel("Latency (s)", fontsize=12, fontweight="bold")
max_latency_in_subplot = max(current_latencies) if current_latencies else 0
ax.set_ylim(0, max_latency_in_subplot * 1.22 if max_latency_in_subplot > 0 else 1)
ax.tick_params(axis='y', labelsize=12)
for spine in ax.spines.values():
spine.set_visible(True)
spine.set_linewidth(1.0)
spine.set_edgecolor("black")
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=num_methods,
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
prop={"weight": "bold", "size": 9}
)
# fig.text(0.5, 0.06, "(b) Latency", ha='center', va='center', fontweight='bold', fontsize=11)
save_path = os.path.join(FIGURE_PATH, "plot2_latency.pdf")
# plt.tight_layout() # Adjusted call below
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
plt.close(fig)
print(f"Figure 2 (Latency) has been saved to: {save_path}")
if __name__ == "__main__":
print("Start generating figures...")
if plt.rcParams["text.usetex"]:
print("Info: LaTeX rendering is enabled. Ensure LaTeX is installed and configured if issues arise, or set plt.rcParams['text.usetex'] to False.")
create_plot1_em_f1()
create_plot2_latency()
print("All figures have been generated.")

View File

@@ -0,0 +1,111 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /speed_ablation.py
# \brief:
# Author: raphael hao
# %%
import numpy as np
import pandas as pd
# %%
# from script.settings import DATA_PATH, FIGURE_PATH
# Load the latency ablation data
latency_data = pd.read_csv("./paper_plot/data/latency_ablation.csv")
# Filter for SpeedUp metric only
speedup_data = latency_data[latency_data['Metric'] == 'SpeedUp']
# %%
from matplotlib import pyplot as plt
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
# %%
fig, ax = plt.subplots()
fig.set_size_inches(5, 1.5)
plt.subplots_adjust(wspace=0, hspace=0)
total_width, n = 3, 3
group = len(speedup_data['Dataset'].unique())
width = total_width * 0.9 / n
x = np.arange(group) * n
exit_idx_x = x + (total_width - width) / n
edgecolors = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
hatches = ["/////", "xxxxx", "\\\\\\\\\\"]
labels = ["Base", "Base + Two-level", "Base + Two-level + Batch"]
datasets = speedup_data['Dataset'].unique()
for i, dataset in enumerate(datasets):
dataset_data = speedup_data[speedup_data['Dataset'] == dataset]
for j in range(n):
if j == 0:
value = dataset_data['Original'].values[0]
elif j == 1:
value = dataset_data['original + two_level'].values[0]
else:
value = dataset_data['original + two_level + batch'].values[0]
ax.text(
exit_idx_x[i] + j * width,
value + 0.05,
f"{value:.2f}",
ha='center',
va='bottom',
fontsize=10,
fontweight='bold',
rotation=0,
zorder=20,
)
ax.bar(
exit_idx_x[i] + j * width,
value,
width=width * 0.8,
color="white",
edgecolor=edgecolors[j],
hatch=hatches[j],
linewidth=1.5,
label=labels[j] if i == 0 else None,
zorder=10,
)
ax.set_ylim([0.5, 2.3])
ax.set_yticks(np.arange(0.5, 2.2, 0.5))
ax.set_yticklabels(np.arange(0.5, 2.2, 0.5), fontsize=12)
ax.set_xticks(exit_idx_x + width)
ax.set_xticklabels(datasets, fontsize=10)
# ax.set_xlabel("Different Datasets", fontsize=14)
ax.legend(
bbox_to_anchor=(-0.03, 1.4),
ncol=3,
loc="upper left",
labelspacing=0.1,
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
fancybox=False,
handlelength=0.8,
handletextpad=0.6,
columnspacing=0.8,
prop={"weight": "bold", "size": 10},
).set_zorder(100)
ax.set_ylabel("Speedup", fontsize=11)
plt.savefig("./paper_plot/figures/latency_speedup.pdf", bbox_inches="tight", dpi=300)
# %%
print(f"Save to ./paper_plot/figures/latency_speedup.pdf")