194 lines
6.3 KiB
Python
194 lines
6.3 KiB
Python
import argparse
|
|
import re
|
|
import matplotlib.pyplot as plt
|
|
import os
|
|
|
|
|
|
def parse_log(log_file_path):
|
|
"""
|
|
Parses the log file to extract relevant data for accuracy-recall curve comparison.
|
|
Args:
|
|
log_file_path (str): Path to the log file to parse.
|
|
Returns:
|
|
dict: A dictionary containing extracted results.
|
|
"""
|
|
data = {
|
|
"recalls_with_skip": [],
|
|
"f1_scores_with_skip": [],
|
|
"exact_match_scores_with_skip": [],
|
|
"recalls_without_skip": [],
|
|
"f1_scores_without_skip": [],
|
|
"exact_match_scores_without_skip": [],
|
|
"nprobe_values": [],
|
|
}
|
|
|
|
with open(log_file_path, "r") as file:
|
|
logs = file.readlines()
|
|
|
|
# Variables to track the state during parsing
|
|
is_skip_reorder_true = False
|
|
is_skip_reorder_false = False
|
|
current_nprobe = None
|
|
|
|
for line in logs:
|
|
# Debug: print the current line being processed
|
|
# print(f"Processing line: {line.strip()}")
|
|
|
|
# Check for skip_reorder flag
|
|
if "skip_search_reorder=True" in line:
|
|
is_skip_reorder_true = True
|
|
is_skip_reorder_false = False
|
|
elif "skip_search_reorder=False" in line:
|
|
is_skip_reorder_true = False
|
|
is_skip_reorder_false = True
|
|
|
|
# Extract nprobe values (assuming they are given before the experiment)
|
|
nprobe_match = re.search(r"nprobe=(\d+)", line)
|
|
if nprobe_match:
|
|
current_nprobe = int(nprobe_match.group(1))
|
|
if current_nprobe not in data["nprobe_values"]:
|
|
data["nprobe_values"].append(current_nprobe)
|
|
print(f"Found nprobe value: {current_nprobe}")
|
|
|
|
# Extract average recall rate
|
|
avg_recall_match = re.search(
|
|
r"Avg recall rate for (flat|diskann): ([0-9\.e\-]+)", line
|
|
)
|
|
if avg_recall_match:
|
|
recall_value = float(avg_recall_match.group(2))
|
|
print(
|
|
f"Found avg recall rate: {recall_value} for {avg_recall_match.group(1)} in line {line!r}"
|
|
)
|
|
|
|
if "flat" in avg_recall_match.group(1):
|
|
# data["recalls_without_skip"].append(recall_value)
|
|
pass
|
|
elif "diskann" in avg_recall_match.group(1):
|
|
if is_skip_reorder_true:
|
|
data["recalls_with_skip"].append(recall_value)
|
|
elif is_skip_reorder_false:
|
|
data["recalls_without_skip"].append(recall_value)
|
|
|
|
# Extract exact_match, f1, and recall scores from evaluation results
|
|
eval_match = re.search(
|
|
r"\{'exact_match': ([0-9\.]+), 'exact_match_stderr': [0-9\.]+, 'f1': ([0-9\.]+), 'f1_stderr': [0-9\.]+",
|
|
line,
|
|
)
|
|
if eval_match:
|
|
exact_match = float(eval_match.group(1))
|
|
f1 = float(eval_match.group(2))
|
|
|
|
print(f"Found evaluation results -> Exact Match: {exact_match}, F1: {f1}")
|
|
|
|
# Add to appropriate list based on skip_reorder flag
|
|
if is_skip_reorder_true:
|
|
data["exact_match_scores_with_skip"].append(exact_match)
|
|
data["f1_scores_with_skip"].append(f1)
|
|
elif is_skip_reorder_false:
|
|
data["exact_match_scores_without_skip"].append(exact_match)
|
|
data["f1_scores_without_skip"].append(f1)
|
|
|
|
return data
|
|
|
|
|
|
def plot_skip_reorder_comparison(data, output_dir):
|
|
"""
|
|
绘制带有和不带 skip_reorder 参数的准确率-召回率曲线。
|
|
|
|
Args:
|
|
data: The parsed data including recalls, f1 scores, and exact match scores.
|
|
output_dir: Path where the plot will be saved.
|
|
"""
|
|
recalls_with_skip = data["recalls_with_skip"]
|
|
f1_scores_with_skip = data["f1_scores_with_skip"]
|
|
exact_match_scores_with_skip = data["exact_match_scores_with_skip"]
|
|
recalls_without_skip = data["recalls_without_skip"]
|
|
f1_scores_without_skip = data["f1_scores_without_skip"]
|
|
exact_match_scores_without_skip = data["exact_match_scores_without_skip"]
|
|
nprobe_values = data["nprobe_values"]
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
# Check if data lists are not empty and have the same length before plotting
|
|
if (
|
|
recalls_with_skip
|
|
and len(recalls_with_skip) == len(f1_scores_with_skip)
|
|
and len(recalls_with_skip) == len(exact_match_scores_with_skip)
|
|
):
|
|
plt.plot(
|
|
recalls_with_skip,
|
|
f1_scores_with_skip,
|
|
"bo-",
|
|
label="F1 Score (with skip_reorder)",
|
|
markersize=8,
|
|
linewidth=2,
|
|
)
|
|
plt.plot(
|
|
recalls_with_skip,
|
|
exact_match_scores_with_skip,
|
|
"rs-",
|
|
label="Exact Match (with skip_reorder)",
|
|
markersize=8,
|
|
linewidth=2,
|
|
)
|
|
|
|
if (
|
|
recalls_without_skip
|
|
and len(recalls_without_skip) == len(f1_scores_without_skip)
|
|
and len(recalls_without_skip) == len(exact_match_scores_without_skip)
|
|
):
|
|
plt.plot(
|
|
recalls_without_skip,
|
|
f1_scores_without_skip,
|
|
"go-",
|
|
label="F1 Score (without skip_reorder)",
|
|
markersize=8,
|
|
linewidth=2,
|
|
)
|
|
plt.plot(
|
|
recalls_without_skip,
|
|
exact_match_scores_without_skip,
|
|
"ms-",
|
|
label="Exact Match (without skip_reorder)",
|
|
markersize=8,
|
|
linewidth=2,
|
|
)
|
|
|
|
plt.xlabel("Recall")
|
|
plt.ylabel("Score")
|
|
plt.title("Recall vs Accuracy Comparison")
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.xlim(0.0, 1.0)
|
|
|
|
# Save the plot only if data is present
|
|
if len(nprobe_values) > 0:
|
|
plot_path = os.path.join(
|
|
output_dir,
|
|
f"recall_vs_acc_comparison.png",
|
|
)
|
|
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
|
print(f"Plot saved to {plot_path}")
|
|
else:
|
|
print("No valid data to plot.")
|
|
|
|
plt.close()
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Parse log file and plot results")
|
|
parser.add_argument(
|
|
"log_file_path", type=str, help="Path to the log file"
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir", type=str, help="Path to the output directory", default="skip_reorder_comparison"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Parse the log
|
|
parsed_data = parse_log(args.log_file_path)
|
|
|
|
print(parsed_data)
|
|
|
|
# Plot the data
|
|
plot_skip_reorder_comparison(parsed_data, args.output_dir)
|