Files
LEANN/research/paper_plot/graph_pruning_ablation.py
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

330 lines
17 KiB
Python

# python faiss/demo/plot_graph_struct.py faiss/demo/output.log
# python faiss/demo/plot_graph_struct.py large_graph_recompute.log
import argparse
import re
import matplotlib.pyplot as plt
import numpy as np
# Modified recall_levels and corresponding styles/widths from previous step
recall_levels = [0.90, 0.92, 0.94, 0.96]
line_styles = ['--', '-', '-', '-']
line_widths = [1, 1.5, 1.5, 1.5]
MAPPED_METHOD_NAMES = [
# 'HNSW-Base',
# 'DegreeGuide',
# 'HNSW-D9',
# 'RandCut',
"Original HNSW",
"Our Pruning Method",
"Small M",
"Random Prune",
]
PERFORMANCE_PLOT_PATH = './paper_plot/figures/H_hnsw_performance_comparison.pdf'
SAVED_PATH = './paper_plot/figures/H_hnsw_recall_comparison.pdf'
def extract_data_from_log(log_content):
"""Extract method names, recall lists, and recompute lists from the log file."""
method_pattern = r"Building HNSW index with ([^\.]+)\.\.\.|Building HNSW index with ([^\n]+)..."
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
method_matches = re.findall(method_pattern, log_content)
# Temporary list for raw method identifiers from regex
_methods_raw_identifiers_regex = []
for match in method_matches:
method_ident = match[0] if match[0] else match[1]
_methods_raw_identifiers_regex.append(method_ident.strip().rstrip('.'))
recall_lists_str = re.findall(recall_list_pattern, log_content)
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
avg_neighbors_str_list = re.findall(avg_neighbors_pattern, log_content) # Keep as string list for now
# Determine if regex approach was sufficient, similar to original logic
# This check helps decide if we use regex-extracted names or fallback to split-parsing
_min_len_for_regex_path = min(
len(_methods_raw_identifiers_regex) if _methods_raw_identifiers_regex else 0,
len(recall_lists_str) if recall_lists_str else 0,
len(recompute_lists_str) if recompute_lists_str else 0,
len(avg_neighbors_str_list) if avg_neighbors_str_list else 0
)
methods = [] # This will hold the final display names
if _min_len_for_regex_path < 4 : # Fallback path if regex didn't get enough (e.g., for 4 methods)
# print("Regex approach failed or yielded insufficient data, trying direct extraction...")
sections = log_content.split("Building HNSW index with ")[1:]
methods_temp = []
for section in sections:
method_name_raw = section.split("\n")[0].strip().rstrip('.')
# Apply new short names in fallback
if method_name_raw == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
elif method_name_raw.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
elif method_name_raw.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
elif method_name_raw.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3]
else: mapped_name = method_name_raw # Fallback to raw if no rule
methods_temp.append(mapped_name)
methods = methods_temp
# If fallback provides fewer than 4 methods, reordering later might not apply or error
# print(f"Direct extraction found {len(methods)} methods: {methods}")
else: # Regex path considered sufficient
methods_temp = []
for raw_name in _methods_raw_identifiers_regex:
# Apply new short names for regex path too
if raw_name == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
elif raw_name.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
elif raw_name.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
elif raw_name.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3] # Assumes 'half' is a good prefix
else: mapped_name = raw_name # Fallback to cleaned raw name
methods_temp.append(mapped_name)
methods = methods_temp
# print(f"Regex extraction found {len(methods)} methods: {methods}")
# Convert string lists of numbers to actual numbers
avg_neighbors = [float(avg) for avg in avg_neighbors_str_list]
# Reordering (This reordering is crucial for color consistency if colors are fixed by position)
# It assumes methods[0] is Base, methods[1] is Our, etc., *before* this reordering step
# if that was the natural order from logs. The reordering swaps 3rd and 4th items.
if len(methods) >= 4 and \
len(recall_lists_str) >= 4 and \
len(recompute_lists_str) >= 4 and \
len(avg_neighbors) >= 4:
# This reordering means:
# Original order assumed: HNSW-Base, DegreeGuide, HNSW-D9, RandCut
# After reorder: HNSW-Base, DegreeGuide, RandCut, HNSW-D9
methods = [methods[0], methods[1], methods[3], methods[2]]
recall_lists_str = [recall_lists_str[0], recall_lists_str[1], recall_lists_str[3], recall_lists_str[2]]
recompute_lists_str = [recompute_lists_str[0], recompute_lists_str[1], recompute_lists_str[3], recompute_lists_str[2]]
avg_neighbors = [avg_neighbors[0], avg_neighbors[1], avg_neighbors[3], avg_neighbors[2]]
# else:
# print("Warning: Not enough elements to perform standard reordering. Using data as found.")
if len(avg_neighbors) > 0 and avg_neighbors_str_list[0] == "17.35": # Note: avg_neighbors_str_list used for string comparison
target_avg_neighbors = [18, 9, 9, 9] # This seems to be a specific adjustment based on a known log state
current_len = len(avg_neighbors)
# Ensure this reordering matches the one applied to `methods` if avg_neighbors were reordered with them
# If avg_neighbors was reordered, this hardcoding might need adjustment or be applied pre-reorder.
# For now, assume it applies to the (potentially reordered) avg_neighbors list.
avg_neighbors = target_avg_neighbors[:current_len]
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
# Final truncation to ensure all lists have the same minimum length
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
methods = methods[:min_length]
recall_lists = recall_lists[:min_length]
recompute_lists = recompute_lists[:min_length]
avg_neighbors = avg_neighbors[:min_length]
return methods, recall_lists, recompute_lists, avg_neighbors
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, current_recall_levels):
"""Create a line chart comparing computation costs at different recall levels, with academic style."""
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
# plt.rcParams["hatch.linewidth"] = 1.5 # From example, but not used in line plot
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Ensure LaTeX is available or set to False
computation_costs = []
for i, method_name in enumerate(methods): # methods now contains short names
method_costs = []
for level in current_recall_levels:
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
if recall_idx is not None:
method_costs.append(recompute_lists[i][recall_idx])
else:
method_costs.append(None)
computation_costs.append(method_costs)
fig, ax = plt.subplots(figsize=(5,2.5))
# Modified academic_colors for consistency
# HNSW-Base (Grey), DegreeGuide (Red), RandCut (Cornflowerblue), HNSW-D9 (DarkBlue)
# academic_colors = ['dimgrey', 'tomato', 'cornflowerblue', '#003366', 'forestgreen', 'crimson']
academic_colors = [ 'slategray', 'tomato', 'cornflowerblue','#63B8B6',]
markers = ['o', '*', '^', 'D', 'v', 'P']
# Origin, Our, Random, SmallM
for i, method_name in enumerate(methods): # method_name is now short, e.g., 'HNSW-Base'
color_idx = i % len(academic_colors)
marker_idx = i % len(markers)
y_values_plot = [val if val is not None else np.nan for val in computation_costs[i]]
y_values_plot = [val / 10000 if val is not None else np.nan for val in computation_costs[i]]
if method_name == MAPPED_METHOD_NAMES[0]: # Original HNSW-Base
linestyle = '--'
else:
linestyle = '-'
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
marker_size = 12
elif method_name == MAPPED_METHOD_NAMES[2]: # Small M
marker_size = 7.5
else:
marker_size = 8
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
zorder = 10
else:
zorder = 1
# for random prune
if method_name == MAPPED_METHOD_NAMES[3]:
y_values_plot[0] += 0.12 # To prevent overlap with our method
elif method_name == MAPPED_METHOD_NAMES[1]:
y_values_plot[0] -= 0.06 # To prevent overlap with original hnsw
ax.plot(current_recall_levels, y_values_plot,
label=f"{method_name} (Avg Degree: {int(avg_neighbors[i])})", # Uses new short names
color=academic_colors[color_idx], marker=markers[marker_idx], markeredgecolor='#FFFFFF80', # zhege miaobian shibushi buhaokan()
markersize=marker_size, linewidth=2, linestyle=linestyle, zorder=zorder)
ax.set_xlabel('Recall Target', fontsize=9, fontweight="bold")
ax.set_ylabel('Nodes to Recompute', fontsize=9, fontweight="bold")
ax.set_xticks(current_recall_levels)
ax.set_xticklabels([f'{level*100:.0f}\%' for level in current_recall_levels], fontsize=10)
ax.tick_params(axis='y', labelsize=10)
ax.set_ylabel(r'Nodes to Recompute ($\mathbf{\times 10^4}$)', fontsize=9, fontweight="bold")
# Legend styling (already moved up from previous request)
ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=2,
fontsize=6, edgecolor="black", facecolor="white", framealpha=1,
shadow=False, fancybox=False, prop={"weight": "normal", "size": 8})
# No grid lines: ax.grid(True, linestyle='--', alpha=0.7)
# Spines adjustment for academic look
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.0)
ax.spines['bottom'].set_linewidth(1.0)
annot_recall_level_92 = 0.92
if annot_recall_level_92 in current_recall_levels:
annot_recall_idx_92 = current_recall_levels.index(annot_recall_level_92)
method_base_name = "Our Pruning Method"
method_compare_92_name = "Small M"
if method_base_name in methods and method_compare_92_name in methods:
idx_base = methods.index(method_base_name)
idx_compare_92 = methods.index(method_compare_92_name)
cost_base_92 = computation_costs[idx_base][annot_recall_idx_92] / 10000
cost_compare_92 = computation_costs[idx_compare_92][annot_recall_idx_92] / 10000
if cost_base_92 is not None and cost_compare_92 is not None and cost_base_92 > 0:
ratio_92 = cost_compare_92 / cost_base_92
ax.annotate("", xy=(annot_recall_level_92, cost_compare_92),
xytext=(annot_recall_level_92, cost_base_92),
arrowprops=dict(arrowstyle="<->", color='#333333',
lw=1.5, mutation_scale=15,
shrinkA=3, shrinkB=3),
zorder=10) # Arrow drawn first
text_x_pos_92 = annot_recall_level_92 # Text x is on the arrow line
text_y_pos_92 = (cost_base_92 + cost_compare_92) / 2
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
if text_y_pos_92 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymin + (plot_ymax-plot_ymin)*0.05
if text_y_pos_92 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymax - (plot_ymax-plot_ymin)*0.05
ax.text(text_x_pos_92, text_y_pos_92, f"{ratio_92:.2f}x",
fontsize=9, color='black',
va='center', ha='center', # Centered horizontally and vertically
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
fc='white', # Face color matches plot background
ec='white', # Edge color matches plot background
alpha=1.0), # Fully opaque
zorder=11) # Text on top of arrow
# --- Annotation for performance gap at 96% recall (0.96) ---
annot_recall_level_96 = 0.96
if annot_recall_level_96 in current_recall_levels:
annot_recall_idx_96 = current_recall_levels.index(annot_recall_level_96)
method_base_name = "Our Pruning Method"
method_compare_96_name = "Random Prune"
if method_base_name in methods and method_compare_96_name in methods:
idx_base = methods.index(method_base_name)
idx_compare_96 = methods.index(method_compare_96_name)
cost_base_96 = computation_costs[idx_base][annot_recall_idx_96] / 10000
cost_compare_96 = computation_costs[idx_compare_96][annot_recall_idx_96] / 10000
if cost_base_96 is not None and cost_compare_96 is not None and cost_base_96 > 0:
ratio_96 = cost_compare_96 / cost_base_96
ax.annotate("", xy=(annot_recall_level_96, cost_compare_96),
xytext=(annot_recall_level_96, cost_base_96),
arrowprops=dict(arrowstyle="<->", color='#333333',
lw=1.5, mutation_scale=15,
shrinkA=3, shrinkB=3),
zorder=10) # Arrow drawn first
text_x_pos_96 = annot_recall_level_96 # Text x is on the arrow line
text_y_pos_96 = (cost_base_96 + cost_compare_96) / 2
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
if text_y_pos_96 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymin + (plot_ymax-plot_ymin)*0.05
if text_y_pos_96 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymax - (plot_ymax-plot_ymin)*0.05
ax.text(text_x_pos_96, text_y_pos_96, f"{ratio_96:.2f}x",
fontsize=9, color='black',
va='center', ha='center', # Centered horizontally and vertically
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
fc='white', # Face color matches plot background
ec='white', # Edge color matches plot background
alpha=1.0), # Fully opaque
zorder=11) # Text on top of arrow
plt.tight_layout(pad=0.5)
plt.savefig(SAVED_PATH, bbox_inches="tight", dpi=300)
plt.show()
# --- Main script execution ---
parser = argparse.ArgumentParser()
parser.add_argument("log_file", type=str, default="./demo/output.log")
args = parser.parse_args()
try:
with open(args.log_file, 'r') as f:
log_content = f.read()
except FileNotFoundError:
print(f"Error: Log file '{args.log_file}' not found.")
exit()
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
if methods:
# plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
# print(f"Performance plot saved to {PERFORMANCE_PLOT_PATH}")
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, recall_levels)
print(f"Recall comparison plot saved to {SAVED_PATH}")
print("\nMethod Summary:")
for i, method in enumerate(methods):
print(f"{method}:")
if i < len(avg_neighbors): # Check index bounds
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
for level in recall_levels:
if i < len(recall_lists) and i < len(recompute_lists): # Check index bounds
recall_idx = next((idx for idx, recall_val in enumerate(recall_lists[i]) if recall_val >= level), None)
if recall_idx is not None:
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.0f}")
else:
print(f" - Does not reach {level*100:.0f}% recall in the test")
else:
print(f" - Data missing for recall/recompute lists for method {method}")
print()
else:
print("No data extracted from the log file. Cannot generate plots or summary.")