deep.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. from __future__ import absolute_import, division, print_function
  2. import math
  3. import os
  4. import typing as ty
  5. from copy import deepcopy
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.optim as optim
  10. import zero
  11. from torch import Tensor
  12. class IndexLoader:
  13. def __init__(
  14. self, train_size: int, batch_size: int, shuffle: bool, device: torch.device
  15. ) -> None:
  16. self._train_size = train_size
  17. self._batch_size = batch_size
  18. self._shuffle = shuffle
  19. self._device = device
  20. def __len__(self) -> int:
  21. return math.ceil(self._train_size / self._batch_size)
  22. def __iter__(self):
  23. indices = list(
  24. zero.iloader(self._train_size, self._batch_size, shuffle=self._shuffle)
  25. )
  26. return iter(torch.cat(indices).to(self._device).split(self._batch_size))
  27. class Lambda(nn.Module):
  28. def __init__(self, f: ty.Callable) -> None:
  29. super().__init__()
  30. self.f = f
  31. def forward(self, x):
  32. return self.f(x)
  33. # Source: https://github.com/bzhangGo/rmsnorm
  34. # NOTE: eps is changed to 1e-5
  35. class RMSNorm(nn.Module):
  36. def __init__(self, d, p=-1.0, eps=1e-5, bias=False):
  37. """Root Mean Square Layer Normalization
  38. :param d: model size
  39. :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
  40. :param eps: epsilon value, default 1e-8
  41. :param bias: whether use bias term for RMSNorm, disabled by
  42. default because RMSNorm doesn't enforce re-centering invariance.
  43. """
  44. super(RMSNorm, self).__init__()
  45. self.eps = eps
  46. self.d = d
  47. self.p = p
  48. self.bias = bias
  49. self.scale = nn.Parameter(torch.ones(d))
  50. self.register_parameter("scale", self.scale)
  51. if self.bias:
  52. self.offset = nn.Parameter(torch.zeros(d))
  53. self.register_parameter("offset", self.offset)
  54. def forward(self, x):
  55. if self.p < 0.0 or self.p > 1.0:
  56. norm_x = x.norm(2, dim=-1, keepdim=True)
  57. d_x = self.d
  58. else:
  59. partial_size = int(self.d * self.p)
  60. partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
  61. norm_x = partial_x.norm(2, dim=-1, keepdim=True)
  62. d_x = partial_size
  63. rms_x = norm_x * d_x ** (-1.0 / 2)
  64. x_normed = x / (rms_x + self.eps)
  65. if self.bias:
  66. return self.scale * x_normed + self.offset
  67. return self.scale * x_normed
  68. class ScaleNorm(nn.Module):
  69. """
  70. Sources:
  71. - https://github.com/tnq177/transformers_without_tears/blob/25026061979916afb193274438f7097945acf9bc/layers.py#L132
  72. - https://github.com/tnq177/transformers_without_tears/blob/6b2726cd9e6e642d976ae73b9f696d9d7ff4b395/layers.py#L157
  73. """
  74. def __init__(self, d: int, eps: float = 1e-5, clamp: bool = False) -> None:
  75. super(ScaleNorm, self).__init__()
  76. self.scale = nn.Parameter(torch.tensor(d ** 0.5))
  77. self.eps = eps
  78. self.clamp = clamp
  79. def forward(self, x):
  80. norms = torch.norm(x, dim=-1, keepdim=True)
  81. norms = norms.clamp(min=self.eps) if self.clamp else norms + self.eps
  82. return self.scale * x / norms
  83. def reglu(x: Tensor) -> Tensor:
  84. a, b = x.chunk(2, dim=-1)
  85. return a * F.relu(b)
  86. def geglu(x: Tensor) -> Tensor:
  87. a, b = x.chunk(2, dim=-1)
  88. return a * F.gelu(b)
  89. class ReGLU(nn.Module):
  90. def forward(self, x: Tensor) -> Tensor:
  91. return reglu(x)
  92. class GEGLU(nn.Module):
  93. def forward(self, x: Tensor) -> Tensor:
  94. return geglu(x)
  95. def make_optimizer(
  96. optimizer: str,
  97. parameter_groups,
  98. lr: float,
  99. weight_decay: float,
  100. ) -> optim.Optimizer:
  101. Optimizer = {
  102. 'adabelief': AdaBelief,
  103. 'adam': optim.Adam,
  104. 'adamw': optim.AdamW,
  105. 'radam': RAdam,
  106. 'sgd': optim.SGD,
  107. }[optimizer]
  108. momentum = (0.9,) if Optimizer is optim.SGD else ()
  109. return Optimizer(parameter_groups, lr, *momentum, weight_decay=weight_decay)
  110. def make_lr_schedule(
  111. optimizer: optim.Optimizer,
  112. lr: float,
  113. epoch_size: int,
  114. lr_schedule: ty.Optional[ty.Dict[str, ty.Any]],
  115. ) -> ty.Tuple[
  116. ty.Optional[optim.lr_scheduler._LRScheduler],
  117. ty.Dict[str, ty.Any],
  118. ty.Optional[int],
  119. ]:
  120. if lr_schedule is None:
  121. lr_schedule = {'type': 'constant'}
  122. lr_scheduler = None
  123. n_warmup_steps = None
  124. if lr_schedule['type'] in ['transformer', 'linear_warmup']:
  125. n_warmup_steps = (
  126. lr_schedule['n_warmup_steps']
  127. if 'n_warmup_steps' in lr_schedule
  128. else lr_schedule['n_warmup_epochs'] * epoch_size
  129. )
  130. elif lr_schedule['type'] == 'cyclic':
  131. lr_scheduler = optim.lr_scheduler.CyclicLR(
  132. optimizer,
  133. base_lr=lr,
  134. max_lr=lr_schedule['max_lr'],
  135. step_size_up=lr_schedule['n_epochs_up'] * epoch_size,
  136. step_size_down=lr_schedule['n_epochs_down'] * epoch_size,
  137. mode=lr_schedule['mode'],
  138. gamma=lr_schedule.get('gamma', 1.0),
  139. cycle_momentum=False,
  140. )
  141. return lr_scheduler, lr_schedule, n_warmup_steps
  142. def get_activation_fn(name: str) -> ty.Callable[[Tensor], Tensor]:
  143. return (
  144. reglu
  145. if name == 'reglu'
  146. else geglu
  147. if name == 'geglu'
  148. else torch.sigmoid
  149. if name == 'sigmoid'
  150. else getattr(F, name)
  151. )
  152. def get_nonglu_activation_fn(name: str) -> ty.Callable[[Tensor], Tensor]:
  153. return (
  154. F.relu
  155. if name == 'reglu'
  156. else F.gelu
  157. if name == 'geglu'
  158. else get_activation_fn(name)
  159. )
  160. def load_swa_state_dict(model: nn.Module, swa_model: optim.swa_utils.AveragedModel):
  161. state_dict = deepcopy(swa_model.state_dict())
  162. del state_dict['n_averaged']
  163. model.load_state_dict({k[len('module.') :]: v for k, v in state_dict.items()})
  164. def get_epoch_parameters(
  165. train_size: int, batch_size: ty.Union[int, str]
  166. ) -> ty.Tuple[int, int]:
  167. if isinstance(batch_size, str):
  168. if batch_size == 'v3':
  169. batch_size = (
  170. 256 if train_size < 50000 else 512 if train_size < 100000 else 1024
  171. )
  172. elif batch_size == 'v1':
  173. batch_size = (
  174. 16
  175. if train_size < 1000
  176. else 32
  177. if train_size < 10000
  178. else 64
  179. if train_size < 50000
  180. else 128
  181. if train_size < 100000
  182. else 256
  183. if train_size < 200000
  184. else 512
  185. if train_size < 500000
  186. else 1024
  187. )
  188. elif batch_size == 'v2':
  189. batch_size = (
  190. 512 if train_size < 100000 else 1024 if train_size < 500000 else 2048
  191. )
  192. return batch_size, math.ceil(train_size / batch_size) # type: ignore[code]
  193. def get_linear_warmup_lr(lr: float, n_warmup_steps: int, step: int) -> float:
  194. assert step > 0, "1-based enumeration of steps is expected"
  195. return min(lr, step / (n_warmup_steps + 1) * lr)
  196. def get_manual_lr(schedule: ty.List[float], epoch: int) -> float:
  197. assert epoch > 0, "1-based enumeration of epochs is expected"
  198. return schedule[min(epoch, len(schedule)) - 1]
  199. def get_transformer_lr(scale: float, d: int, n_warmup_steps: int, step: int) -> float:
  200. return scale * d ** -0.5 * min(step ** -0.5, step * n_warmup_steps ** -1.5)
  201. def learn(model, optimizer, loss_fn, step, batch, star) -> ty.Tuple[Tensor, ty.Any]:
  202. model.train()
  203. optimizer.zero_grad()
  204. out = step(batch)
  205. loss = loss_fn(*out) if star else loss_fn(out)
  206. loss.backward()
  207. optimizer.step()
  208. return loss, out
  209. def _learn_with_virtual_batch(
  210. model, optimizer, loss_fn, step, batch, chunk_size
  211. ) -> Tensor:
  212. batch_size = len(batch)
  213. if chunk_size >= batch_size:
  214. return learn(model, optimizer, loss_fn, step, batch, True)[0]
  215. model.train()
  216. optimizer.zero_grad()
  217. total_loss = None
  218. for chunk in zero.iter_batches(batch, chunk_size):
  219. loss = loss_fn(*step(chunk))
  220. loss = loss * len(chunk)
  221. loss.backward()
  222. if total_loss is None:
  223. total_loss = loss.detach()
  224. else:
  225. total_loss += loss.detach()
  226. for x in model.parameters():
  227. if x.grad is not None:
  228. x.grad /= batch_size
  229. optimizer.step()
  230. return total_loss / batch_size
  231. def learn_with_auto_virtual_batch(
  232. model,
  233. optimizer,
  234. loss_fn,
  235. step,
  236. batch,
  237. batch_size_hint: int,
  238. chunk_size: ty.Optional[int],
  239. ) -> ty.Tuple[Tensor, ty.Optional[int]]:
  240. """This is just an overcomplicated version of `train_with_auto_virtual_batch`."""
  241. random_state = zero.get_random_state()
  242. while chunk_size != 0:
  243. try:
  244. zero.set_random_state(random_state)
  245. return (
  246. _learn_with_virtual_batch(
  247. model,
  248. optimizer,
  249. loss_fn,
  250. step,
  251. batch,
  252. chunk_size or batch_size_hint,
  253. ),
  254. chunk_size,
  255. )
  256. except RuntimeError as err:
  257. if not is_oom_exception(err):
  258. raise
  259. if chunk_size is None:
  260. chunk_size = batch_size_hint
  261. chunk_size //= 2
  262. raise RuntimeError('Not enough memory even for batch_size=1')
  263. def train_with_auto_virtual_batch(
  264. optimizer,
  265. loss_fn,
  266. step,
  267. batch,
  268. chunk_size: int,
  269. ) -> ty.Tuple[Tensor, int]:
  270. batch_size = len(batch)
  271. random_state = zero.get_random_state()
  272. while chunk_size != 0:
  273. try:
  274. zero.set_random_state(random_state)
  275. optimizer.zero_grad()
  276. if batch_size <= chunk_size:
  277. loss = loss_fn(*step(batch))
  278. loss.backward()
  279. else:
  280. loss = None
  281. for chunk in zero.iter_batches(batch, chunk_size):
  282. chunk_loss = loss_fn(*step(chunk))
  283. chunk_loss = chunk_loss * (len(chunk) / batch_size)
  284. chunk_loss.backward()
  285. if loss is None:
  286. loss = chunk_loss.detach()
  287. else:
  288. loss += chunk_loss.detach()
  289. except RuntimeError as err:
  290. if not is_oom_exception(err):
  291. raise
  292. chunk_size //= 2
  293. else:
  294. break
  295. if not chunk_size:
  296. raise RuntimeError('Not enough memory even for batch_size=1')
  297. optimizer.step()
  298. return loss, chunk_size # type: ignore[code]
  299. def tensor(x) -> torch.Tensor:
  300. assert isinstance(x, torch.Tensor)
  301. return ty.cast(torch.Tensor, x)
  302. def get_n_parameters(m: nn.Module):
  303. return sum(x.numel() for x in m.parameters() if x.requires_grad)
  304. def get_mlp_n_parameters(units: ty.List[int]):
  305. x = 0
  306. for a, b in zip(units, units[1:]):
  307. x += a * b + b
  308. return x
  309. def get_lr(optimizer: optim.Optimizer) -> float:
  310. return next(iter(optimizer.param_groups))['lr']
  311. def set_lr(optimizer: optim.Optimizer, lr: float) -> None:
  312. for x in optimizer.param_groups:
  313. x['lr'] = lr
  314. def get_device() -> torch.device:
  315. return torch.device('cuda:0' if os.environ.get('CUDA_VISIBLE_DEVICES') else 'cpu')
  316. @torch.no_grad()
  317. def get_gradient_norm_ratios(m: nn.Module):
  318. return {
  319. k: v.grad.norm() / v.norm()
  320. for k, v in m.named_parameters()
  321. if v.grad is not None
  322. }
  323. def is_oom_exception(err: RuntimeError) -> bool:
  324. return any(
  325. x in str(err)
  326. for x in [
  327. 'CUDA out of memory',
  328. 'CUBLAS_STATUS_ALLOC_FAILED',
  329. 'CUDA error: out of memory',
  330. ]
  331. )
  332. # Source: https://github.com/LiyuanLucasLiu/RAdam
  333. class RAdam(optim.Optimizer):
  334. def __init__(
  335. self,
  336. params,
  337. lr=1e-3,
  338. betas=(0.9, 0.999),
  339. eps=1e-8,
  340. weight_decay=0,
  341. degenerated_to_sgd=True,
  342. ):
  343. if not 0.0 <= lr:
  344. raise ValueError("Invalid learning rate: {}".format(lr))
  345. if not 0.0 <= eps:
  346. raise ValueError("Invalid epsilon value: {}".format(eps))
  347. if not 0.0 <= betas[0] < 1.0:
  348. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  349. if not 0.0 <= betas[1] < 1.0:
  350. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  351. self.degenerated_to_sgd = degenerated_to_sgd
  352. if (
  353. isinstance(params, (list, tuple))
  354. and len(params) > 0
  355. and isinstance(params[0], dict)
  356. ):
  357. for param in params:
  358. if 'betas' in param and (
  359. param['betas'][0] != betas[0] or param['betas'][1] != betas[1]
  360. ):
  361. param['buffer'] = [[None, None, None] for _ in range(10)]
  362. defaults = dict(
  363. lr=lr,
  364. betas=betas,
  365. eps=eps,
  366. weight_decay=weight_decay,
  367. buffer=[[None, None, None] for _ in range(10)],
  368. )
  369. super(RAdam, self).__init__(params, defaults)
  370. def __setstate__(self, state):
  371. super(RAdam, self).__setstate__(state)
  372. def step(self, closure=None):
  373. loss = None
  374. if closure is not None:
  375. loss = closure()
  376. for group in self.param_groups:
  377. for p in group['params']:
  378. if p.grad is None:
  379. continue
  380. grad = p.grad.data.float()
  381. if grad.is_sparse:
  382. raise RuntimeError('RAdam does not support sparse gradients')
  383. p_data_fp32 = p.data.float()
  384. state = self.state[p]
  385. if len(state) == 0:
  386. state['step'] = 0
  387. state['exp_avg'] = torch.zeros_like(p_data_fp32)
  388. state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
  389. else:
  390. state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
  391. state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
  392. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  393. beta1, beta2 = group['betas']
  394. exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
  395. exp_avg.mul_(beta1).add_(1 - beta1, grad)
  396. state['step'] += 1
  397. buffered = group['buffer'][int(state['step'] % 10)]
  398. if state['step'] == buffered[0]:
  399. N_sma, step_size = buffered[1], buffered[2]
  400. else:
  401. buffered[0] = state['step']
  402. beta2_t = beta2 ** state['step']
  403. N_sma_max = 2 / (1 - beta2) - 1
  404. N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
  405. buffered[1] = N_sma
  406. # more conservative since it's an approximated value
  407. if N_sma >= 5:
  408. step_size = math.sqrt(
  409. (1 - beta2_t)
  410. * (N_sma - 4)
  411. / (N_sma_max - 4)
  412. * (N_sma - 2)
  413. / N_sma
  414. * N_sma_max
  415. / (N_sma_max - 2)
  416. ) / (1 - beta1 ** state['step'])
  417. elif self.degenerated_to_sgd:
  418. step_size = 1.0 / (1 - beta1 ** state['step'])
  419. else:
  420. step_size = -1
  421. buffered[2] = step_size
  422. # more conservative since it's an approximated value
  423. if N_sma >= 5:
  424. if group['weight_decay'] != 0:
  425. p_data_fp32.add_(
  426. -group['weight_decay'] * group['lr'], p_data_fp32
  427. )
  428. denom = exp_avg_sq.sqrt().add_(group['eps'])
  429. p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
  430. p.data.copy_(p_data_fp32)
  431. elif step_size > 0:
  432. if group['weight_decay'] != 0:
  433. p_data_fp32.add_(
  434. -group['weight_decay'] * group['lr'], p_data_fp32
  435. )
  436. p_data_fp32.add_(-step_size * group['lr'], exp_avg)
  437. p.data.copy_(p_data_fp32)
  438. return loss
  439. version_higher = torch.__version__ >= "1.5.0"
  440. # Source: https://github.com/juntang-zhuang/Adabelief-Optimizer
  441. class AdaBelief(optim.Optimizer):
  442. r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
  443. Arguments:
  444. params (iterable): iterable of parameters to optimize or dicts defining
  445. parameter groups
  446. lr (float, optional): learning rate (default: 1e-3)
  447. betas (Tuple[float, float], optional): coefficients used for computing
  448. running averages of gradient and its square (default: (0.9, 0.999))
  449. eps (float, optional): term added to the denominator to improve
  450. numerical stability (default: 1e-16)
  451. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  452. amsgrad (boolean, optional): whether to use the AMSGrad variant of this
  453. algorithm from the paper `On the Convergence of Adam and Beyond`_
  454. (default: False)
  455. weight_decouple (boolean, optional): ( default: True) If set as True, then
  456. the optimizer uses decoupled weight decay as in AdamW
  457. fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
  458. is set as True.
  459. When fixed_decay == True, the weight decay is performed as
  460. $W_{new} = W_{old} - W_{old} \times decay$.
  461. When fixed_decay == False, the weight decay is performed as
  462. $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the
  463. weight decay ratio decreases with learning rate (lr).
  464. rectify (boolean, optional): (default: True) If set as True, then perform the rectified
  465. update similar to RAdam
  466. degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
  467. when variance of gradient is high
  468. print_change_log (boolean, optional) (default: True) If set as True, print the modifcation to
  469. default hyper-parameters
  470. reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
  471. """
  472. def __init__(
  473. self,
  474. params,
  475. lr=1e-3,
  476. betas=(0.9, 0.999),
  477. eps=1e-16,
  478. weight_decay=0,
  479. amsgrad=False,
  480. weight_decouple=True,
  481. fixed_decay=False,
  482. rectify=True,
  483. degenerated_to_sgd=True,
  484. print_change_log=True,
  485. ):
  486. # ------------------------------------------------------------------------------
  487. # Print modifications to default arguments
  488. if print_change_log:
  489. print(
  490. 'Please check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.'
  491. )
  492. print('Modifications to default arguments:')
  493. default_table = [
  494. ['eps', 'weight_decouple', 'rectify'],
  495. ['adabelief-pytorch=0.0.5', '1e-8', 'False', 'False'],
  496. ['>=0.1.0 (Current 0.2.0)', '1e-16', 'True', 'True'],
  497. ]
  498. print(default_table)
  499. recommend_table = [
  500. [
  501. 'SGD better than Adam (e.g. CNN for Image Classification)',
  502. 'Adam better than SGD (e.g. Transformer, GAN)',
  503. ],
  504. ['Recommended eps = 1e-8', 'Recommended eps = 1e-16'],
  505. ]
  506. print(recommend_table)
  507. print('For a complete table of recommended hyperparameters, see')
  508. print('https://github.com/juntang-zhuang/Adabelief-Optimizer')
  509. print(
  510. 'You can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.'
  511. )
  512. # ------------------------------------------------------------------------------
  513. if not 0.0 <= lr:
  514. raise ValueError("Invalid learning rate: {}".format(lr))
  515. if not 0.0 <= eps:
  516. raise ValueError("Invalid epsilon value: {}".format(eps))
  517. if not 0.0 <= betas[0] < 1.0:
  518. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  519. if not 0.0 <= betas[1] < 1.0:
  520. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  521. self.degenerated_to_sgd = degenerated_to_sgd
  522. if (
  523. isinstance(params, (list, tuple))
  524. and len(params) > 0
  525. and isinstance(params[0], dict)
  526. ):
  527. for param in params:
  528. if 'betas' in param and (
  529. param['betas'][0] != betas[0] or param['betas'][1] != betas[1]
  530. ):
  531. param['buffer'] = [[None, None, None] for _ in range(10)]
  532. defaults = dict(
  533. lr=lr,
  534. betas=betas,
  535. eps=eps,
  536. weight_decay=weight_decay,
  537. amsgrad=amsgrad,
  538. buffer=[[None, None, None] for _ in range(10)],
  539. )
  540. super(AdaBelief, self).__init__(params, defaults)
  541. self.degenerated_to_sgd = degenerated_to_sgd
  542. self.weight_decouple = weight_decouple
  543. self.rectify = rectify
  544. self.fixed_decay = fixed_decay
  545. if self.weight_decouple:
  546. print('Weight decoupling enabled in AdaBelief')
  547. if self.fixed_decay:
  548. print('Weight decay fixed')
  549. if self.rectify:
  550. print('Rectification enabled in AdaBelief')
  551. if amsgrad:
  552. print('AMSGrad enabled in AdaBelief')
  553. def __setstate__(self, state):
  554. super(AdaBelief, self).__setstate__(state)
  555. for group in self.param_groups:
  556. group.setdefault('amsgrad', False)
  557. def reset(self):
  558. for group in self.param_groups:
  559. for p in group['params']:
  560. state = self.state[p]
  561. amsgrad = group['amsgrad']
  562. # State initialization
  563. state['step'] = 0
  564. # Exponential moving average of gradient values
  565. state['exp_avg'] = (
  566. torch.zeros_like(p.data, memory_format=torch.preserve_format)
  567. if version_higher
  568. else torch.zeros_like(p.data)
  569. )
  570. # Exponential moving average of squared gradient values
  571. state['exp_avg_var'] = (
  572. torch.zeros_like(p.data, memory_format=torch.preserve_format)
  573. if version_higher
  574. else torch.zeros_like(p.data)
  575. )
  576. if amsgrad:
  577. # Maintains max of all exp. moving avg. of sq. grad. values
  578. state['max_exp_avg_var'] = (
  579. torch.zeros_like(p.data, memory_format=torch.preserve_format)
  580. if version_higher
  581. else torch.zeros_like(p.data)
  582. )
  583. def step(self, closure=None):
  584. """Performs a single optimization step.
  585. Arguments:
  586. closure (callable, optional): A closure that reevaluates the model
  587. and returns the loss.
  588. """
  589. loss = None
  590. if closure is not None:
  591. loss = closure()
  592. for group in self.param_groups:
  593. for p in group['params']:
  594. if p.grad is None:
  595. continue
  596. # cast data type
  597. half_precision = False
  598. if p.data.dtype == torch.float16:
  599. half_precision = True
  600. p.data = p.data.float()
  601. p.grad = p.grad.float()
  602. grad = p.grad.data
  603. if grad.is_sparse:
  604. raise RuntimeError(
  605. 'AdaBelief does not support sparse gradients, please consider SparseAdam instead'
  606. )
  607. amsgrad = group['amsgrad']
  608. state = self.state[p]
  609. beta1, beta2 = group['betas']
  610. # State initialization
  611. if len(state) == 0:
  612. state['step'] = 0
  613. # Exponential moving average of gradient values
  614. state['exp_avg'] = (
  615. torch.zeros_like(p.data, memory_format=torch.preserve_format)
  616. if version_higher
  617. else torch.zeros_like(p.data)
  618. )
  619. # Exponential moving average of squared gradient values
  620. state['exp_avg_var'] = (
  621. torch.zeros_like(p.data, memory_format=torch.preserve_format)
  622. if version_higher
  623. else torch.zeros_like(p.data)
  624. )
  625. if amsgrad:
  626. # Maintains max of all exp. moving avg. of sq. grad. values
  627. state['max_exp_avg_var'] = (
  628. torch.zeros_like(
  629. p.data, memory_format=torch.preserve_format
  630. )
  631. if version_higher
  632. else torch.zeros_like(p.data)
  633. )
  634. # perform weight decay, check if decoupled weight decay
  635. if self.weight_decouple:
  636. if not self.fixed_decay:
  637. p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
  638. else:
  639. p.data.mul_(1.0 - group['weight_decay'])
  640. else:
  641. if group['weight_decay'] != 0:
  642. grad.add_(p.data, alpha=group['weight_decay'])
  643. # get current state variable
  644. exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
  645. state['step'] += 1
  646. bias_correction1 = 1 - beta1 ** state['step']
  647. bias_correction2 = 1 - beta2 ** state['step']
  648. # Update first and second moment running average
  649. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  650. grad_residual = grad - exp_avg
  651. exp_avg_var.mul_(beta2).addcmul_(
  652. grad_residual, grad_residual, value=1 - beta2
  653. )
  654. if amsgrad:
  655. max_exp_avg_var = state['max_exp_avg_var']
  656. # Maintains the maximum of all 2nd moment running avg. till now
  657. torch.max(
  658. max_exp_avg_var,
  659. exp_avg_var.add_(group['eps']),
  660. out=max_exp_avg_var,
  661. )
  662. # Use the max. for normalizing running avg. of gradient
  663. denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(
  664. group['eps']
  665. )
  666. else:
  667. denom = (
  668. exp_avg_var.add_(group['eps']).sqrt()
  669. / math.sqrt(bias_correction2)
  670. ).add_(group['eps'])
  671. # update
  672. if not self.rectify:
  673. # Default update
  674. step_size = group['lr'] / bias_correction1
  675. p.data.addcdiv_(exp_avg, denom, value=-step_size)
  676. else: # Rectified update, forked from RAdam
  677. buffered = group['buffer'][int(state['step'] % 10)]
  678. if state['step'] == buffered[0]:
  679. N_sma, step_size = buffered[1], buffered[2]
  680. else:
  681. buffered[0] = state['step']
  682. beta2_t = beta2 ** state['step']
  683. N_sma_max = 2 / (1 - beta2) - 1
  684. N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
  685. buffered[1] = N_sma
  686. # more conservative since it's an approximated value
  687. if N_sma >= 5:
  688. step_size = math.sqrt(
  689. (1 - beta2_t)
  690. * (N_sma - 4)
  691. / (N_sma_max - 4)
  692. * (N_sma - 2)
  693. / N_sma
  694. * N_sma_max
  695. / (N_sma_max - 2)
  696. ) / (1 - beta1 ** state['step'])
  697. elif self.degenerated_to_sgd:
  698. step_size = 1.0 / (1 - beta1 ** state['step'])
  699. else:
  700. step_size = -1
  701. buffered[2] = step_size
  702. if N_sma >= 5:
  703. denom = exp_avg_var.sqrt().add_(group['eps'])
  704. p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
  705. elif step_size > 0:
  706. p.data.add_(exp_avg, alpha=-step_size * group['lr'])
  707. if half_precision:
  708. p.data = p.data.half()
  709. p.grad = p.grad.half()
  710. return loss