600 lines
22 KiB
Python
Executable File
600 lines
22 KiB
Python
Executable File
import os
|
|
import math
|
|
import pandas as pd
|
|
import numpy as np
|
|
import re
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as ticker
|
|
import json
|
|
import pickle
|
|
import pdb
|
|
|
|
|
|
"""
|
|
Automatic result extraction for BM25.
|
|
"""
|
|
|
|
|
|
def extract_data_to_table(directory_path):
|
|
# Regular expression pattern to match the data format in file content
|
|
content_pattern = (
|
|
r"# tokens: (\d+(\.\d+)?)\tLM PPL: (\d+(\.\d+)?)\tPPL: (\d+(\.\d+)?)"
|
|
)
|
|
# Regular expression pattern to extract info from file names
|
|
file_name_pattern_M = r"(.+)-(\d+)M-seed_(\d+).txt"
|
|
file_name_pattern = r"(.+)-(\d+)-seed_(\d+).txt"
|
|
|
|
# Data storage
|
|
data = []
|
|
|
|
# Iterating through each file in the directory
|
|
for file_name in os.listdir(directory_path):
|
|
# Checking if the file name matches the pattern
|
|
file_match_M = re.match(file_name_pattern_M, file_name)
|
|
file_match = re.match(file_name_pattern, file_name)
|
|
if file_match_M:
|
|
domain, num_samples, seed = file_match_M.groups()
|
|
|
|
# Reading the file and extracting data
|
|
file_path = os.path.join(directory_path, file_name)
|
|
with open(file_path, "r") as file:
|
|
for line in file:
|
|
# Searching for the pattern in each line
|
|
content_match = re.search(content_pattern, line)
|
|
if content_match:
|
|
# Extracting values
|
|
tokens, lm_ppl, ppl = (
|
|
content_match.groups()[0],
|
|
content_match.groups()[2],
|
|
content_match.groups()[4],
|
|
)
|
|
# Adding the extracted data and extra info to the list
|
|
data.append(
|
|
{
|
|
"Domain": domain,
|
|
"Samples": int(num_samples) * 1e6,
|
|
"Seed": int(seed),
|
|
"#eval_tokens": float(tokens),
|
|
"LM_PPL": float(lm_ppl),
|
|
"PPL": float(ppl),
|
|
}
|
|
)
|
|
elif file_match:
|
|
domain, num_samples, seed = file_match.groups()
|
|
|
|
# Reading the file and extracting data
|
|
file_path = os.path.join(directory_path, file_name)
|
|
with open(file_path, "r") as file:
|
|
for line in file:
|
|
# Searching for the pattern in each line
|
|
content_match = re.search(content_pattern, line)
|
|
if content_match:
|
|
# Extracting values
|
|
tokens, lm_ppl, ppl = (
|
|
content_match.groups()[0],
|
|
content_match.groups()[2],
|
|
content_match.groups()[4],
|
|
)
|
|
# Adding the extracted data and extra info to the list
|
|
data.append(
|
|
{
|
|
"Domain": domain,
|
|
"Samples": int(num_samples),
|
|
"Seed": int(seed),
|
|
"#eval_tokens": float(tokens),
|
|
"LM_PPL": float(lm_ppl),
|
|
"PPL": float(ppl),
|
|
}
|
|
)
|
|
|
|
df = pd.DataFrame(data)
|
|
grouped_df = df.groupby(["Domain", "Samples", "#eval_tokens"]).mean()
|
|
|
|
return df, grouped_df
|
|
|
|
|
|
"""
|
|
Automatic resutls extraction for dense retrieval. (new)
|
|
"""
|
|
|
|
|
|
def extract_dense_scaling_results(log_files, domain=None, plot=None):
|
|
# Regular expression pattern to match the key-value pairs in the input string
|
|
pattern = r"(\w[\w #]+) = ([\w.]+)"
|
|
|
|
data_list = []
|
|
for file in log_files:
|
|
with open(file, "r") as file:
|
|
for line in file:
|
|
# Use re.findall to extract all matches of the pattern
|
|
matches = re.findall(pattern, line)
|
|
|
|
if matches:
|
|
data_dict = {
|
|
key.replace(" ", "_").lower(): (
|
|
None
|
|
if value == "None"
|
|
else float(value)
|
|
if value.replace(".", "", 1).isdigit()
|
|
else value
|
|
)
|
|
for key, value in matches
|
|
}
|
|
data_list.append(data_dict)
|
|
|
|
df = pd.DataFrame(data_list)
|
|
if "total_shards" in df.columns:
|
|
df["subsample_ratio"] = df["sampled_shards"] / df["total_shards"]
|
|
else:
|
|
df["subsample_ratio"] = 1 / df["total_shards"]
|
|
df = df.sort_values(by="subsample_ratio")
|
|
print(df.head)
|
|
|
|
if plot:
|
|
# Setting the plot size for better visibility
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
# Plotting
|
|
for concate_k in df["concate_k"].unique():
|
|
subset = df[df["concate_k"] == concate_k]
|
|
if concate_k == 0:
|
|
perplexity_when_concate_k_0 = subset["perplexity"].mean()
|
|
plt.axhline(
|
|
y=perplexity_when_concate_k_0,
|
|
color="r",
|
|
linestyle="-",
|
|
label="Closed-book",
|
|
)
|
|
else:
|
|
plt.plot(
|
|
subset["subsample_ratio"],
|
|
subset["perplexity"],
|
|
label=f"Concate_k = {concate_k}",
|
|
)
|
|
|
|
plt.title(f"Perplexity Change with Total Shards -- {domain}")
|
|
plt.xlabel("Subsample Ratio")
|
|
plt.ylabel("Perplexity")
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.savefig(plot)
|
|
return df
|
|
|
|
|
|
def plot_mmlu():
|
|
# C4 results
|
|
labels = [
|
|
"LM-only",
|
|
"top-1 w/ 1/32 C4 datastore",
|
|
"top-1 w/ 2/32 C4 datastore",
|
|
"top-1 w/ 3/32 C4 datastore",
|
|
"top-1 w/ 4/32 C4 datastore",
|
|
"top-1 w/ 5/32 C4 datastore",
|
|
"top-1 w/ 6/32 C4 datastore",
|
|
]
|
|
x = [0, 1, 2, 3, 4, 5, 6]
|
|
few_shot_0_concat_1 = [30.69, 32.81, 32.05, 32.55, 32.57, 33.03, 32.88]
|
|
few_shot_1_concat_1 = [39.67, 41.03, 41.74, 42.1, 42.62, 41.55, 42.09]
|
|
few_shot_5_concat_1 = [42.47, 43.75, 44.37, 44.1, 44.84, 43.95, 44.49]
|
|
|
|
# Plotting the data
|
|
plt.figure(figsize=(14, 8))
|
|
|
|
# Plot for few_shot_0_concat_1
|
|
plt.plot(
|
|
x,
|
|
few_shot_0_concat_1,
|
|
marker="o",
|
|
linestyle="-",
|
|
color="blue",
|
|
label="Few-shot k=0, Concat k=1",
|
|
)
|
|
|
|
# Plot for few_shot_1_concat_1
|
|
plt.plot(
|
|
x,
|
|
few_shot_1_concat_1,
|
|
marker="s",
|
|
linestyle="-",
|
|
color="red",
|
|
label="Few-shot k=1, Concat k=1",
|
|
)
|
|
|
|
# Plot for few_shot_5_concat_1
|
|
plt.plot(
|
|
x,
|
|
few_shot_5_concat_1,
|
|
marker="^",
|
|
linestyle="-",
|
|
color="green",
|
|
label="Few-shot k=5, Concat k=1",
|
|
)
|
|
|
|
# Adding details
|
|
plt.title("MMLU Performance")
|
|
plt.xlabel("Retrieval-based LM Datastore Composition")
|
|
plt.ylabel("Accuracy")
|
|
plt.xticks(ticks=x, labels=labels, rotation=45, ha="right")
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
plt.grid(True)
|
|
plt.savefig("mmlu_c4_scaling.png")
|
|
|
|
|
|
def extract_lm_eval_results(
|
|
result_dir, task_name, model_name, n_shot_list, n_doc_list, datastore_name_filter=""
|
|
):
|
|
markers = ["o", "s", "^", "D", "*", "p", "H", "x"]
|
|
colors = plt.cm.tab20.colors
|
|
|
|
all_data = []
|
|
for subdir, dirs, files in os.walk(result_dir):
|
|
num_ints = len(os.path.basename(subdir).split("-"))
|
|
for file in files:
|
|
if file.endswith(".jsonl"):
|
|
file_path = os.path.join(subdir, file)
|
|
with open(file_path, "r") as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
data["SubdirLevel"] = num_ints
|
|
data["n-shot"], data["n-doc"] = (
|
|
int(data["n-shot"]),
|
|
int(data["n-doc"]),
|
|
)
|
|
data["Value"] = float(data["Value"])
|
|
all_data.append(data)
|
|
|
|
filtered_data = [
|
|
d
|
|
for d in all_data
|
|
if datastore_name_filter in result_dir
|
|
and d["n-shot"] in n_shot_list
|
|
and d["n-doc"] in n_doc_list
|
|
and d["SubdirLevel"] > 0
|
|
]
|
|
|
|
plot_data = {}
|
|
for d in filtered_data:
|
|
key = (d["n-shot"], d["n-doc"])
|
|
plot_data.setdefault(key, []).append((d["SubdirLevel"], d["Value"]))
|
|
|
|
sorted_keys = sorted(plot_data.keys(), key=lambda x: (x[0], x[1]))
|
|
|
|
closed_book_values = {}
|
|
for i, key in enumerate(sorted_keys):
|
|
n_shot, n_doc = key
|
|
if n_doc == 0:
|
|
value = plot_data[key][-1][-1]
|
|
closed_book_values.update({n_shot: value})
|
|
|
|
plt.figure(figsize=(15, 10))
|
|
for i, key in enumerate(sorted_keys):
|
|
n_shot, n_doc = key
|
|
if n_doc == 0:
|
|
continue
|
|
values = plot_data[key]
|
|
values.append(
|
|
(0, closed_book_values[n_shot])
|
|
if n_shot in closed_book_values.keys()
|
|
else (0, None)
|
|
)
|
|
values.sort() # Ensure the values are sorted by SubdirLevel
|
|
x_values, y_values = zip(*values) # Unzip the tuple pairs to separate lists
|
|
marker = markers[n_shot] if n_doc else ""
|
|
color = colors[i % len(colors)] # Choose a color from the colormap
|
|
label = f"n-shot={n_shot}, n-doc={n_doc}"
|
|
plt.plot(
|
|
x_values, y_values, marker=marker, color=color, linestyle="-", label=label
|
|
)
|
|
|
|
# plt.gca().yaxis.set_major_locator(ticker.MaxNLocator(nbins='auto', steps=[1, 2, 5, 10]))
|
|
|
|
if subject_name == "mmlu":
|
|
plot_dir = os.path.join("plots", "mmlu")
|
|
else:
|
|
plot_dir = "plots"
|
|
os.makedirs(plot_dir, exist_ok=True)
|
|
|
|
plt.xlabel("Number of Index Shards")
|
|
plt.ylabel("Accuracy")
|
|
plt.title(f"{task_name} scaling performance with {model_name}")
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.savefig(f"{plot_dir}/{task_name}_{model_name}.png")
|
|
|
|
return all_data
|
|
|
|
|
|
def plot_mmlu_persub_figures(directory="plots"):
|
|
files = [
|
|
file
|
|
for file in os.listdir(directory)
|
|
if file.startswith("mmlu_") and file.endswith(".png")
|
|
]
|
|
plots_per_figure = 16
|
|
for i in range(0, len(files), plots_per_figure):
|
|
# Create a new figure
|
|
fig, axs = plt.subplots(4, 4, figsize=(20, 20))
|
|
|
|
# Flatten the axis array for easy indexing
|
|
axs = axs.flatten()
|
|
|
|
# Iterate over each subplot in the current figure
|
|
for ax, file in zip(axs, files[i : i + plots_per_figure]):
|
|
# Read the image file
|
|
img = plt.imread(os.path.join(directory, file))
|
|
|
|
# Display the image in the subplot
|
|
ax.imshow(img)
|
|
ax.set_title(file)
|
|
ax.axis("off") # Hide axes
|
|
|
|
# Adjust layout and display the figure
|
|
plt.tight_layout()
|
|
plt.savefig(f"mmlu_persub_{i}.png")
|
|
|
|
|
|
def plot_calibration_figures(domain, shard_id=8, show_ci=True, show_all_points=False):
|
|
if show_all_points:
|
|
show_ci = False
|
|
|
|
data_path = f"out_calibration/{shard_id}_shard_{domain}/calibration_results_decon_rpj_{domain}_None_samples.pkl"
|
|
|
|
with open(data_path, "rb") as file:
|
|
all_results = pickle.load(file)
|
|
|
|
all_lm_losses = [item[0] for item in all_results]
|
|
all_retrieval_scores = [item[1] for item in all_results]
|
|
print(f"Total {len(all_lm_losses)} examples.")
|
|
|
|
# Compute PPL of top-1 doc v.s. golden doc from top-100
|
|
losses_top1 = [losses[0] for losses in all_lm_losses]
|
|
avg_losses_top1 = sum(losses_top1) / len(losses_top1)
|
|
ppl_losses_top1 = math.exp(avg_losses_top1)
|
|
|
|
lossed_top100_gold = [min(losses) for losses in all_lm_losses]
|
|
avg_losses_top100_gold = sum(lossed_top100_gold) / len(lossed_top100_gold)
|
|
ppl_lossed_top100_gold = math.exp(avg_losses_top100_gold)
|
|
|
|
print(
|
|
f"Top-1 doc PPL: {ppl_losses_top1:.4f}\nGold doc from top-100 PPL: {ppl_lossed_top100_gold:.4f}"
|
|
)
|
|
|
|
# Calibration plot
|
|
lm_losses = np.array(all_lm_losses)
|
|
retrieval_scores = np.array(all_retrieval_scores)
|
|
|
|
from scipy.special import softmax
|
|
import scipy.stats as stats
|
|
|
|
softmax_lm_losses = softmax(lm_losses, axis=1)
|
|
softmax_retrieval_scores = softmax(retrieval_scores, axis=1)
|
|
|
|
if show_all_points:
|
|
lm_losses = lm_losses.flatten()
|
|
retrieval_scores = retrieval_scores.flatten()
|
|
|
|
plt.figure(figsize=(8, 6))
|
|
plt.plot(lm_losses, retrieval_scores, marker="o", linestyle="")
|
|
plt.title(f"Calibration Curve with {shard_id} Shards")
|
|
plt.xlabel("LM Losses")
|
|
plt.ylabel("Retrieval Scores")
|
|
plt.grid(True)
|
|
plt.savefig(f"out_calibration/calibration_all_{shard_id}_shard_{domain}.png")
|
|
|
|
elif show_ci:
|
|
lm_losses_mean = np.mean(lm_losses, axis=0)
|
|
retrieval_scores_mean = np.mean(retrieval_scores, axis=0)
|
|
|
|
lm_losses_sem = stats.sem(lm_losses, axis=0)
|
|
retrieval_scores_sem = stats.sem(retrieval_scores, axis=0)
|
|
|
|
# Assuming a 95% confidence interval, z-score is approximately 1.96 for a normal distribution
|
|
z_score = 1.96
|
|
losses_ci = lm_losses_sem * z_score
|
|
retrieval_ci = retrieval_scores_sem * z_score
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
plt.errorbar(
|
|
lm_losses_mean,
|
|
retrieval_scores_mean,
|
|
xerr=losses_ci,
|
|
yerr=retrieval_ci,
|
|
fmt="o",
|
|
ecolor="lightgray",
|
|
alpha=0.5,
|
|
capsize=5,
|
|
)
|
|
plt.xlabel("LM Losses")
|
|
plt.ylabel("Retrieval Scores")
|
|
plt.title(
|
|
f"Calibration plot for {shard_id}-shard {domain} with Confidence Intervals"
|
|
)
|
|
plt.grid(True)
|
|
plt.savefig(f"out_calibration/calibration_ci_{shard_id}_shard_{domain}.png")
|
|
|
|
else:
|
|
lm_losses = np.mean(lm_losses, axis=0)
|
|
retrieval_scores = np.mean(retrieval_scores, axis=0)
|
|
|
|
plt.figure(figsize=(8, 6))
|
|
plt.plot(lm_losses, retrieval_scores, marker="o", linestyle="")
|
|
plt.title(f"Calibration Curve with {shard_id} Shards")
|
|
plt.xlabel("LM Losses")
|
|
plt.ylabel("Retrieval Scores")
|
|
plt.grid(True)
|
|
plt.savefig(f"out_calibration/calibration_{shard_id}_shard_{domain}.png")
|
|
|
|
return ppl_losses_top1, ppl_lossed_top100_gold, all_lm_losses, all_retrieval_scores
|
|
|
|
|
|
def plot_top1_vs_best_doc(domain, total_shards=8):
|
|
lm_only_ppl = {
|
|
"books": 21.5250,
|
|
"stackexchange": 11.5948,
|
|
"wiki": 14.0729,
|
|
}
|
|
|
|
top1_losses, best_losses = [], []
|
|
for shard_id in range(1, total_shards + 1):
|
|
top1_loss, best_loss, _, _ = plot_calibration_figures(domain, shard_id)
|
|
top1_losses.append(top1_loss)
|
|
best_losses.append(best_loss)
|
|
|
|
x = [i for i in range(1, total_shards + 1)]
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
# Plotting
|
|
if lm_only_ppl[domain]:
|
|
plt.axhline(
|
|
y=lm_only_ppl[domain], color="r", linestyle="-", label="Closed-book"
|
|
)
|
|
|
|
plt.plot(x, top1_losses, label=f"Top-1 Doc")
|
|
plt.plot(x, best_losses, label=f"Gold Doc")
|
|
|
|
plt.title(f"Perplexity Change with Total Shards")
|
|
plt.xlabel("Number of Shards")
|
|
plt.ylabel("Perplexity")
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.savefig(f"best_plot_{domain}.png")
|
|
|
|
|
|
def plot_top1_vs_best_doc_per_sample(domain, shard_id, show_top_k=10, special_mark_k=0):
|
|
_, _, all_lm_losses, all_retrieval_scores = plot_calibration_figures(
|
|
domain, shard_id
|
|
)
|
|
all_sorted_lm_losses, all_sorted_retrieval_scores = [], []
|
|
for lm_losses, retrieval_scores in zip(all_retrieval_scores, all_lm_losses):
|
|
sorted_scores, sorted_losses = zip(
|
|
*sorted(zip(retrieval_scores, lm_losses), reverse=True)
|
|
)
|
|
all_sorted_lm_losses.append(sorted_losses)
|
|
all_sorted_retrieval_scores.append(sorted_scores)
|
|
|
|
num_samples = len(all_lm_losses)
|
|
x = [i for i in range(num_samples)]
|
|
plt.figure(figsize=(25, 6))
|
|
|
|
# Plotting
|
|
for i in range(show_top_k - 1, -1, -1):
|
|
plt.plot(
|
|
x,
|
|
[losses[i] for losses in all_sorted_lm_losses],
|
|
label=f"Top-{i + 1}th Doc",
|
|
marker="x" if i == special_mark_k else "o",
|
|
linestyle="",
|
|
)
|
|
|
|
plt.title(f"Per-sample Loss of {domain} with 1 retrieved doc")
|
|
plt.xlabel("Index of the Evaluation Sample")
|
|
plt.ylabel("Loss")
|
|
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
|
|
plt.grid(True)
|
|
plt.savefig(f"per_sample_{domain}.png")
|
|
|
|
|
|
def compute_variance_across_hards(path, n_shot=5, n_doc=3):
|
|
all_data = []
|
|
for subdir, dirs, files in os.walk(path):
|
|
num_ints = len(os.path.basename(subdir).split("-"))
|
|
for file in files:
|
|
if file.endswith(".jsonl"):
|
|
file_path = os.path.join(subdir, file)
|
|
with open(file_path, "r") as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
data["SubdirLevel"] = num_ints
|
|
data["n-shot"], data["n-doc"] = (
|
|
int(data["n-shot"]),
|
|
int(data["n-doc"]),
|
|
)
|
|
data["Value"] = float(data["Value"])
|
|
all_data.append(data)
|
|
|
|
plot_data = {}
|
|
for d in all_data:
|
|
key = (d["n-shot"], d["n-doc"])
|
|
plot_data.setdefault(key, []).append((d["SubdirLevel"], d["Value"]))
|
|
|
|
files_end = [d.split("/")[-1] for d, _, _ in os.walk(path)]
|
|
shard_ids = [int(i) for i in files_end[1:]]
|
|
key = n_shot, n_doc
|
|
values = plot_data[key]
|
|
_, y_values = zip(*values)
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
try:
|
|
plt.plot(shard_ids, y_values, marker="o", linestyle="")
|
|
except:
|
|
print(f"mismatched size for {key}: {len(shard_ids)}, {len(y_values)}")
|
|
|
|
print(y_values)
|
|
print(f"Saving to {f'per_sample_{files_end[0]}.png'}")
|
|
|
|
plt.xlabel("Single-shard Index ID")
|
|
plt.ylabel("PPL")
|
|
plt.grid(True)
|
|
plt.savefig(f"per_sample_{files_end[0]}.png")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# # Replace with your directory path
|
|
# directory_path = "out/2023_dec_25_single_domain"
|
|
|
|
# # Extracting data to a table with additional information
|
|
# df, grouped_df = extract_data_to_table(directory_path)
|
|
# print(grouped_df)
|
|
# print(grouped_df.index.get_level_values("Samples (M)").to_numpy())
|
|
|
|
plot_info_list = [
|
|
# {'logfile': 'rpj_c4.log', 'domain': 'rpj-c4', 'plot': 'scaling_c4_single_index_plot.png'},
|
|
# {'logfile': 'rpj_arxiv.log', 'domain': 'rpj-arxiv', 'plot': 'scaling_arxiv_plot.png'},
|
|
# {'logfile': 'rpj_book_scaling.log', 'domain': 'rpj-book', 'plot': 'scaling_book_plot.png'},
|
|
# {'logfile': 'rpj_github_scaling.log', 'domain': 'rpj-github', 'plot': 'scaling_github_plot.png'},
|
|
# {'logfile': 'rpj_stackexchange_scaling.log', 'domain': 'rpj-stackexchange', 'plot': 'scaling_stackexchange_plot.png'},
|
|
# {'logfile': 'rpj_wiki.log', 'domain': 'rpj-wiki', 'plot': 'scaling_wiki_plot.png'},
|
|
# {'logfile': 'out/2024_apr_decon/decon_rpj_wiki_contriever_ppl.log', 'domain': 'rpj-wiki-decon-contriever', 'plot': 'scaling_wiki_decon_plot_contriever.png'},
|
|
# {'logfile': 'out/2024_apr_decon/decon_rpj_book_contriever_ppl.log', 'domain': 'rpj-book-decon-contriever', 'plot': 'scaling_book_decon_plot_contriever.png'},
|
|
# {'logfile': 'out/2024_apr_decon/decon_rpj_arxiv_contriever_ppl.log', 'domain': 'rpj-arxiv-decon-contriever', 'plot': 'scaling_arxiv_decon_plot_contriever.png'},
|
|
# {'logfile': 'out/2024_apr_decon/decon_rpj_stackexchange_contriever_ppl.log', 'domain': 'rpj-stackexchange-decon-contriever', 'plot': 'scaling_stackexchange_decon_plot_contriever.png'},
|
|
# {'logfile': 'out/2024_apr_decon/decon_rpj_stackexchange_dragon_ppl.log', 'domain': 'rpj-stackexchange-decon-dragon', 'plot': 'scaling_stackexchange_decon_plot_dragon.png'},
|
|
# {'logfile': 'out/2024_apr_decon/decon_rpj_wiki_dragon_ppl.log', 'domain': 'rpj-wiki-decon-dragon', 'plot': 'scaling_wiki_decon_plot_dragon.png'},
|
|
# {'logfile': 'out/2024_apr_decon/decon_rpj_arxiv_dragon_ppl.log', 'domain': 'rpj-arxiv-decon-dragon', 'plot': 'scaling_arxiv_decon_plot_dragon.png'},
|
|
# {'logfile': 'out/2024_apr_decon/decon_rpj_book_dragon_ppl.log', 'domain': 'rpj-book-decon-dragon', 'plot': 'scaling_book_decon_plot_dragon.png'},
|
|
]
|
|
|
|
# for plot_info in plot_info_list:
|
|
# extract_dense_scaling_results([plot_info['logfile']], plot_info['domain'], plot_info['plot'])
|
|
|
|
model_name = "lclm"
|
|
subject_name = "gsm8k"
|
|
datastore_name = "c4"
|
|
result_dir = f"/gscratch/zlab/rulins/Scaling/lm_eval_results/{model_name}"
|
|
|
|
all_subjects = [
|
|
file
|
|
for file in os.listdir(result_dir)
|
|
if subject_name in file and datastore_name in file
|
|
]
|
|
for subject in all_subjects:
|
|
file_name = subject
|
|
print(file_name)
|
|
extract_lm_eval_results(
|
|
os.path.join(result_dir, file_name),
|
|
subject,
|
|
model_name,
|
|
[0, 5], # few-shot
|
|
[0, 3], # n-doc
|
|
file_name,
|
|
)
|
|
|
|
# plot_mmlu_persub_figures("plots/mmlu")
|
|
|
|
# compute_variance_across_hards(f'/gscratch/zlab/rulins/Scaling/lm_eval_results/llama2-7b/subsample/nq_open-rpj_c4-32_shards')
|
|
# compute_variance_across_hards(f'/gscratch/zlab/rulins/Scaling/lm_eval_results/llama2-7b/subsample/medqa_4options-rpj_c4-32_shards')
|
|
|
|
# plot_calibration_figures(domain='wiki', shard_id=1, show_all_points=True)
|
|
# plot_top1_vs_best_doc_per_sample(domain='stackexchange', shard_id=1, show_top_k=10, special_mark_k=0)
|