Source code for torchflare.callbacks.load_checkpoint

"""Implements Load checkpoint."""
from abc import ABC
from typing import TYPE_CHECKING

import torch

from torchflare.callbacks.callback import Callbacks
from torchflare.callbacks.states import CallbackOrder

if TYPE_CHECKING:
    from torchflare.experiments.experiment import Experiment


[docs]class LoadCheckpoint(Callbacks, ABC): """Class to load checkpoint.""" def __init__(self, path_to_model: str = None): """Constructor method for LoadCheckpoint Class.""" super(LoadCheckpoint, self).__init__(order=CallbackOrder.MODEL_INIT) self.path = path_to_model
[docs] @staticmethod def unpack_ckpt(nn_obj, ckpt): """Method to unpack checkpoint. Args: nn_obj: The nn.Module object. ckpt: The corresponding state_dict for the object. """ if isinstance(nn_obj, dict): for k, v in nn_obj.items(): v.load_state_dict(ckpt[k]) else: nn_obj.load_state_dict(ckpt)
def on_experiment_start(self, experiment: "Experiment"): """Load checkpoint before starting training.""" checkpoint = torch.load(self.path, map_location=torch.device(experiment.device)) self.unpack_ckpt(nn_obj=experiment.state.model, ckpt=checkpoint["model_state_dict"]) self.unpack_ckpt(nn_obj=experiment.state.optimizer, ckpt=checkpoint["optimizer_state_dict"])