upload
This commit is contained in:
341
dataset/build_arc_dataset.py
Normal file
341
dataset/build_arc_dataset.py
Normal file
@@ -0,0 +1,341 @@
|
||||
from typing import List, Tuple, Dict
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import numpy as np
|
||||
|
||||
from argdantic import ArgParser
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dataset.common import PuzzleDatasetMetadata, dihedral_transform, inverse_dihedral_transform
|
||||
|
||||
|
||||
cli = ArgParser()
|
||||
|
||||
|
||||
class DataProcessConfig(BaseModel):
|
||||
input_file_prefix: str
|
||||
output_dir: str
|
||||
subsets: List[str]
|
||||
test_set_name: str
|
||||
test_set_name2: str = "your_test_set"
|
||||
seed: int = 42
|
||||
num_aug: int = 1000
|
||||
puzzle_identifiers_start: int = 1 # start > 1 to handle multiple datasets
|
||||
|
||||
ARCMaxGridSize = 30
|
||||
ARCAugmentRetriesFactor = 5
|
||||
|
||||
PuzzleIdSeparator = "|||"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ARCPuzzle:
|
||||
id: str
|
||||
examples: List[Tuple[np.ndarray, np.ndarray]]
|
||||
|
||||
|
||||
def arc_grid_to_np(grid: List[List[int]]):
|
||||
arr = np.array(grid)
|
||||
|
||||
# Shape check
|
||||
assert arr.ndim == 2
|
||||
assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
|
||||
# Element check
|
||||
assert np.all((arr >= 0) & (arr <= 9))
|
||||
return arr.astype(np.uint8)
|
||||
|
||||
|
||||
def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
|
||||
# PAD: 0, <eos>: 1, digits: 2 ... 11
|
||||
# Compute random top-left pad
|
||||
if do_translation:
|
||||
pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
|
||||
pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
|
||||
else:
|
||||
pad_r = pad_c = 0
|
||||
|
||||
# Pad grid
|
||||
result = []
|
||||
for grid in [inp, out]:
|
||||
nrow, ncol = grid.shape
|
||||
grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
|
||||
|
||||
# Add <eos>
|
||||
eos_row, eos_col = pad_r + nrow, pad_c + ncol
|
||||
if eos_row < ARCMaxGridSize:
|
||||
grid[eos_row, pad_c:eos_col] = 1
|
||||
if eos_col < ARCMaxGridSize:
|
||||
grid[pad_r:eos_row, eos_col] = 1
|
||||
|
||||
result.append(grid.flatten())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def grid_hash(grid: np.ndarray):
|
||||
assert grid.ndim == 2
|
||||
assert grid.dtype == np.uint8
|
||||
|
||||
buffer = [x.to_bytes(1, byteorder='big') for x in grid.shape]
|
||||
buffer.append(grid.tobytes())
|
||||
|
||||
return hashlib.sha256(b"".join(buffer)).hexdigest()
|
||||
|
||||
|
||||
def puzzle_hash(puzzle: dict):
|
||||
# Hash the puzzle for checking equivalence
|
||||
hashes = []
|
||||
for example_type, example in puzzle.items():
|
||||
for input, label in example.examples:
|
||||
hashes.append(f"{grid_hash(input)}|{grid_hash(label)}")
|
||||
|
||||
hashes.sort()
|
||||
return hashlib.sha256("|".join(hashes).encode()).hexdigest()
|
||||
|
||||
|
||||
def aug(name: str):
|
||||
# Augment plan
|
||||
trans_id = np.random.randint(0, 8)
|
||||
mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
|
||||
|
||||
name_with_aug_repr = f"{name}{PuzzleIdSeparator}t{trans_id}{PuzzleIdSeparator}{''.join(str(x) for x in mapping)}"
|
||||
|
||||
def _map_grid(grid: np.ndarray):
|
||||
return dihedral_transform(mapping[grid], trans_id)
|
||||
|
||||
return name_with_aug_repr, _map_grid
|
||||
|
||||
|
||||
def inverse_aug(name: str):
|
||||
# Inverse the "aug" function
|
||||
if PuzzleIdSeparator not in name:
|
||||
return name, lambda x: x
|
||||
|
||||
trans_id, perm = name.split(PuzzleIdSeparator)[-2:]
|
||||
trans_id = int(trans_id[1:]) # Remove "t" letter
|
||||
inv_perm = np.argsort(list(perm)).astype(np.uint8)
|
||||
|
||||
def _map_grid(grid: np.ndarray):
|
||||
return inv_perm[inverse_dihedral_transform(grid, trans_id)]
|
||||
|
||||
return name.split(PuzzleIdSeparator)[0], _map_grid
|
||||
|
||||
|
||||
def convert_single_arc_puzzle(results: dict, name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
|
||||
# Convert
|
||||
dests = set(dest_mapping.values())
|
||||
converted = {dest: ARCPuzzle(name, []) for dest in dests}
|
||||
for example_type, examples in puzzle.items():
|
||||
# Map to target split
|
||||
dest = dest_mapping[example_type]
|
||||
converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
|
||||
|
||||
group = [converted]
|
||||
|
||||
# Augment
|
||||
if aug_count > 0:
|
||||
hashes = {puzzle_hash(converted)}
|
||||
|
||||
for _trial in range(ARCAugmentRetriesFactor * aug_count):
|
||||
aug_name, _map_grid = aug(name)
|
||||
|
||||
# Check duplicate
|
||||
augmented = {dest: ARCPuzzle(aug_name, [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
|
||||
h = puzzle_hash(augmented)
|
||||
if h not in hashes:
|
||||
hashes.add(h)
|
||||
group.append(augmented)
|
||||
|
||||
if len(group) >= aug_count + 1:
|
||||
break
|
||||
|
||||
if len(group) < aug_count + 1:
|
||||
print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
|
||||
|
||||
# Append
|
||||
for dest in dests:
|
||||
# Convert the examples
|
||||
dest_split, dest_set = dest
|
||||
|
||||
results.setdefault(dest_split, {})
|
||||
results[dest_split].setdefault(dest_set, [])
|
||||
results[dest_split][dest_set].append([converted[dest] for converted in group])
|
||||
|
||||
|
||||
def load_puzzles_arcagi(config: DataProcessConfig):
|
||||
train_examples_dest = ("train", "all")
|
||||
test_examples_map = {
|
||||
config.test_set_name: [(1.0, ("test", "all"))],
|
||||
config.test_set_name2: [(1.0, ("test", "all"))],
|
||||
"_default": [(1.0, ("train", "all"))]
|
||||
}
|
||||
|
||||
test_puzzles = {}
|
||||
results = {}
|
||||
|
||||
total_puzzles = 0
|
||||
for subset_name in config.subsets:
|
||||
# Load all puzzles in this subset
|
||||
with open(f"{config.input_file_prefix}_{subset_name}_challenges.json", "r") as f:
|
||||
puzzles = json.load(f)
|
||||
|
||||
sols_filename = f"{config.input_file_prefix}_{subset_name}_solutions.json"
|
||||
if os.path.isfile(sols_filename):
|
||||
with open(sols_filename, "r") as f:
|
||||
sols = json.load(f)
|
||||
|
||||
for puzzle_id in puzzles.keys():
|
||||
for idx, sol_grid in enumerate(sols[puzzle_id]):
|
||||
puzzles[puzzle_id]["test"][idx]["output"] = sol_grid
|
||||
else:
|
||||
# Fill with dummy
|
||||
print (f"{subset_name} solutions not found, filling with dummy")
|
||||
|
||||
for puzzle_id, puzzle in puzzles.items():
|
||||
for example in puzzle["test"]:
|
||||
example.setdefault("output", [[0]])
|
||||
|
||||
# Shuffle puzzles
|
||||
puzzles = list(puzzles.items())
|
||||
np.random.shuffle(puzzles)
|
||||
|
||||
# Assign by fraction
|
||||
for idx, (name, puzzle) in enumerate(puzzles):
|
||||
fraction = idx / len(puzzles)
|
||||
test_examples_dest = None
|
||||
for f, dest in test_examples_map.get(subset_name, test_examples_map["_default"]):
|
||||
if fraction < f:
|
||||
test_examples_dest = dest
|
||||
break
|
||||
|
||||
assert test_examples_dest is not None
|
||||
|
||||
if test_examples_dest[0] == "test":
|
||||
test_puzzles[name] = puzzle
|
||||
|
||||
convert_single_arc_puzzle(results, name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
|
||||
total_puzzles += 1
|
||||
|
||||
print (f"Total puzzles: {total_puzzles}")
|
||||
return results, test_puzzles
|
||||
|
||||
|
||||
def convert_dataset(config: DataProcessConfig):
|
||||
np.random.seed(config.seed)
|
||||
|
||||
# Read dataset
|
||||
data, test_puzzles = load_puzzles_arcagi(config)
|
||||
|
||||
# Map global puzzle identifiers
|
||||
num_identifiers = config.puzzle_identifiers_start # 0 is blank, start at 1
|
||||
identifier_map = {}
|
||||
for split_name, split in data.items():
|
||||
for subset_name, subset in split.items():
|
||||
for group in subset:
|
||||
for puzzle in group:
|
||||
if puzzle.id not in identifier_map:
|
||||
identifier_map[puzzle.id] = num_identifiers
|
||||
num_identifiers += 1
|
||||
print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
|
||||
|
||||
# Save
|
||||
for split_name, split in data.items():
|
||||
os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
|
||||
|
||||
# Translational augmentations
|
||||
enable_translational_augment = split_name == "train"
|
||||
|
||||
# Statistics
|
||||
total_examples = 0
|
||||
total_puzzles = 0
|
||||
total_groups = 0
|
||||
|
||||
for subset_name, subset in split.items(): # "all" is the only subset
|
||||
# Construct subset
|
||||
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
||||
results["puzzle_indices"].append(0)
|
||||
results["group_indices"].append(0)
|
||||
|
||||
example_id = 0
|
||||
puzzle_id = 0
|
||||
|
||||
for group in subset:
|
||||
for puzzle in group:
|
||||
# Push puzzle
|
||||
no_aug_id = np.random.randint(0, len(puzzle.examples))
|
||||
for _idx_ex, (inp, out) in enumerate(puzzle.examples):
|
||||
inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
|
||||
|
||||
results["inputs"].append(inp)
|
||||
results["labels"].append(out)
|
||||
example_id += 1
|
||||
|
||||
total_examples += 1
|
||||
|
||||
results["puzzle_indices"].append(example_id)
|
||||
results["puzzle_identifiers"].append(identifier_map[puzzle.id])
|
||||
|
||||
puzzle_id += 1
|
||||
total_puzzles += 1
|
||||
|
||||
# Push group
|
||||
results["group_indices"].append(puzzle_id)
|
||||
total_groups += 1
|
||||
|
||||
for k, v in results.items():
|
||||
if k in {"inputs", "labels"}:
|
||||
v = np.stack(v, 0)
|
||||
else:
|
||||
v = np.array(v, dtype=np.int32)
|
||||
|
||||
np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
|
||||
|
||||
# Metadata
|
||||
metadata = PuzzleDatasetMetadata(
|
||||
seq_len=ARCMaxGridSize * ARCMaxGridSize,
|
||||
vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
|
||||
pad_id=0,
|
||||
ignore_label_id=0,
|
||||
blank_identifier_id=0,
|
||||
num_puzzle_identifiers=num_identifiers,
|
||||
total_groups=total_groups,
|
||||
mean_puzzle_examples=total_examples / total_puzzles,
|
||||
total_puzzles=total_puzzles,
|
||||
sets=list(split.keys())
|
||||
)
|
||||
|
||||
# Save metadata as JSON.
|
||||
with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
|
||||
json.dump(metadata.model_dump(), f)
|
||||
|
||||
# Save IDs mapping
|
||||
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
||||
ids_mapping = {v: k for k, v in identifier_map.items()}
|
||||
json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
|
||||
|
||||
# Save Test Puzzles
|
||||
with open(os.path.join(config.output_dir, "test_puzzles.json"), "w") as f:
|
||||
json.dump(test_puzzles, f)
|
||||
|
||||
|
||||
@cli.command(singleton=True)
|
||||
def main(config: DataProcessConfig):
|
||||
convert_dataset(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
140
dataset/build_maze_dataset.py
Normal file
140
dataset/build_maze_dataset.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from typing import Optional
|
||||
import math
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from argdantic import ArgParser
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from common import PuzzleDatasetMetadata, dihedral_transform
|
||||
|
||||
|
||||
CHARSET = "# SGo"
|
||||
|
||||
|
||||
cli = ArgParser()
|
||||
|
||||
|
||||
class DataProcessConfig(BaseModel):
|
||||
source_repo: str = "sapientinc/maze-30x30-hard-1k"
|
||||
output_dir: str = "data/maze-30x30-hard-1k"
|
||||
|
||||
subsample_size: Optional[int] = None
|
||||
aug: bool = False
|
||||
|
||||
|
||||
def convert_subset(set_name: str, config: DataProcessConfig):
|
||||
# Read CSV
|
||||
all_chars = set()
|
||||
grid_size = None
|
||||
inputs = []
|
||||
labels = []
|
||||
|
||||
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
|
||||
reader = csv.reader(csvfile)
|
||||
next(reader) # Skip header
|
||||
for source, q, a, rating in reader:
|
||||
all_chars.update(q)
|
||||
all_chars.update(a)
|
||||
|
||||
if grid_size is None:
|
||||
n = int(len(q) ** 0.5)
|
||||
grid_size = (n, n)
|
||||
|
||||
inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
|
||||
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
|
||||
|
||||
# If subsample_size is specified for the training set,
|
||||
# randomly sample the desired number of examples.
|
||||
if set_name == "train" and config.subsample_size is not None:
|
||||
total_samples = len(inputs)
|
||||
if config.subsample_size < total_samples:
|
||||
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
||||
inputs = [inputs[i] for i in indices]
|
||||
labels = [labels[i] for i in indices]
|
||||
|
||||
# Generate dataset
|
||||
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
||||
puzzle_id = 0
|
||||
example_id = 0
|
||||
|
||||
results["puzzle_indices"].append(0)
|
||||
results["group_indices"].append(0)
|
||||
|
||||
for inp, out in zip(tqdm(inputs), labels):
|
||||
# Dihedral transformations for augmentation
|
||||
for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
|
||||
results["inputs"].append(dihedral_transform(inp, aug_idx))
|
||||
results["labels"].append(dihedral_transform(out, aug_idx))
|
||||
example_id += 1
|
||||
puzzle_id += 1
|
||||
|
||||
results["puzzle_indices"].append(example_id)
|
||||
results["puzzle_identifiers"].append(0)
|
||||
|
||||
# Push group
|
||||
results["group_indices"].append(puzzle_id)
|
||||
|
||||
# Char mappings
|
||||
assert len(all_chars - set(CHARSET)) == 0
|
||||
|
||||
char2id = np.zeros(256, np.uint8)
|
||||
char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
|
||||
|
||||
# To Numpy
|
||||
def _seq_to_numpy(seq):
|
||||
arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
|
||||
|
||||
return arr
|
||||
|
||||
results = {
|
||||
"inputs": _seq_to_numpy(results["inputs"]),
|
||||
"labels": _seq_to_numpy(results["labels"]),
|
||||
|
||||
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
||||
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
||||
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
||||
}
|
||||
|
||||
# Metadata
|
||||
metadata = PuzzleDatasetMetadata(
|
||||
seq_len=int(math.prod(grid_size)), # type: ignore
|
||||
vocab_size=len(CHARSET) + 1, # PAD + Charset
|
||||
pad_id=0,
|
||||
ignore_label_id=0,
|
||||
blank_identifier_id=0,
|
||||
num_puzzle_identifiers=1,
|
||||
total_groups=len(results["group_indices"]) - 1,
|
||||
mean_puzzle_examples=1,
|
||||
total_puzzles=len(results["group_indices"]) - 1,
|
||||
sets=["all"]
|
||||
)
|
||||
|
||||
# Save metadata as JSON.
|
||||
save_dir = os.path.join(config.output_dir, set_name)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
||||
json.dump(metadata.model_dump(), f)
|
||||
|
||||
# Save data
|
||||
for k, v in results.items():
|
||||
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
||||
|
||||
# Save IDs mapping (for visualization only)
|
||||
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
||||
json.dump(["<blank>"], f)
|
||||
|
||||
|
||||
@cli.command(singleton=True)
|
||||
def preprocess_data(config: DataProcessConfig):
|
||||
convert_subset("train", config)
|
||||
convert_subset("test", config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
167
dataset/build_sudoku_dataset.py
Normal file
167
dataset/build_sudoku_dataset.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from typing import Optional
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from argdantic import ArgParser
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from common import PuzzleDatasetMetadata
|
||||
|
||||
|
||||
cli = ArgParser()
|
||||
|
||||
|
||||
class DataProcessConfig(BaseModel):
|
||||
source_repo: str = "sapientinc/sudoku-extreme"
|
||||
output_dir: str = "data/sudoku-extreme-full"
|
||||
|
||||
subsample_size: Optional[int] = None
|
||||
min_difficulty: Optional[int] = None
|
||||
num_aug: int = 0
|
||||
|
||||
|
||||
def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
|
||||
# Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
|
||||
digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
|
||||
|
||||
# Randomly decide whether to transpose.
|
||||
transpose_flag = np.random.rand() < 0.5
|
||||
|
||||
# Generate a valid row permutation:
|
||||
# - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
|
||||
bands = np.random.permutation(3)
|
||||
row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
|
||||
|
||||
# Similarly for columns (stacks).
|
||||
stacks = np.random.permutation(3)
|
||||
col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
|
||||
|
||||
# Build an 81->81 mapping. For each new cell at (i, j)
|
||||
# (row index = i // 9, col index = i % 9),
|
||||
# its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
|
||||
mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
|
||||
|
||||
def apply_transformation(x: np.ndarray) -> np.ndarray:
|
||||
# Apply transpose flag
|
||||
if transpose_flag:
|
||||
x = x.T
|
||||
# Apply the position mapping.
|
||||
new_board = x.flatten()[mapping].reshape(9, 9).copy()
|
||||
# Apply digit mapping
|
||||
return digit_map[new_board]
|
||||
|
||||
return apply_transformation(board), apply_transformation(solution)
|
||||
|
||||
|
||||
def convert_subset(set_name: str, config: DataProcessConfig):
|
||||
# Read CSV
|
||||
inputs = []
|
||||
labels = []
|
||||
|
||||
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
|
||||
reader = csv.reader(csvfile)
|
||||
next(reader) # Skip header
|
||||
for source, q, a, rating in reader:
|
||||
if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
|
||||
assert len(q) == 81 and len(a) == 81
|
||||
|
||||
inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
||||
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
||||
|
||||
# If subsample_size is specified for the training set,
|
||||
# randomly sample the desired number of examples.
|
||||
if set_name == "train" and config.subsample_size is not None:
|
||||
total_samples = len(inputs)
|
||||
if config.subsample_size < total_samples:
|
||||
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
||||
inputs = [inputs[i] for i in indices]
|
||||
labels = [labels[i] for i in indices]
|
||||
|
||||
# Generate dataset
|
||||
num_augments = config.num_aug if set_name == "train" else 0
|
||||
|
||||
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
||||
puzzle_id = 0
|
||||
example_id = 0
|
||||
|
||||
results["puzzle_indices"].append(0)
|
||||
results["group_indices"].append(0)
|
||||
|
||||
for orig_inp, orig_out in zip(tqdm(inputs), labels):
|
||||
for aug_idx in range(1 + num_augments):
|
||||
# First index is not augmented
|
||||
if aug_idx == 0:
|
||||
inp, out = orig_inp, orig_out
|
||||
else:
|
||||
inp, out = shuffle_sudoku(orig_inp, orig_out)
|
||||
|
||||
# Push puzzle (only single example)
|
||||
results["inputs"].append(inp)
|
||||
results["labels"].append(out)
|
||||
example_id += 1
|
||||
puzzle_id += 1
|
||||
|
||||
results["puzzle_indices"].append(example_id)
|
||||
results["puzzle_identifiers"].append(0)
|
||||
|
||||
# Push group
|
||||
results["group_indices"].append(puzzle_id)
|
||||
|
||||
# To Numpy
|
||||
def _seq_to_numpy(seq):
|
||||
arr = np.concatenate(seq).reshape(len(seq), -1)
|
||||
|
||||
assert np.all((arr >= 0) & (arr <= 9))
|
||||
return arr + 1
|
||||
|
||||
results = {
|
||||
"inputs": _seq_to_numpy(results["inputs"]),
|
||||
"labels": _seq_to_numpy(results["labels"]),
|
||||
|
||||
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
||||
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
||||
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
||||
}
|
||||
|
||||
# Metadata
|
||||
metadata = PuzzleDatasetMetadata(
|
||||
seq_len=81,
|
||||
vocab_size=10 + 1, # PAD + "0" ... "9"
|
||||
pad_id=0,
|
||||
ignore_label_id=0,
|
||||
blank_identifier_id=0,
|
||||
num_puzzle_identifiers=1,
|
||||
total_groups=len(results["group_indices"]) - 1,
|
||||
mean_puzzle_examples=1,
|
||||
total_puzzles=len(results["group_indices"]) - 1,
|
||||
sets=["all"]
|
||||
)
|
||||
|
||||
# Save metadata as JSON.
|
||||
save_dir = os.path.join(config.output_dir, set_name)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
||||
json.dump(metadata.model_dump(), f)
|
||||
|
||||
# Save data
|
||||
for k, v in results.items():
|
||||
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
||||
|
||||
# Save IDs mapping (for visualization only)
|
||||
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
||||
json.dump(["<blank>"], f)
|
||||
|
||||
|
||||
@cli.command(singleton=True)
|
||||
def preprocess_data(config: DataProcessConfig):
|
||||
convert_subset("train", config)
|
||||
convert_subset("test", config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
49
dataset/common.py
Normal file
49
dataset/common.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import pydantic
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Global list mapping each dihedral transform id to its inverse.
|
||||
# Index corresponds to the original tid, and the value is its inverse.
|
||||
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
|
||||
|
||||
|
||||
class PuzzleDatasetMetadata(pydantic.BaseModel):
|
||||
pad_id: int
|
||||
ignore_label_id: Optional[int]
|
||||
blank_identifier_id: int
|
||||
vocab_size: int
|
||||
seq_len: int
|
||||
num_puzzle_identifiers: int
|
||||
total_groups: int
|
||||
mean_puzzle_examples: float
|
||||
total_puzzles: int
|
||||
sets: List[str]
|
||||
|
||||
|
||||
def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
||||
"""8 dihedral symmetries by rotate, flip and mirror"""
|
||||
|
||||
if tid == 0:
|
||||
return arr # identity
|
||||
elif tid == 1:
|
||||
return np.rot90(arr, k=1)
|
||||
elif tid == 2:
|
||||
return np.rot90(arr, k=2)
|
||||
elif tid == 3:
|
||||
return np.rot90(arr, k=3)
|
||||
elif tid == 4:
|
||||
return np.fliplr(arr) # horizontal flip
|
||||
elif tid == 5:
|
||||
return np.flipud(arr) # vertical flip
|
||||
elif tid == 6:
|
||||
return arr.T # transpose (reflection along main diagonal)
|
||||
elif tid == 7:
|
||||
return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
|
||||
else:
|
||||
return arr
|
||||
|
||||
|
||||
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
||||
return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
|
||||
Reference in New Issue
Block a user