upload
This commit is contained in:
177
evaluators/arc.py
Normal file
177
evaluators/arc.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from typing import Dict, Sequence, Optional
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
import torch.distributed as dist
|
||||
|
||||
from dataset.build_arc_dataset import inverse_aug, grid_hash, arc_grid_to_np
|
||||
from dataset.common import PuzzleDatasetMetadata
|
||||
|
||||
@njit
|
||||
def _crop(grid: np.ndarray):
|
||||
"""Find maximum-sized rectangle without any EOS token inside. """
|
||||
grid = grid.reshape(30, 30)
|
||||
|
||||
max_area = 0
|
||||
max_size = (0, 0)
|
||||
nr, nc = grid.shape
|
||||
|
||||
num_c = nc
|
||||
for num_r in range(1, nr + 1):
|
||||
# Scan for maximum c
|
||||
for c in range(1, num_c + 1):
|
||||
x = grid[num_r - 1, c - 1]
|
||||
if (x < 2) | (x > 11):
|
||||
num_c = c - 1
|
||||
break
|
||||
|
||||
area = num_r * num_c
|
||||
if area > max_area:
|
||||
max_area = area
|
||||
max_size = (num_r, num_c)
|
||||
|
||||
return (grid[:max_size[0], :max_size[1]] - 2).astype(np.uint8)
|
||||
|
||||
|
||||
class ARC:
|
||||
required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"}
|
||||
|
||||
def __init__(self, data_path: str,
|
||||
eval_metadata: PuzzleDatasetMetadata,
|
||||
submission_K: int = 2,
|
||||
pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000),
|
||||
aggregated_voting: bool = True):
|
||||
super().__init__()
|
||||
self.pass_Ks = pass_Ks
|
||||
self.submission_K = submission_K
|
||||
self.aggregated_voting = aggregated_voting
|
||||
self.blank_identifier_id = eval_metadata.blank_identifier_id
|
||||
|
||||
# Load identifiers and test puzzles
|
||||
with open(os.path.join(data_path, "identifiers.json"), "r") as f:
|
||||
self.identifier_map = json.load(f)
|
||||
with open(os.path.join(data_path, "test_puzzles.json"), "r") as f:
|
||||
self.test_puzzles = json.load(f)
|
||||
|
||||
# States
|
||||
self._local_hmap = {}
|
||||
self._local_preds = {}
|
||||
|
||||
def begin_eval(self):
|
||||
if not self.aggregated_voting:
|
||||
# Clear previous predictions
|
||||
self._local_hmap = {}
|
||||
self._local_preds = {}
|
||||
|
||||
def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]):
|
||||
# Collect required outputs to CPU
|
||||
outputs = {}
|
||||
q_values = None
|
||||
|
||||
for collection in (batch, preds):
|
||||
for k, v in collection.items():
|
||||
if k in self.required_outputs:
|
||||
if k == "q_halt_logits":
|
||||
q_values = v.to(torch.float64).sigmoid().cpu()
|
||||
else:
|
||||
outputs[k] = v.cpu()
|
||||
|
||||
assert q_values is not None
|
||||
|
||||
# Remove padding from outputs
|
||||
mask = outputs["puzzle_identifiers"] != self.blank_identifier_id
|
||||
outputs = {k: v[mask] for k, v in outputs.items()}
|
||||
|
||||
# Get predictions
|
||||
for identifier, input, pred, q in zip(outputs["puzzle_identifiers"].numpy(), outputs["inputs"].numpy(), outputs["preds"].numpy(), q_values.numpy()):
|
||||
name = self.identifier_map[identifier]
|
||||
orig_name, _inverse_fn = inverse_aug(name)
|
||||
|
||||
input_hash = grid_hash(_inverse_fn(_crop(input)))
|
||||
|
||||
pred = _inverse_fn(_crop(pred))
|
||||
assert np.all((pred >= 0) & (pred <= 9)), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check
|
||||
|
||||
# Store into local state
|
||||
pred_hash = grid_hash(pred)
|
||||
|
||||
self._local_hmap[pred_hash] = pred
|
||||
|
||||
self._local_preds.setdefault(orig_name, {})
|
||||
self._local_preds[orig_name].setdefault(input_hash, [])
|
||||
self._local_preds[orig_name][input_hash].append((pred_hash, float(q)))
|
||||
|
||||
def result(self, save_path: Optional[str], rank: int, world_size: int, group: Optional[torch.distributed.ProcessGroup] = None) -> Optional[Dict[str, float]]:
|
||||
# Gather predictions to rank 0 for voting
|
||||
global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None
|
||||
dist.gather_object((self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group)
|
||||
|
||||
# Rank 0 logic
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
submission = {}
|
||||
correct = [0.0 for _ in range(len(self.pass_Ks))]
|
||||
|
||||
for name, puzzle in self.test_puzzles.items():
|
||||
# Process test examples in this puzzle
|
||||
submission[name] = []
|
||||
num_test_correct = [0 for _ in range(len(self.pass_Ks))]
|
||||
for pair in puzzle["test"]:
|
||||
input_hash = grid_hash(arc_grid_to_np(pair["input"]))
|
||||
label_hash = grid_hash(arc_grid_to_np(pair["output"]))
|
||||
|
||||
p_map = {}
|
||||
for hmap, preds in global_hmap_preds: # type: ignore
|
||||
for h, q in preds.get(name, {}).get(input_hash, {}):
|
||||
p_map.setdefault(h, [0, 0])
|
||||
p_map[h][0] += 1
|
||||
p_map[h][1] += q
|
||||
|
||||
if not len(p_map):
|
||||
print (f"Puzzle {name} has no predictions.")
|
||||
continue
|
||||
|
||||
for h, stats in p_map.items():
|
||||
stats[1] /= stats[0]
|
||||
|
||||
p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)
|
||||
|
||||
# vote for different Ks
|
||||
for i, k in enumerate(self.pass_Ks):
|
||||
ok = False
|
||||
for h, stats in p_map[:k]:
|
||||
ok |= h == label_hash
|
||||
|
||||
num_test_correct[i] += ok
|
||||
|
||||
# Query grids
|
||||
pred_grids = []
|
||||
for h, stats in p_map[:self.submission_K]:
|
||||
for hmap, preds in global_hmap_preds: # type: ignore
|
||||
if h in hmap:
|
||||
pred_grids.append(hmap[h])
|
||||
break
|
||||
|
||||
# Pad to K
|
||||
while len(pred_grids) < self.submission_K:
|
||||
pred_grids.append(pred_grids[0])
|
||||
|
||||
submission[name].append({f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)})
|
||||
|
||||
# Total correctness
|
||||
for i in range(len(self.pass_Ks)):
|
||||
correct[i] += num_test_correct[i] / len(puzzle["test"])
|
||||
|
||||
# Save submission
|
||||
if save_path is not None:
|
||||
with open(os.path.join(save_path, "submission.json"), "w") as f:
|
||||
json.dump(submission, f)
|
||||
|
||||
# Final result
|
||||
all_results = {f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)}
|
||||
|
||||
return all_results
|
||||
Reference in New Issue
Block a user