Skip to content

Commit

Permalink
Merge pull request #2357 from huggingface/more_opt_stuff
Browse files Browse the repository at this point in the history
Add caution to Adan. Add decouple decay option to LAMB.
  • Loading branch information
rwightman authored Dec 27, 2024
2 parents a02b1a8 + afdf11d commit 364c567
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 11 deletions.
37 changes: 37 additions & 0 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,20 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
has_betas=True,
defaults={'trust_clip': True}
),
OptimInfo(
name='lambw',
opt_class=Lamb,
description='LAMB with decoupled weight decay',
has_betas=True,
defaults={'decoupled_decay': True}
),
OptimInfo(
name='lambcw',
opt_class=Lamb,
description='LAMB with trust ratio clipping for stability and decoupled decay',
has_betas=True,
defaults={'trust_clip': True, 'decoupled_decay': True}
),
OptimInfo(
name='lars',
opt_class=Lars,
Expand Down Expand Up @@ -544,6 +558,22 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
description='Cautious Adopt',
defaults={'caution': True}
),
OptimInfo(
name='cadan',
opt_class=Adan,
description='Cautious Adaptive Nesterov Momentum Algorithm',
defaults={'caution': True, 'no_prox': False},
has_betas=True,
num_betas=3
),
OptimInfo(
name='cadanw',
opt_class=Adan,
description='Cautious Adaptive Nesterov Momentum with decoupled weight decay',
defaults={'caution': True, 'no_prox': True},
has_betas=True,
num_betas=3
),
OptimInfo(
name='cadoptw',
opt_class=Adopt,
Expand All @@ -557,6 +587,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='clambw',
opt_class=Lamb,
description='Cautious LAMB with decoupled weight decay',
has_betas=True,
defaults={'caution': True, 'decoupled_decay': True}
),
OptimInfo(
name='claprop',
opt_class=LaProp,
Expand Down
52 changes: 42 additions & 10 deletions timm/optim/adan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# limitations under the License.

import math
from typing import List, Tuple
from typing import List, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -56,6 +56,7 @@ class Adan(Optimizer):
eps: Term added to the denominator to improve numerical stability.
weight_decay: Decoupled weight decay (L2 penalty)
no_prox: How to perform the weight decay
caution: Enable caution from 'Cautious Optimizers'
foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
"""

Expand All @@ -66,7 +67,8 @@ def __init__(self,
eps: float = 1e-8,
weight_decay: float = 0.0,
no_prox: bool = False,
foreach: bool = True,
caution: bool = False,
foreach: Optional[bool] = None,
):
if not 0.0 <= lr:
raise ValueError('Invalid learning rate: {}'.format(lr))
Expand All @@ -85,6 +87,7 @@ def __init__(self,
eps=eps,
weight_decay=weight_decay,
no_prox=no_prox,
caution=caution,
foreach=foreach,
)
super().__init__(params, defaults)
Expand All @@ -93,6 +96,7 @@ def __setstate__(self, state):
super(Adan, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('no_prox', False)
group.setdefault('caution', False)

@torch.no_grad()
def restart_opt(self):
Expand All @@ -118,6 +122,11 @@ def step(self, closure=None):
with torch.enable_grad():
loss = closure()

try:
has_scalar_maximum = 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
except:
has_scalar_maximum = False

for group in self.param_groups:
params_with_grad = []
grads = []
Expand Down Expand Up @@ -161,9 +170,19 @@ def step(self, closure=None):
if not params_with_grad:
continue

kwargs = dict(
params=params_with_grad,
grads=grads,
if group['foreach'] is None:
use_foreach = not group['caution'] or has_scalar_maximum
else:
use_foreach = group['foreach']

if use_foreach:
func = _multi_tensor_adan
else:
func = _single_tensor_adan

func(
params_with_grad,
grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
exp_avg_diffs=exp_avg_diffs,
Expand All @@ -178,13 +197,9 @@ def step(self, closure=None):
weight_decay=group['weight_decay'],
eps=group['eps'],
no_prox=group['no_prox'],
caution=group['caution'],
)

if group['foreach']:
_multi_tensor_adan(**kwargs)
else:
_single_tensor_adan(**kwargs)

return loss


Expand All @@ -206,6 +221,7 @@ def _single_tensor_adan(
weight_decay: float,
eps: float,
no_prox: bool,
caution: bool,
):
for i, param in enumerate(params):
grad = grads[i]
Expand All @@ -227,6 +243,12 @@ def _single_tensor_adan(
step_size_diff = lr * beta2 / bias_correction2
step_size = lr / bias_correction1

if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask

if no_prox:
param.mul_(1 - lr * weight_decay)
param.addcdiv_(exp_avg, denom, value=-step_size)
Expand Down Expand Up @@ -257,6 +279,7 @@ def _multi_tensor_adan(
weight_decay: float,
eps: float,
no_prox: bool,
caution: bool,
):
if len(params) == 0:
return
Expand All @@ -282,6 +305,15 @@ def _multi_tensor_adan(
step_size_diff = lr * beta2 / bias_correction2
step_size = lr / bias_correction1

if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(exp_avgs, grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
exp_avgs = torch._foreach_mul(exp_avgs, masks)

if no_prox:
torch._foreach_mul_(params, 1 - lr * weight_decay)
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
Expand Down
8 changes: 7 additions & 1 deletion timm/optim/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
trust_clip: bool = False,
always_adapt: bool = False,
caution: bool = False,
decoupled_decay: bool = False,
):
defaults = dict(
lr=lr,
Expand All @@ -106,13 +107,15 @@ def __init__(
trust_clip=trust_clip,
always_adapt=always_adapt,
caution=caution,
decoupled_decay=decoupled_decay,
)
super().__init__(params, defaults)

def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('decoupled_decay', False)

def _get_clip_grad_norm(self):
max_grad_norm = self.defaults['max_grad_norm']
Expand Down Expand Up @@ -199,7 +202,10 @@ def step(self, closure=None):

weight_decay = group['weight_decay']
if weight_decay != 0:
update.add_(p, alpha=weight_decay)
if group.get('decoupled_decay', False):
p.add_(p, alpha=-group['lr'] * weight_decay)
else:
update.add_(p, alpha=weight_decay)

if weight_decay != 0 or group['always_adapt']:
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
Expand Down
3 changes: 3 additions & 0 deletions timm/optim/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ def _mars_single_tensor_step(
if c_t_norm > 1.:
c_t = c_t / c_t_norm
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)

if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask

if mars_type == "adamw":
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
bias_correction1 = 1.0 - beta1 ** step
Expand Down

0 comments on commit 364c567

Please sign in to comment.