Source code for torchflare.callbacks.model_checkpoint

"""Implements Model Checkpoint Callback."""
import os
from abc import ABC
from typing import TYPE_CHECKING

import numpy as np
import torch

from torchflare.callbacks.callback import Callbacks
from torchflare.callbacks.extra_utils import init_improvement
from torchflare.callbacks.states import CallbackOrder

if TYPE_CHECKING:
    from torchflare.experiments.experiment import Experiment


def get_state_dicts(d):
    """Get state dictionaries for input dictionary or nn_modules.

    Args:
        d : The input dict or nn_module.
    """
    if isinstance(d, dict):
        z = {k: v.state_dict() for k, v in d.items()}
    else:
        z = d.state_dict()
    return z


[docs]class ModelCheckpoint(Callbacks, ABC): """Callback for Checkpointing your model. Args: mode: One of {"min", "max"}. In min mode, training will stop when the quantity monitored has stopped decreasing in "max" mode it will stop when the quantity monitored has stopped increasing. monitor: The quantity to be monitored. (Default : val_loss) If you want to monitor other metric just pass in the name of the metric. save_dir: The directory where you want to save the model files. file_name: The name of file. Default : model.bin Note: ModelCheckpoint will save state_dictionaries for model , optimizer , scheduler and the value of epoch with following key values: 1) 'model_state_dict' : The state dictionary of model 2) 'optimizer_state_dict' : The state dictionary of optimizer Model checkpoint will be saved based \ on the values of metrics/loss obtained from validation set. Raises: ValueError if monitor does not start with prefix ``val_`` or ``train_``. Example: .. code-block:: import torchflare.callbacks as cbs model_ckpt = cbs.ModelCheckpoint(monitor="val_accuracy", mode="max") """ def __init__(self, mode: str, monitor: str, save_dir: str = "./", file_name: str = "model.bin"): """Constructor for ModelCheckpoint class.""" super(ModelCheckpoint, self).__init__(order=CallbackOrder.CHECKPOINT) if monitor.startswith("train_") or monitor.startswith("val_"): self.monitor = monitor else: raise ValueError("Monitor must have a prefix either train_ or val_.") self.mode = mode self.eps = 1e-7 self.improvement, self.best_val = init_improvement(mode=self.mode, min_delta=self.eps) self.path = os.path.join(save_dir, file_name)
[docs] def checkpoint(self, model, optimizer): """Method to save the state dictionaries of model, optimizer,etc.""" state_dict = { "model_state_dict": get_state_dicts(d=model), "optimizer_state_dict": get_state_dicts(d=optimizer), } torch.save(state_dict, self.path)
def on_epoch_end(self, experiment: "Experiment"): """Method to save best model depending on the monitored quantity.""" val = experiment.exp_logs.get(self.monitor) if self.improvement(score=val, best=self.best_val): self.checkpoint( model=experiment.state.model, optimizer=experiment.state.optimizer, ) def on_experiment_end(self, experiment: "Experiment"): """Reset to default.""" if self.mode == "max": self.best_val = -np.inf else: self.best_val = np.inf