upload
This commit is contained in:
40
models/ema.py
Normal file
40
models/ema.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
|
||||
class EMAHelper(object):
|
||||
def __init__(self, mu=0.999):
|
||||
self.mu = mu
|
||||
self.shadow = {}
|
||||
|
||||
def register(self, module):
|
||||
if isinstance(module, nn.DataParallel):
|
||||
module = module.module
|
||||
for name, param in module.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.shadow[name] = param.data.clone()
|
||||
|
||||
def update(self, module):
|
||||
if isinstance(module, nn.DataParallel):
|
||||
module = module.module
|
||||
for name, param in module.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
|
||||
|
||||
def ema(self, module):
|
||||
if isinstance(module, nn.DataParallel):
|
||||
module = module.module
|
||||
for name, param in module.named_parameters():
|
||||
if param.requires_grad:
|
||||
param.data.copy_(self.shadow[name].data)
|
||||
|
||||
def ema_copy(self, module):
|
||||
module_copy = copy.deepcopy(module)
|
||||
self.ema(module_copy)
|
||||
return module_copy
|
||||
|
||||
def state_dict(self):
|
||||
return self.shadow
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.shadow = state_dict
|
||||
|
||||
Reference in New Issue
Block a user