Source code for torchflare.callbacks.lr_schedulers

"""Implements LrScheduler callbacks."""
from abc import ABC
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Union

import torch.optim.lr_scheduler as _schedulers

from torchflare.callbacks.callback import Callbacks
from torchflare.callbacks.states import CallbackOrder

if TYPE_CHECKING:
    from torchflare.experiments.experiment import Experiment


class LRSchedulerCallback(Callbacks, ABC):
    """Wrapper class for scheduler callbacks."""

    def __init__(self, scheduler, step_on_batch: bool):
        """Constructor class for Scheduler callback.

        Args:
            scheduler: A pytorch scheduler
            step_on_batch: Whether the scheduler steps after batch or not.
        """
        super(LRSchedulerCallback, self).__init__(order=CallbackOrder.SCHEDULER)
        self._scheduler = scheduler
        self.step_on_batch = step_on_batch
        self.scheduler = None

    def on_experiment_start(self, experiment: "Experiment"):
        """Set scheduler."""
        if self.scheduler is None:
            if isinstance(experiment.state.optimizer, dict):
                self.scheduler = {
                    k + "_scheduler": self._scheduler(v)
                    for k, v in experiment.state.optimizer.items()
                }
            else:
                self.scheduler = self._scheduler(experiment.state.optimizer)

    def on_batch_end(self, experiment: "Experiment"):
        """Step at end of batch."""
        if self.step_on_batch and (experiment.which_loader == experiment.train_stage):
            self.scheduler.step()

    def on_epoch_end(self, experiment: "Experiment"):
        """Step at the end of epoch."""
        if not self.step_on_batch:
            if isinstance(self.scheduler, _schedulers.ReduceLROnPlateau):
                key = (
                    experiment.val_key
                    if experiment.valid_stage in experiment.state.dataloaders.keys()
                    else experiment.train_key
                )
                val = experiment.exp_logs.get(key + experiment.main_metric)
                self.scheduler.step(val)

            else:
                self.scheduler.step()


[docs]class LambdaLR(LRSchedulerCallback, ABC): """Multiply learning rate by a factor computed with a given function. The function should take int value number of epochs as the only argument. Args: lr_lambda (function or list of functions): Lambda function for the learning rate factor computation. last_epoch (int): The index of last epoch. Default: -1. step_on_batch (bool): Step on each training iteration rather than each epoch. Defaults to False. """ def __init__( self, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch: int = -1, step_on_batch: bool = False, ): """Constructor for lambda scheduler.""" super().__init__( lambda opt: _schedulers.LambdaLR(opt, lr_lambda, last_epoch=last_epoch), step_on_batch=step_on_batch, )
[docs]class StepLR(LRSchedulerCallback, ABC): """Multiply learning rate by a given factor with a given period. Args: step_size (int): Period of learning rate update in epochs. gamma (float, optional): The multiplicative factor. Defaults to 0.1. last_epoch (int): The index of last epoch. Default: -1. step_on_batch (bool): Step on each training iteration rather than each epoch. Defaults to False. """ def __init__( self, step_size: int, gamma: float = 0.1, last_epoch: int = -1, step_on_batch: bool = False ): """Constructor for StepLR.""" super().__init__( lambda opt: _schedulers.StepLR(opt, step_size, gamma=gamma, last_epoch=last_epoch), step_on_batch=step_on_batch, )
[docs]class MultiStepLR(LRSchedulerCallback, ABC): """Multiply learning rate by a given factor on each epoch from a given list. Args: milestones (list of int): List of epochs number to perform lr step. gamma (float, optional): The multiplicative factor. Defaults to 0.1. last_epoch (int): The index of last epoch. Default: -1. step_on_batch (bool): Step on each training iteration rather than each epoch. Defaults to False. """ def __init__( self, milestones: Iterable[int], gamma: float = 0.1, last_epoch: int = -1, step_on_batch: bool = False, ): """Constructor class for MultiStepLR.""" super().__init__( lambda opt: _schedulers.MultiStepLR( opt, milestones, gamma=gamma, last_epoch=last_epoch ), step_on_batch=step_on_batch, )
[docs]class ExponentialLR(LRSchedulerCallback, ABC): """Multiply learning rate by a given factor on each epoch. Args: gamma (float, optional): The multiplicative factor. Defaults to 0.1. last_epoch (int): The index of last epoch. Default: -1. step_on_batch (bool): Step on each training iteration rather than each epoch. Defaults to False. """ def __init__(self, gamma: float, last_epoch: int = -1, step_on_batch: bool = False): """Constructor for ExponentialLR.""" super().__init__( lambda opt: _schedulers.ExponentialLR(opt, gamma, last_epoch=last_epoch), step_on_batch=step_on_batch, )
[docs]class CosineAnnealingLR(LRSchedulerCallback, ABC): """Set the learning rate of each parameter group using a cosine annealing schedule. Args: T_max (int): Max number of epochs or iterations. eta_min (float, optional): Min learning rate. Defaults to 0. last_epoch (int): The index of last epoch. Default: -1. step_on_batch (bool): Step on each training iteration rather than each epoch. Defaults to True. """ def __init__( self, T_max: int, eta_min: float = 0, last_epoch: int = -1, step_on_batch: bool = True ): # noqa """Constructor for CosineAnnealingLR.""" super().__init__( lambda opt: _schedulers.CosineAnnealingLR( opt, T_max, eta_min=eta_min, last_epoch=last_epoch ), step_on_batch=step_on_batch, )
[docs]class ReduceLROnPlateau(LRSchedulerCallback, ABC): """Reduce learning rate when a metric has stopped improving. Args: mode: One of {"min", "max"}. In min mode, training will \ stop when the quantity monitored \ has stopped decreasing.In "max" mode it \ will stop when the quantity monitored has stopped increasing. factor (float, optional): The multiplicative factor. Defaults to 0.1. patience (int, optional): Number of training epochs without the metric improvement to update the learning rate. Defaults to 10. verbose (bool, optional): Print info on each update to stdout. Defaults to False. threshold (float, optional): Threshold for considering the changes significant. Defaults to 1e-4. threshold_mode (str, optional): Should be 'rel', 'abs'. Defaults to 'rel'. cooldown (int, optional): Number of epochs to wait before resuming normal operation after lr has been updated. Defaults to 0. min_lr (float or list of float, optional): Min learning rate. Defaults to 0. eps (float, optional): Min significant learning rate update. Defaults to 1e-8. """ def __init__( self, mode="min", factor=0.1, patience=10, verbose=False, threshold=1e-4, threshold_mode="rel", cooldown=0, min_lr=0, eps=1e-8, ): """Constructor for ReduceLRonPlateau.""" super().__init__( lambda opt: _schedulers.ReduceLROnPlateau( opt, mode=mode, factor=factor, patience=patience, verbose=verbose, threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown, min_lr=min_lr, eps=eps, ), step_on_batch=False, )
[docs]class CyclicLR(LRSchedulerCallback, ABC): """Sets the learning rate of each parameter group according to cyclical learning rate policy. Args: base_lr (float or list of float): Initial learning rate. max_lr (float or list of float): Max learning rate. step_size_up (int, optional): Increase phase duration in epochs or iterations. Defaults to 2000. step_size_down (int, optional): Decrease phase duration in epochs or iterations. Defaults to None. mode (str, optional): Should be 'triangular', 'triangular2' or 'exp_range'. Defaults to 'triangular'. gamma (float, optional): Constant for the 'exp_range' policy. Defaults to 1. scale_fn (function, optional): Custom scaling policy function. Defaults to None. scale_mode (str, optional): Should be 'cycle' or 'iterations'. Defaults to 'cycle'. cycle_momentum (bool, optional): Momentum is cycled inversely to learning rate between 'base_momentum' and 'max_momentum'. Defaults to True. base_momentum (float or list of float, optional): Lower momentum boundaries in the cycle for each parameter group. Defaults to 0.8. max_momentum (float or list of float, optional): Upper momentum boundaries in the cycle for each parameter group. Defaults to 0.9. last_epoch (int): The index of last epoch. Default: -1. step_on_batch (bool): Step on each training iteration rather than each epoch. Defaults to True. """ def __init__( self, base_lr: float, max_lr: float, step_size_up: int = 2000, step_size_down: Optional[int] = None, mode: str = "triangular", gamma: float = 1.0, scale_fn: Optional[Callable[[float], float]] = None, scale_mode: str = "cycle", cycle_momentum: bool = True, base_momentum: float = 0.8, max_momentum: float = 0.9, last_epoch: int = -1, step_on_batch: bool = True, ): """Constructor for CyclicLR.""" super().__init__( lambda opt: _schedulers.CyclicLR( opt, base_lr, max_lr, step_size_up=step_size_up, step_size_down=step_size_down, mode=mode, gamma=gamma, scale_fn=scale_fn, scale_mode=scale_mode, cycle_momentum=cycle_momentum, base_momentum=base_momentum, max_momentum=max_momentum, last_epoch=last_epoch, ), step_on_batch=step_on_batch, )
[docs]class CosineAnnealingWarmRestarts(LRSchedulerCallback, ABC): """Set the learning rate of each \ parameter group using a cosine annealing \ schedule with a warm restart. Args: T_0 (int): Number of epochs or iterations for the first restart. T_mult (int): T increase factor after a restart. eta_min (float, optional): Min learning rate. Defaults to 0. last_epoch (int): The index of last epoch. Default: -1. step_on_batch (bool): Step on each training iteration rather than each epoch. Defaults to True. """ def __init__( self, T_0: int, T_mult: int = 1, eta_min: int = 0, last_epoch: int = -1, step_on_batch: bool = True, ): # noqa """Constructor for CosineAnnealingWarmRestarts.""" super().__init__( lambda opt: _schedulers.CosineAnnealingWarmRestarts( opt, T_0, T_mult=T_mult, eta_min=eta_min, last_epoch=last_epoch ), step_on_batch=step_on_batch, )
[docs]class MultiplicativeLR(LRSchedulerCallback, ABC): """Multiply the learning rate of each parameter group by the factor given in the specified function. Args: lr_lambda (function or list of functions): A function which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in an optimizer.param_groups. last_epoch (int): The index of last epoch. Default: -1. step_on_batch (bool): Step on each training iteration rather than each epoch. Defaults to False. """ def __init__( self, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch: int = -1, step_on_batch: bool = False, ): """Constructor for MultiplicativeLR.""" super().__init__( lambda opt: _schedulers.MultiplicativeLR(opt, lr_lambda, last_epoch=last_epoch), step_on_batch=step_on_batch, )
[docs]class OneCycleLR(LRSchedulerCallback, ABC): """Sets the learning rate of each parameter \ group according to the 1cycle learning rate policy. The 1cycle policy anneals the learning rate from an initial learning rate \ to some maximum learning rate \ and then from that maximum learning rate to some minimum learning rate\ much lower than the initial learning rate. Args: max_lr (float or list of float): Upper learning rate boundaries in the cycle for each parameter group. total_steps (int): The total number of steps in the cycle. Note that if a value is not provided here, then it must be inferred by providing a value for epochs and steps_per_epoch. Defaults to None. epochs (int): The number of epochs to train for. This is used along with steps_per_epoch in order to infer the total number of steps in the cycle if a value for total_steps is not provided. Defaults to None. steps_per_epoch (int): The number of steps per an epoch to train for. This is used along with epochs in order to infer the total number of steps in the cycle if a value for total_steps is not provided. Defaults to None. pct_start (float): The percentage of the cycle (in number of steps) spent increasing the learning rate. Defaults to 0.3. anneal_strategy (str): {'cos', 'linear'} Specifies the annealing strategy: "cos" for cosine annealing, "linear" for linear annealing. Defaults to 'cos'. cycle_momentum (bool): If ``True``, momentum is cycled inversely to learning rate between 'base_momentum' and 'max_momentum'. Defaults to True. base_momentum (float or list of float): Lower momentum boundaries in the cycle for each parameter group. Note that momentum is cycled inversely to learning rate; at the peak of a cycle, momentum is 'base_momentum' and learning rate is 'max_lr'. Defaults to 0.85. max_momentum (float or list of float): Upper momentum boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_momentum - base_momentum). Note that momentum is cycled inversely to learning rate; at the start of a cycle, momentum is 'max_momentum' and learning rate is 'base_lr' Defaults to 0.95. div_factor (float): Determines the initial learning rate via initial_lr = max_lr/div_factor Defaults to 25. final_div_factor (float): Determines the minimum learning rate via min_lr = initial_lr/final_div_factor Defaults to 1e4. last_epoch (int): The index of last epoch. Default: -1. """ def __init__( self, max_lr: Union[float, List[float]], total_steps: Optional[int] = None, epochs: Optional[int] = None, steps_per_epoch: Optional[int] = None, pct_start: float = 0.3, anneal_strategy: str = "cos", cycle_momentum: bool = True, base_momentum: Union[float, List[float]] = 0.85, max_momentum: Union[float, List[float]] = 0.95, div_factor: float = 25.0, final_div_factor: float = 1e4, last_epoch: int = -1, ): """Constructor for OneCycleLR.""" super().__init__( lambda opt: _schedulers.OneCycleLR( opt, max_lr, total_steps=total_steps, epochs=epochs, steps_per_epoch=steps_per_epoch, pct_start=pct_start, anneal_strategy=anneal_strategy, cycle_momentum=cycle_momentum, base_momentum=base_momentum, max_momentum=max_momentum, div_factor=div_factor, final_div_factor=final_div_factor, last_epoch=last_epoch, ), step_on_batch=True, )
__all__ = [ "LRSchedulerCallback", "LambdaLR", "OneCycleLR", "CosineAnnealingLR", "CosineAnnealingWarmRestarts", "CyclicLR", "MultiplicativeLR", "MultiStepLR", "ReduceLROnPlateau", "StepLR", "ExponentialLR", ]