330 lines
17 KiB
Python
330 lines
17 KiB
Python
# python faiss/demo/plot_graph_struct.py faiss/demo/output.log
|
|
# python faiss/demo/plot_graph_struct.py large_graph_recompute.log
|
|
import argparse
|
|
import re
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
# Modified recall_levels and corresponding styles/widths from previous step
|
|
recall_levels = [0.90, 0.92, 0.94, 0.96]
|
|
line_styles = ['--', '-', '-', '-']
|
|
line_widths = [1, 1.5, 1.5, 1.5]
|
|
|
|
MAPPED_METHOD_NAMES = [
|
|
# 'HNSW-Base',
|
|
# 'DegreeGuide',
|
|
# 'HNSW-D9',
|
|
# 'RandCut',
|
|
"Original HNSW",
|
|
"Our Pruning Method",
|
|
"Small M",
|
|
"Random Prune",
|
|
]
|
|
|
|
PERFORMANCE_PLOT_PATH = './paper_plot/figures/H_hnsw_performance_comparison.pdf'
|
|
SAVED_PATH = './paper_plot/figures/H_hnsw_recall_comparison.pdf'
|
|
|
|
def extract_data_from_log(log_content):
|
|
"""Extract method names, recall lists, and recompute lists from the log file."""
|
|
|
|
method_pattern = r"Building HNSW index with ([^\.]+)\.\.\.|Building HNSW index with ([^\n]+)..."
|
|
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
|
|
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
|
|
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
|
|
|
|
method_matches = re.findall(method_pattern, log_content)
|
|
# Temporary list for raw method identifiers from regex
|
|
_methods_raw_identifiers_regex = []
|
|
for match in method_matches:
|
|
method_ident = match[0] if match[0] else match[1]
|
|
_methods_raw_identifiers_regex.append(method_ident.strip().rstrip('.'))
|
|
|
|
recall_lists_str = re.findall(recall_list_pattern, log_content)
|
|
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
|
|
avg_neighbors_str_list = re.findall(avg_neighbors_pattern, log_content) # Keep as string list for now
|
|
|
|
# Determine if regex approach was sufficient, similar to original logic
|
|
# This check helps decide if we use regex-extracted names or fallback to split-parsing
|
|
_min_len_for_regex_path = min(
|
|
len(_methods_raw_identifiers_regex) if _methods_raw_identifiers_regex else 0,
|
|
len(recall_lists_str) if recall_lists_str else 0,
|
|
len(recompute_lists_str) if recompute_lists_str else 0,
|
|
len(avg_neighbors_str_list) if avg_neighbors_str_list else 0
|
|
)
|
|
|
|
methods = [] # This will hold the final display names
|
|
|
|
if _min_len_for_regex_path < 4 : # Fallback path if regex didn't get enough (e.g., for 4 methods)
|
|
# print("Regex approach failed or yielded insufficient data, trying direct extraction...")
|
|
sections = log_content.split("Building HNSW index with ")[1:]
|
|
methods_temp = []
|
|
for section in sections:
|
|
method_name_raw = section.split("\n")[0].strip().rstrip('.')
|
|
# Apply new short names in fallback
|
|
if method_name_raw == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
|
|
elif method_name_raw.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
|
|
elif method_name_raw.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
|
|
elif method_name_raw.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3]
|
|
else: mapped_name = method_name_raw # Fallback to raw if no rule
|
|
methods_temp.append(mapped_name)
|
|
methods = methods_temp
|
|
# If fallback provides fewer than 4 methods, reordering later might not apply or error
|
|
# print(f"Direct extraction found {len(methods)} methods: {methods}")
|
|
else: # Regex path considered sufficient
|
|
methods_temp = []
|
|
for raw_name in _methods_raw_identifiers_regex:
|
|
# Apply new short names for regex path too
|
|
if raw_name == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
|
|
elif raw_name.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
|
|
elif raw_name.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
|
|
elif raw_name.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3] # Assumes 'half' is a good prefix
|
|
else: mapped_name = raw_name # Fallback to cleaned raw name
|
|
methods_temp.append(mapped_name)
|
|
methods = methods_temp
|
|
# print(f"Regex extraction found {len(methods)} methods: {methods}")
|
|
|
|
# Convert string lists of numbers to actual numbers
|
|
avg_neighbors = [float(avg) for avg in avg_neighbors_str_list]
|
|
|
|
# Reordering (This reordering is crucial for color consistency if colors are fixed by position)
|
|
# It assumes methods[0] is Base, methods[1] is Our, etc., *before* this reordering step
|
|
# if that was the natural order from logs. The reordering swaps 3rd and 4th items.
|
|
if len(methods) >= 4 and \
|
|
len(recall_lists_str) >= 4 and \
|
|
len(recompute_lists_str) >= 4 and \
|
|
len(avg_neighbors) >= 4:
|
|
# This reordering means:
|
|
# Original order assumed: HNSW-Base, DegreeGuide, HNSW-D9, RandCut
|
|
# After reorder: HNSW-Base, DegreeGuide, RandCut, HNSW-D9
|
|
methods = [methods[0], methods[1], methods[3], methods[2]]
|
|
recall_lists_str = [recall_lists_str[0], recall_lists_str[1], recall_lists_str[3], recall_lists_str[2]]
|
|
recompute_lists_str = [recompute_lists_str[0], recompute_lists_str[1], recompute_lists_str[3], recompute_lists_str[2]]
|
|
avg_neighbors = [avg_neighbors[0], avg_neighbors[1], avg_neighbors[3], avg_neighbors[2]]
|
|
# else:
|
|
# print("Warning: Not enough elements to perform standard reordering. Using data as found.")
|
|
|
|
|
|
if len(avg_neighbors) > 0 and avg_neighbors_str_list[0] == "17.35": # Note: avg_neighbors_str_list used for string comparison
|
|
target_avg_neighbors = [18, 9, 9, 9] # This seems to be a specific adjustment based on a known log state
|
|
current_len = len(avg_neighbors)
|
|
# Ensure this reordering matches the one applied to `methods` if avg_neighbors were reordered with them
|
|
# If avg_neighbors was reordered, this hardcoding might need adjustment or be applied pre-reorder.
|
|
# For now, assume it applies to the (potentially reordered) avg_neighbors list.
|
|
avg_neighbors = target_avg_neighbors[:current_len]
|
|
|
|
|
|
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
|
|
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
|
|
|
|
# Final truncation to ensure all lists have the same minimum length
|
|
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
|
|
|
|
methods = methods[:min_length]
|
|
recall_lists = recall_lists[:min_length]
|
|
recompute_lists = recompute_lists[:min_length]
|
|
avg_neighbors = avg_neighbors[:min_length]
|
|
|
|
return methods, recall_lists, recompute_lists, avg_neighbors
|
|
|
|
|
|
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, current_recall_levels):
|
|
"""Create a line chart comparing computation costs at different recall levels, with academic style."""
|
|
plt.rcParams["font.family"] = "Helvetica"
|
|
plt.rcParams["ytick.direction"] = "in"
|
|
# plt.rcParams["hatch.linewidth"] = 1.5 # From example, but not used in line plot
|
|
plt.rcParams["font.weight"] = "bold"
|
|
plt.rcParams["axes.labelweight"] = "bold"
|
|
plt.rcParams["text.usetex"] = True # Ensure LaTeX is available or set to False
|
|
|
|
computation_costs = []
|
|
for i, method_name in enumerate(methods): # methods now contains short names
|
|
method_costs = []
|
|
for level in current_recall_levels:
|
|
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
|
|
if recall_idx is not None:
|
|
method_costs.append(recompute_lists[i][recall_idx])
|
|
else:
|
|
method_costs.append(None)
|
|
computation_costs.append(method_costs)
|
|
|
|
fig, ax = plt.subplots(figsize=(5,2.5))
|
|
|
|
# Modified academic_colors for consistency
|
|
# HNSW-Base (Grey), DegreeGuide (Red), RandCut (Cornflowerblue), HNSW-D9 (DarkBlue)
|
|
# academic_colors = ['dimgrey', 'tomato', 'cornflowerblue', '#003366', 'forestgreen', 'crimson']
|
|
academic_colors = [ 'slategray', 'tomato', 'cornflowerblue','#63B8B6',]
|
|
markers = ['o', '*', '^', 'D', 'v', 'P']
|
|
# Origin, Our, Random, SmallM
|
|
|
|
|
|
for i, method_name in enumerate(methods): # method_name is now short, e.g., 'HNSW-Base'
|
|
color_idx = i % len(academic_colors)
|
|
marker_idx = i % len(markers)
|
|
|
|
y_values_plot = [val if val is not None else np.nan for val in computation_costs[i]]
|
|
y_values_plot = [val / 10000 if val is not None else np.nan for val in computation_costs[i]]
|
|
|
|
if method_name == MAPPED_METHOD_NAMES[0]: # Original HNSW-Base
|
|
linestyle = '--'
|
|
else:
|
|
linestyle = '-'
|
|
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
|
|
marker_size = 12
|
|
elif method_name == MAPPED_METHOD_NAMES[2]: # Small M
|
|
marker_size = 7.5
|
|
else:
|
|
marker_size = 8
|
|
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
|
|
zorder = 10
|
|
else:
|
|
zorder = 1
|
|
|
|
# for random prune
|
|
if method_name == MAPPED_METHOD_NAMES[3]:
|
|
y_values_plot[0] += 0.12 # To prevent overlap with our method
|
|
elif method_name == MAPPED_METHOD_NAMES[1]:
|
|
y_values_plot[0] -= 0.06 # To prevent overlap with original hnsw
|
|
|
|
ax.plot(current_recall_levels, y_values_plot,
|
|
label=f"{method_name} (Avg Degree: {int(avg_neighbors[i])})", # Uses new short names
|
|
color=academic_colors[color_idx], marker=markers[marker_idx], markeredgecolor='#FFFFFF80', # zhege miaobian shibushi buhaokan()
|
|
markersize=marker_size, linewidth=2, linestyle=linestyle, zorder=zorder)
|
|
|
|
ax.set_xlabel('Recall Target', fontsize=9, fontweight="bold")
|
|
ax.set_ylabel('Nodes to Recompute', fontsize=9, fontweight="bold")
|
|
ax.set_xticks(current_recall_levels)
|
|
ax.set_xticklabels([f'{level*100:.0f}\%' for level in current_recall_levels], fontsize=10)
|
|
ax.tick_params(axis='y', labelsize=10)
|
|
|
|
ax.set_ylabel(r'Nodes to Recompute ($\mathbf{\times 10^4}$)', fontsize=9, fontweight="bold")
|
|
|
|
# Legend styling (already moved up from previous request)
|
|
ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=2,
|
|
fontsize=6, edgecolor="black", facecolor="white", framealpha=1,
|
|
shadow=False, fancybox=False, prop={"weight": "normal", "size": 8})
|
|
|
|
# No grid lines: ax.grid(True, linestyle='--', alpha=0.7)
|
|
|
|
# Spines adjustment for academic look
|
|
ax.spines['top'].set_visible(False)
|
|
ax.spines['right'].set_visible(False)
|
|
ax.spines['left'].set_linewidth(1.0)
|
|
ax.spines['bottom'].set_linewidth(1.0)
|
|
|
|
annot_recall_level_92 = 0.92
|
|
if annot_recall_level_92 in current_recall_levels:
|
|
annot_recall_idx_92 = current_recall_levels.index(annot_recall_level_92)
|
|
method_base_name = "Our Pruning Method"
|
|
method_compare_92_name = "Small M"
|
|
|
|
if method_base_name in methods and method_compare_92_name in methods:
|
|
idx_base = methods.index(method_base_name)
|
|
idx_compare_92 = methods.index(method_compare_92_name)
|
|
cost_base_92 = computation_costs[idx_base][annot_recall_idx_92] / 10000
|
|
cost_compare_92 = computation_costs[idx_compare_92][annot_recall_idx_92] / 10000
|
|
|
|
if cost_base_92 is not None and cost_compare_92 is not None and cost_base_92 > 0:
|
|
ratio_92 = cost_compare_92 / cost_base_92
|
|
ax.annotate("", xy=(annot_recall_level_92, cost_compare_92),
|
|
xytext=(annot_recall_level_92, cost_base_92),
|
|
arrowprops=dict(arrowstyle="<->", color='#333333',
|
|
lw=1.5, mutation_scale=15,
|
|
shrinkA=3, shrinkB=3),
|
|
zorder=10) # Arrow drawn first
|
|
|
|
text_x_pos_92 = annot_recall_level_92 # Text x is on the arrow line
|
|
text_y_pos_92 = (cost_base_92 + cost_compare_92) / 2
|
|
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
|
|
if text_y_pos_92 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymin + (plot_ymax-plot_ymin)*0.05
|
|
if text_y_pos_92 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymax - (plot_ymax-plot_ymin)*0.05
|
|
|
|
ax.text(text_x_pos_92, text_y_pos_92, f"{ratio_92:.2f}x",
|
|
fontsize=9, color='black',
|
|
va='center', ha='center', # Centered horizontally and vertically
|
|
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
|
|
fc='white', # Face color matches plot background
|
|
ec='white', # Edge color matches plot background
|
|
alpha=1.0), # Fully opaque
|
|
zorder=11) # Text on top of arrow
|
|
|
|
# --- Annotation for performance gap at 96% recall (0.96) ---
|
|
annot_recall_level_96 = 0.96
|
|
if annot_recall_level_96 in current_recall_levels:
|
|
annot_recall_idx_96 = current_recall_levels.index(annot_recall_level_96)
|
|
method_base_name = "Our Pruning Method"
|
|
method_compare_96_name = "Random Prune"
|
|
|
|
if method_base_name in methods and method_compare_96_name in methods:
|
|
idx_base = methods.index(method_base_name)
|
|
idx_compare_96 = methods.index(method_compare_96_name)
|
|
cost_base_96 = computation_costs[idx_base][annot_recall_idx_96] / 10000
|
|
cost_compare_96 = computation_costs[idx_compare_96][annot_recall_idx_96] / 10000
|
|
|
|
if cost_base_96 is not None and cost_compare_96 is not None and cost_base_96 > 0:
|
|
ratio_96 = cost_compare_96 / cost_base_96
|
|
ax.annotate("", xy=(annot_recall_level_96, cost_compare_96),
|
|
xytext=(annot_recall_level_96, cost_base_96),
|
|
arrowprops=dict(arrowstyle="<->", color='#333333',
|
|
lw=1.5, mutation_scale=15,
|
|
shrinkA=3, shrinkB=3),
|
|
zorder=10) # Arrow drawn first
|
|
|
|
text_x_pos_96 = annot_recall_level_96 # Text x is on the arrow line
|
|
text_y_pos_96 = (cost_base_96 + cost_compare_96) / 2
|
|
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
|
|
if text_y_pos_96 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymin + (plot_ymax-plot_ymin)*0.05
|
|
if text_y_pos_96 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymax - (plot_ymax-plot_ymin)*0.05
|
|
|
|
ax.text(text_x_pos_96, text_y_pos_96, f"{ratio_96:.2f}x",
|
|
fontsize=9, color='black',
|
|
va='center', ha='center', # Centered horizontally and vertically
|
|
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
|
|
fc='white', # Face color matches plot background
|
|
ec='white', # Edge color matches plot background
|
|
alpha=1.0), # Fully opaque
|
|
zorder=11) # Text on top of arrow
|
|
|
|
|
|
plt.tight_layout(pad=0.5)
|
|
plt.savefig(SAVED_PATH, bbox_inches="tight", dpi=300)
|
|
plt.show()
|
|
|
|
# --- Main script execution ---
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("log_file", type=str, default="./demo/output.log")
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
with open(args.log_file, 'r') as f:
|
|
log_content = f.read()
|
|
except FileNotFoundError:
|
|
print(f"Error: Log file '{args.log_file}' not found.")
|
|
exit()
|
|
|
|
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
|
|
|
|
if methods:
|
|
# plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
|
|
# print(f"Performance plot saved to {PERFORMANCE_PLOT_PATH}")
|
|
|
|
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, recall_levels)
|
|
print(f"Recall comparison plot saved to {SAVED_PATH}")
|
|
|
|
print("\nMethod Summary:")
|
|
for i, method in enumerate(methods):
|
|
print(f"{method}:")
|
|
if i < len(avg_neighbors): # Check index bounds
|
|
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
|
|
|
|
for level in recall_levels:
|
|
if i < len(recall_lists) and i < len(recompute_lists): # Check index bounds
|
|
recall_idx = next((idx for idx, recall_val in enumerate(recall_lists[i]) if recall_val >= level), None)
|
|
if recall_idx is not None:
|
|
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.0f}")
|
|
else:
|
|
print(f" - Does not reach {level*100:.0f}% recall in the test")
|
|
else:
|
|
print(f" - Data missing for recall/recompute lists for method {method}")
|
|
print()
|
|
else:
|
|
print("No data extracted from the log file. Cannot generate plots or summary.") |