upload
This commit is contained in:
32
models/common.py
Normal file
32
models/common.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
|
||||
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
|
||||
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
|
||||
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
|
||||
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
|
||||
|
||||
with torch.no_grad():
|
||||
if std == 0:
|
||||
tensor.zero_()
|
||||
else:
|
||||
sqrt2 = math.sqrt(2)
|
||||
a = math.erf(lower / sqrt2)
|
||||
b = math.erf(upper / sqrt2)
|
||||
z = (b - a) / 2
|
||||
|
||||
c = (2 * math.pi) ** -0.5
|
||||
pdf_u = c * math.exp(-0.5 * lower ** 2)
|
||||
pdf_l = c * math.exp(-0.5 * upper ** 2)
|
||||
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
|
||||
|
||||
tensor.uniform_(a, b)
|
||||
tensor.erfinv_()
|
||||
tensor.mul_(sqrt2 * comp_std)
|
||||
tensor.clip_(lower * comp_std, upper * comp_std)
|
||||
|
||||
return tensor
|
||||
Reference in New Issue
Block a user