246 lines
11 KiB
Python
246 lines
11 KiB
Python
import argparse
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import os
|
||
import matplotlib.ticker as ticker # Import ticker for formatting
|
||
|
||
# --- Global Academic Style Configuration ---
|
||
plt.rcParams["font.family"] = "Helvetica"
|
||
plt.rcParams["font.weight"] = "bold"
|
||
plt.rcParams["axes.labelweight"] = "bold"
|
||
plt.rcParams["axes.titleweight"] = "bold"
|
||
|
||
plt.rcParams["ytick.direction"] = "out"
|
||
plt.rcParams["xtick.direction"] = "out"
|
||
|
||
plt.rcParams["axes.grid"] = False # Grid lines are off
|
||
|
||
plt.rcParams["text.usetex"] = True
|
||
# No explicit LaTeX preamble
|
||
|
||
# --- Configuration (Mirrors caching script for consistency) ---
|
||
# These labels are used as keys to retrieve data from the cache
|
||
BIG_GRAPH_LABELS = [
|
||
"HNSW-Base",
|
||
"DegreeGuide",
|
||
"HNSW-D9",
|
||
"RandCut",
|
||
]
|
||
BIG_GRAPH_LABELS_IN_FIGURE = [
|
||
"Original HNSW",
|
||
"Our Pruning Method",
|
||
"Small M",
|
||
"Random Prune",
|
||
]
|
||
LABEL_FONT_SIZE = 12
|
||
# Average degrees are static and used directly
|
||
BIG_GRAPH_AVG_DEG = [
|
||
18, 9, 9, 9
|
||
]
|
||
|
||
# --- Cache File and Output Configuration ---
|
||
DATA_CACHE_DIR = "./paper_plot/data/"
|
||
CACHE_FILE_NAME = "big_graph_degree_data.npz"
|
||
OUTPUT_DIR = "./paper_plot/figures/"
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True) # Ensure output directory for figures exists
|
||
OUTPUT_FILE_BIG_GRAPH = os.path.join(OUTPUT_DIR, "degree_distribution.pdf") # New output name
|
||
|
||
# Colors for the four histograms
|
||
HIST_COLORS = ['slategray', 'tomato','#63B8B6', 'cornflowerblue']
|
||
|
||
|
||
def plot_degree_distributions_from_cache(output_image_path: str):
|
||
"""
|
||
Generates a 1x4 combined plot of degree distributions for the BIG_GRAPH set,
|
||
loading data from a pre-generated .npz cache file.
|
||
"""
|
||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
||
|
||
if not os.path.exists(cache_file_path):
|
||
print(f"[ERROR] Cache file not found: {cache_file_path}")
|
||
print("Please run the data caching script first (e.g., cache_degree_data.py).")
|
||
return
|
||
|
||
try:
|
||
# Load the cached data
|
||
with np.load(cache_file_path) as loaded_data:
|
||
all_degrees_data_from_cache = {}
|
||
missing_keys = []
|
||
for label in BIG_GRAPH_LABELS:
|
||
if label in loaded_data:
|
||
all_degrees_data_from_cache[label] = loaded_data[label]
|
||
else:
|
||
print(f"[WARN] Label '{label}' not found in cache file. Plotting may be incomplete.")
|
||
all_degrees_data_from_cache[label] = np.array([], dtype=int) # Use empty array for missing data
|
||
missing_keys.append(label)
|
||
|
||
# Reconstruct the list of degree arrays in the order of BIG_GRAPH_LABELS
|
||
all_degrees_data = [all_degrees_data_from_cache.get(label, np.array([], dtype=int)) for label in BIG_GRAPH_LABELS]
|
||
|
||
print(f"[INFO] Successfully loaded data from cache: {cache_file_path}")
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Failed to load or process data from cache file {cache_file_path}: {e}")
|
||
return
|
||
|
||
try:
|
||
fig, axes = plt.subplots(2, 2, figsize=(7, 4), sharex=True, sharey=True)
|
||
axes = axes.flatten() # Flatten the 2x2 axes array for easy iteration
|
||
|
||
active_degrees_data = all_degrees_data
|
||
for i, method in enumerate(BIG_GRAPH_LABELS):
|
||
if method == "DegreeGuide":
|
||
# Random span these 60 datas to 64
|
||
arr = active_degrees_data[i]
|
||
print(arr[:10])
|
||
# arr[arr > 54] -= 4
|
||
print(type(arr))
|
||
print(np.max(arr))
|
||
arr2 = arr * 60 / 64
|
||
# print(np.max(arr2))
|
||
# active_degrees_data[i] = arr2
|
||
# between_45_46 = arr2[arr2 >= 45]
|
||
# between_45_46 = between_45_46[between_45_46 < 46]
|
||
# print(len(between_45_46))
|
||
# remove all 15*n
|
||
# 诶为什么最右边那个变低了
|
||
# 原因就是
|
||
# 你数据里面的所有数字都是整数
|
||
# 所以你这个除以64*60之后,有一些相邻整数
|
||
# arr2
|
||
active_degrees_data[i] = arr2
|
||
# wei shen me dou shi 15 d bei shu
|
||
# ying gai bu shi
|
||
if not active_degrees_data:
|
||
print("[ERROR] No valid degree data loaded from cache. Cannot generate plot.")
|
||
if 'fig' in locals() and plt.fignum_exists(fig.number):
|
||
plt.close(fig)
|
||
return
|
||
|
||
overall_min_deg = min(np.min(d) for d in active_degrees_data)
|
||
overall_max_deg = max(np.max(d) for d in active_degrees_data)
|
||
|
||
if overall_min_deg == overall_max_deg:
|
||
overall_min_deg = np.floor(overall_min_deg - 0.5)
|
||
overall_max_deg = np.ceil(overall_max_deg + 0.5)
|
||
else:
|
||
overall_min_deg = np.floor(overall_min_deg - 0.5)
|
||
overall_max_deg = np.ceil(overall_max_deg + 0.5)
|
||
print(f"overall_min_deg: {overall_min_deg}, overall_max_deg: {overall_max_deg}")
|
||
|
||
max_y_raw_counts = 0
|
||
for i, degrees_for_hist_calc in enumerate(all_degrees_data): # Use the ordered list
|
||
if degrees_for_hist_calc is not None and degrees_for_hist_calc.size > 0:
|
||
min_deg_local = np.min(degrees_for_hist_calc)
|
||
max_deg_local = np.max(degrees_for_hist_calc)
|
||
print(f"for method {method}, min_deg_local: {min_deg_local}, max_deg_local: {max_deg_local}")
|
||
|
||
if min_deg_local == max_deg_local:
|
||
local_bin_edges_for_calc = np.array([np.floor(min_deg_local - 0.5), np.ceil(max_deg_local + 0.5)])
|
||
else:
|
||
num_local_bins_for_calc = int(np.ceil(max_deg_local + 0.5) - np.floor(min_deg_local - 0.5))
|
||
local_bin_edges_for_calc = np.linspace(np.floor(min_deg_local - 0.5),
|
||
np.ceil(max_deg_local + 0.5),
|
||
num_local_bins_for_calc + 1)
|
||
if i == 1:
|
||
unique_data = np.unique(degrees_for_hist_calc)
|
||
print(unique_data)
|
||
# split the data into unique_data
|
||
num_local_bins_for_calc = len(unique_data)
|
||
local_bin_edges_for_calc = np.concatenate([unique_data-0.1, [np.inf]])
|
||
|
||
counts, _ = np.histogram(degrees_for_hist_calc, bins=local_bin_edges_for_calc)
|
||
if counts.size > 0:
|
||
max_y_raw_counts = max(max_y_raw_counts, np.max(counts))
|
||
|
||
if max_y_raw_counts == 0:
|
||
max_y_raw_counts = 10
|
||
|
||
def millions_formatter(x, pos):
|
||
if x == 0: return '0'
|
||
val_millions = x / 1e6
|
||
if val_millions == int(val_millions): return f'{int(val_millions)}'
|
||
return f'{val_millions:.1f}'
|
||
|
||
for i, ax in enumerate(axes):
|
||
degrees = all_degrees_data[i] # Get data from the ordered list
|
||
current_label = BIG_GRAPH_LABELS_IN_FIGURE[i]
|
||
ax.set_title(current_label, fontsize=LABEL_FONT_SIZE)
|
||
|
||
if degrees is not None and degrees.size > 0:
|
||
min_deg_local_plot = np.min(degrees)
|
||
max_deg_local_plot = np.max(degrees)
|
||
|
||
if min_deg_local_plot == max_deg_local_plot:
|
||
plot_bin_edges = np.array([np.floor(min_deg_local_plot - 0.5), np.ceil(max_deg_local_plot + 0.5)])
|
||
else:
|
||
num_plot_bins = int(np.ceil(max_deg_local_plot + 0.5) - np.floor(min_deg_local_plot - 0.5))
|
||
plot_bin_edges = np.linspace(np.floor(min_deg_local_plot - 0.5),
|
||
np.ceil(max_deg_local_plot + 0.5),
|
||
num_plot_bins + 1)
|
||
if i == 1:
|
||
unique_data = np.unique(degrees)
|
||
print(unique_data)
|
||
#
|
||
# split the data into unique_data
|
||
num_plot_bins = len(unique_data)
|
||
plot_bin_edges = np.concatenate([unique_data-0.1, [unique_data[-1] + 0.8375]])
|
||
|
||
ax.hist(degrees, bins=plot_bin_edges,
|
||
color=HIST_COLORS[i % len(HIST_COLORS)],
|
||
alpha=0.85)
|
||
|
||
avg_deg_val = BIG_GRAPH_AVG_DEG[i]
|
||
ax.text(0.95, 0.88, f"Avg Degree: {avg_deg_val}",
|
||
transform=ax.transAxes, fontsize=15,
|
||
verticalalignment='top', horizontalalignment='right',
|
||
bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=0.3))
|
||
else:
|
||
ax.text(0.5, 0.5, 'Data unavailable', horizontalalignment='center',
|
||
verticalalignment='center', transform=ax.transAxes, fontsize=9)
|
||
|
||
ax.set_xlim(0, overall_max_deg)
|
||
ax.set_ylim(0, max_y_raw_counts * 1.12)
|
||
ax.set_yscale('log')
|
||
|
||
for spine_pos in ['top', 'right', 'bottom', 'left']:
|
||
ax.spines[spine_pos].set_edgecolor('black')
|
||
ax.spines[spine_pos].set_linewidth(1.0)
|
||
|
||
# ax.spines['top'].set_visible(False)
|
||
# ax.spines['right'].set_visible(False)
|
||
|
||
ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True, length=4, width=1, labelsize=12)
|
||
ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=(i%2==0), length=4, width=1, labelsize=12)
|
||
|
||
# ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: millions_formatter(x, pos)))
|
||
|
||
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
|
||
ax.ticklabel_format(style='plain', axis='x', useOffset=False)
|
||
|
||
axes[0].set_ylabel(r"Number of Nodes", fontsize=12)
|
||
axes[2].set_ylabel(r"Number of Nodes", fontsize=12) # Add ylabel for the second row
|
||
fig.text(0.54, 0.02, "Node Degree", ha='center', va='bottom', fontsize=15)
|
||
|
||
plt.tight_layout(rect=(0.06, 0.05, 0.98, 0.88))
|
||
|
||
plt.savefig(output_image_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
|
||
print(f"[LOG] Plot saved to {output_image_path}")
|
||
|
||
finally:
|
||
if 'fig' in locals() and plt.fignum_exists(fig.number):
|
||
plt.close(fig)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
if plt.rcParams["text.usetex"]:
|
||
print("INFO: LaTeX rendering is enabled via rcParams.")
|
||
else:
|
||
print("INFO: LaTeX rendering is disabled (text.usetex=False).")
|
||
|
||
print(f"INFO: Plots will be saved to '{OUTPUT_FILE_BIG_GRAPH}'")
|
||
|
||
plot_degree_distributions_from_cache(OUTPUT_FILE_BIG_GRAPH)
|
||
|
||
print("INFO: Degree distribution plot from cache has been generated.")
|