"""Implementation of Early stopping."""
import math
from abc import ABC
from typing import TYPE_CHECKING
from torchflare.callbacks.callback import Callbacks
from torchflare.callbacks.extra_utils import init_improvement
from torchflare.callbacks.states import CallbackOrder
if TYPE_CHECKING:
from torchflare.experiments.experiment import Experiment
[docs]class EarlyStopping(Callbacks, ABC):
"""Implementation of Early Stopping Callback.
Args:
monitor: The quantity to be monitored. (Default : val_loss)
If you want to monitor other metric just pass in the name of the metric.
patience: Number of epochs with no improvement after which training will be stopped.
mode: One of {"min", "max"}. In min mode, training will stop when
the quantity monitored has stopped decreasing.
In "max" mode it will stop when the quantity monitored has
stopped increasing.
min_delta: Minimum change in the monitored quantity to qualify as an improvement.
Note:
EarlyStopping will only use the values of metrics/loss obtained on validation set.
Raises:
ValueError if monitor does not start with prefix ``val_`` or ``train_``.
Example:
.. code-block:: python
import torchflare.callbacks as cbs
early_stop = cbs.EarlyStopping(monitor="val_accuracy", patience=5, mode="max")
"""
def __init__(
self,
mode: str,
monitor: str,
patience: int = 5,
min_delta: float = 1e-7,
):
"""Constructor for EarlyStopping class."""
super(EarlyStopping, self).__init__(order=CallbackOrder.STOPPING)
if monitor.startswith("train_") or monitor.startswith("val_"):
self.monitor = monitor
else:
raise ValueError("Monitor must have a prefix either train_ or val_.")
self.patience = patience
self.mode = mode
self.min_delta = min_delta
self.stopping_counter = 0
self.improvement, self.best_score = init_improvement(
mode=self.mode, min_delta=self.min_delta
)
self.stopping_counter = 0
def on_experiment_start(self, experiment: "Experiment"):
"""Start of experiment."""
self.stopping_counter = 0
self.best_score = math.inf if self.mode == "min" else -math.inf
def on_epoch_end(self, experiment: "Experiment"):
"""Function which will determine when to stop the training depending on the score."""
epoch_score = experiment.exp_logs.get(self.monitor)
if self.improvement(epoch_score, self.best_score):
self.best_score = epoch_score
self.stopping_counter = 0
else:
self.stopping_counter += 1
if self.stopping_counter >= self.patience:
print("Early Stopping !")
experiment.stop_training = True
def on_experiment_end(self, experiment: "Experiment"):
"""Reset to defaults."""
self.stopping_counter = 0
self.best_score = None
self.improvement = None