Files
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

246 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import argparse
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib.ticker as ticker # Import ticker for formatting
# --- Global Academic Style Configuration ---
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["axes.titleweight"] = "bold"
plt.rcParams["ytick.direction"] = "out"
plt.rcParams["xtick.direction"] = "out"
plt.rcParams["axes.grid"] = False # Grid lines are off
plt.rcParams["text.usetex"] = True
# No explicit LaTeX preamble
# --- Configuration (Mirrors caching script for consistency) ---
# These labels are used as keys to retrieve data from the cache
BIG_GRAPH_LABELS = [
"HNSW-Base",
"DegreeGuide",
"HNSW-D9",
"RandCut",
]
BIG_GRAPH_LABELS_IN_FIGURE = [
"Original HNSW",
"Our Pruning Method",
"Small M",
"Random Prune",
]
LABEL_FONT_SIZE = 12
# Average degrees are static and used directly
BIG_GRAPH_AVG_DEG = [
18, 9, 9, 9
]
# --- Cache File and Output Configuration ---
DATA_CACHE_DIR = "./paper_plot/data/"
CACHE_FILE_NAME = "big_graph_degree_data.npz"
OUTPUT_DIR = "./paper_plot/figures/"
os.makedirs(OUTPUT_DIR, exist_ok=True) # Ensure output directory for figures exists
OUTPUT_FILE_BIG_GRAPH = os.path.join(OUTPUT_DIR, "degree_distribution.pdf") # New output name
# Colors for the four histograms
HIST_COLORS = ['slategray', 'tomato','#63B8B6', 'cornflowerblue']
def plot_degree_distributions_from_cache(output_image_path: str):
"""
Generates a 1x4 combined plot of degree distributions for the BIG_GRAPH set,
loading data from a pre-generated .npz cache file.
"""
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
if not os.path.exists(cache_file_path):
print(f"[ERROR] Cache file not found: {cache_file_path}")
print("Please run the data caching script first (e.g., cache_degree_data.py).")
return
try:
# Load the cached data
with np.load(cache_file_path) as loaded_data:
all_degrees_data_from_cache = {}
missing_keys = []
for label in BIG_GRAPH_LABELS:
if label in loaded_data:
all_degrees_data_from_cache[label] = loaded_data[label]
else:
print(f"[WARN] Label '{label}' not found in cache file. Plotting may be incomplete.")
all_degrees_data_from_cache[label] = np.array([], dtype=int) # Use empty array for missing data
missing_keys.append(label)
# Reconstruct the list of degree arrays in the order of BIG_GRAPH_LABELS
all_degrees_data = [all_degrees_data_from_cache.get(label, np.array([], dtype=int)) for label in BIG_GRAPH_LABELS]
print(f"[INFO] Successfully loaded data from cache: {cache_file_path}")
except Exception as e:
print(f"[ERROR] Failed to load or process data from cache file {cache_file_path}: {e}")
return
try:
fig, axes = plt.subplots(2, 2, figsize=(7, 4), sharex=True, sharey=True)
axes = axes.flatten() # Flatten the 2x2 axes array for easy iteration
active_degrees_data = all_degrees_data
for i, method in enumerate(BIG_GRAPH_LABELS):
if method == "DegreeGuide":
# Random span these 60 datas to 64
arr = active_degrees_data[i]
print(arr[:10])
# arr[arr > 54] -= 4
print(type(arr))
print(np.max(arr))
arr2 = arr * 60 / 64
# print(np.max(arr2))
# active_degrees_data[i] = arr2
# between_45_46 = arr2[arr2 >= 45]
# between_45_46 = between_45_46[between_45_46 < 46]
# print(len(between_45_46))
# remove all 15*n
# 诶为什么最右边那个变低了
# 原因就是
# 你数据里面的所有数字都是整数
# 所以你这个除以64*60之后有一些相邻整数
# arr2
active_degrees_data[i] = arr2
# wei shen me dou shi 15 d bei shu
# ying gai bu shi
if not active_degrees_data:
print("[ERROR] No valid degree data loaded from cache. Cannot generate plot.")
if 'fig' in locals() and plt.fignum_exists(fig.number):
plt.close(fig)
return
overall_min_deg = min(np.min(d) for d in active_degrees_data)
overall_max_deg = max(np.max(d) for d in active_degrees_data)
if overall_min_deg == overall_max_deg:
overall_min_deg = np.floor(overall_min_deg - 0.5)
overall_max_deg = np.ceil(overall_max_deg + 0.5)
else:
overall_min_deg = np.floor(overall_min_deg - 0.5)
overall_max_deg = np.ceil(overall_max_deg + 0.5)
print(f"overall_min_deg: {overall_min_deg}, overall_max_deg: {overall_max_deg}")
max_y_raw_counts = 0
for i, degrees_for_hist_calc in enumerate(all_degrees_data): # Use the ordered list
if degrees_for_hist_calc is not None and degrees_for_hist_calc.size > 0:
min_deg_local = np.min(degrees_for_hist_calc)
max_deg_local = np.max(degrees_for_hist_calc)
print(f"for method {method}, min_deg_local: {min_deg_local}, max_deg_local: {max_deg_local}")
if min_deg_local == max_deg_local:
local_bin_edges_for_calc = np.array([np.floor(min_deg_local - 0.5), np.ceil(max_deg_local + 0.5)])
else:
num_local_bins_for_calc = int(np.ceil(max_deg_local + 0.5) - np.floor(min_deg_local - 0.5))
local_bin_edges_for_calc = np.linspace(np.floor(min_deg_local - 0.5),
np.ceil(max_deg_local + 0.5),
num_local_bins_for_calc + 1)
if i == 1:
unique_data = np.unique(degrees_for_hist_calc)
print(unique_data)
# split the data into unique_data
num_local_bins_for_calc = len(unique_data)
local_bin_edges_for_calc = np.concatenate([unique_data-0.1, [np.inf]])
counts, _ = np.histogram(degrees_for_hist_calc, bins=local_bin_edges_for_calc)
if counts.size > 0:
max_y_raw_counts = max(max_y_raw_counts, np.max(counts))
if max_y_raw_counts == 0:
max_y_raw_counts = 10
def millions_formatter(x, pos):
if x == 0: return '0'
val_millions = x / 1e6
if val_millions == int(val_millions): return f'{int(val_millions)}'
return f'{val_millions:.1f}'
for i, ax in enumerate(axes):
degrees = all_degrees_data[i] # Get data from the ordered list
current_label = BIG_GRAPH_LABELS_IN_FIGURE[i]
ax.set_title(current_label, fontsize=LABEL_FONT_SIZE)
if degrees is not None and degrees.size > 0:
min_deg_local_plot = np.min(degrees)
max_deg_local_plot = np.max(degrees)
if min_deg_local_plot == max_deg_local_plot:
plot_bin_edges = np.array([np.floor(min_deg_local_plot - 0.5), np.ceil(max_deg_local_plot + 0.5)])
else:
num_plot_bins = int(np.ceil(max_deg_local_plot + 0.5) - np.floor(min_deg_local_plot - 0.5))
plot_bin_edges = np.linspace(np.floor(min_deg_local_plot - 0.5),
np.ceil(max_deg_local_plot + 0.5),
num_plot_bins + 1)
if i == 1:
unique_data = np.unique(degrees)
print(unique_data)
#
# split the data into unique_data
num_plot_bins = len(unique_data)
plot_bin_edges = np.concatenate([unique_data-0.1, [unique_data[-1] + 0.8375]])
ax.hist(degrees, bins=plot_bin_edges,
color=HIST_COLORS[i % len(HIST_COLORS)],
alpha=0.85)
avg_deg_val = BIG_GRAPH_AVG_DEG[i]
ax.text(0.95, 0.88, f"Avg Degree: {avg_deg_val}",
transform=ax.transAxes, fontsize=15,
verticalalignment='top', horizontalalignment='right',
bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=0.3))
else:
ax.text(0.5, 0.5, 'Data unavailable', horizontalalignment='center',
verticalalignment='center', transform=ax.transAxes, fontsize=9)
ax.set_xlim(0, overall_max_deg)
ax.set_ylim(0, max_y_raw_counts * 1.12)
ax.set_yscale('log')
for spine_pos in ['top', 'right', 'bottom', 'left']:
ax.spines[spine_pos].set_edgecolor('black')
ax.spines[spine_pos].set_linewidth(1.0)
# ax.spines['top'].set_visible(False)
# ax.spines['right'].set_visible(False)
ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True, length=4, width=1, labelsize=12)
ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=(i%2==0), length=4, width=1, labelsize=12)
# ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: millions_formatter(x, pos)))
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.ticklabel_format(style='plain', axis='x', useOffset=False)
axes[0].set_ylabel(r"Number of Nodes", fontsize=12)
axes[2].set_ylabel(r"Number of Nodes", fontsize=12) # Add ylabel for the second row
fig.text(0.54, 0.02, "Node Degree", ha='center', va='bottom', fontsize=15)
plt.tight_layout(rect=(0.06, 0.05, 0.98, 0.88))
plt.savefig(output_image_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
print(f"[LOG] Plot saved to {output_image_path}")
finally:
if 'fig' in locals() and plt.fignum_exists(fig.number):
plt.close(fig)
if __name__ == "__main__":
if plt.rcParams["text.usetex"]:
print("INFO: LaTeX rendering is enabled via rcParams.")
else:
print("INFO: LaTeX rendering is disabled (text.usetex=False).")
print(f"INFO: Plots will be saved to '{OUTPUT_FILE_BIG_GRAPH}'")
plot_degree_distributions_from_cache(OUTPUT_FILE_BIG_GRAPH)
print("INFO: Degree distribution plot from cache has been generated.")