"""Implements Base class."""
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchflare.experiments.base_backend import BaseExperiment
from torchflare.experiments.commons import EVENTS
from torchflare.experiments.simple_utils import _has_intersection, to_device
if TYPE_CHECKING:
from torchflare.experiments.config import ModelConfig
[docs]class Experiment(BaseExperiment):
"""Simple class for handling boilerplate code for training, validation and Inference.
Args:
num_epochs(int) : The number of epochs to save model.
fp16(bool) : Set this to True if you want to use mixed precision training.
device(str) : The device where you want train your model. One of **cuda** or **cpu**.
seed(int): The seed to ensure reproducibility.
Examples:
.. code-block:: python
import torch
import torchmetrics
import torchflare.callbacks as cbs
from torchflare.experiments import Experiment
# Defining Training/Validation Dataloaders
train_dl = SomeTrainDataloader()
valid_dl = SomeValidDataloader()
# Defining params
optimizer = "Adam"
optimizer_params = {"lr": 1e-4}
criterion = "cross_entropy"
num_epochs = 10
num_classes = 4
# Defining the list of metrics
metric_list = [
torchmetrics.Accuracy(num_classes = num_classes)
]
# Defining the list of callbacks
callbacks = [
cbs.EarlyStopping(monitor="accuracy", mode="max"),
cbs.ModelCheckpoint(monitor="accuracy", mode="max"),
cbs.ReduceLROnPlateau(mode="max", patience=3), # Defining Scheduler callback.
]
# Defining the model config which contains model, optimizer, criterion.
config = ModelConfig(
nn_module=SomeModelClass,
module_params={"num_features": 200, "num_classes": 5},
optimizer=optimizer,
optimizer_params=optimizer_params,
criterion=criterion,
)
# Creating Experiment and setting the params.
exp = Experiment(
num_epochs=num_epochs,
fp16=True,
device=device,
seed=42,
)
# Compiling the experiment
exp.compile_experiment(
model_config=config,
metrics=metric_list,
callbacks=callbacks,
main_metric="accuracy",
)
# Running the experiment
exp.fit_loader(train_dl=train_dl, valid_dl=valid_dl)
"""
def __init__(
self,
num_epochs: int,
fp16: bool = False,
device: str = "cuda",
seed: int = 42,
):
"""Init method to set up important variables for training and validation."""
super(Experiment, self).__init__(
num_epochs=num_epochs,
fp16=fp16,
device=device,
seed=seed,
)
[docs] def compile_experiment(
self,
model_config: "ModelConfig",
callbacks: List = None,
metrics: List = None,
main_metric: Optional[str] = None,
) -> None:
"""Configures the model for training and validation.
Args:
model_config: An ModelConfig object which holds information about models,
optimizer and criterion.
callbacks(List): The list of callbacks to be used.
metrics(List): The list of metrics to be used.
main_metric(str): The name of main metric to be monitored. Use lower case version.
For examples , use 'accuracy' instead of 'Accuracy'.
Note:
Supports all the schedulers implemented in pytorch/transformers except SWA.
Support for custom scheduling will be added soon.
"""
self.main_metric = main_metric
self._step = {self.train_stage: self.train_step, self.valid_stage: self.val_step}
self.init_state(config=model_config, callbacks=callbacks, metrics=metrics)
def set_dataloaders(self, train_dl, valid_dl) -> None:
"""Setup dataloader variables."""
dataloaders = {self.train_stage: train_dl}
if valid_dl is not None:
dataloaders[self.valid_stage] = valid_dl
self.state.update({"dataloaders": dataloaders})
def on_experiment_start(self) -> None:
"""Event on experiment start."""
self.initialise()
def on_batch_start(self) -> None:
"""Event on batch start."""
self._process_batch(self.batch)
# skipcq : PYL-W0107
def on_loader_start(self) -> None:
"""Event on loader start."""
raise NotImplementedError
def on_epoch_start(self) -> None:
"""Event on epoch start."""
self.current_epoch += 1
self.exp_logs = {self.epoch_key: self.current_epoch}
def on_experiment_end(self) -> None:
"""Event on experiment end."""
self.cleanup()
# skipcq : PYL-W0107
def on_batch_end(self) -> None:
"""Event on batch end."""
raise NotImplementedError
def on_loader_end(self) -> None:
"""Event of loader end."""
self.exp_logs.update(**self.monitors[self.which_loader])
# skipcq : PYL-W0107
def on_epoch_end(self) -> None:
"""Event on epoch end."""
raise NotImplementedError
def _run_event(self, event: str) -> None:
"""Method to run events."""
if _has_intersection(key="_start", event=event):
try:
_ = getattr(self, event)()
except NotImplementedError:
pass
# As soon as event ends, we run callbacks.
self._run_callbacks(event=event)
if _has_intersection(key="_end", event=event):
try:
_ = getattr(self, event)()
except NotImplementedError:
pass
def run_batch(self) -> None:
"""Run batch with batch event."""
self._run_event(EVENTS.ON_BATCH_START.value)
self.batch_outputs = self._step.get(self.which_loader)()
self._prepare_batch_outputs()
self._run_event(EVENTS.ON_BATCH_END.value)
def run_loader(self, dataloader) -> None:
"""Function to iterate the dataloader through all the batches."""
self._run_event(EVENTS.ON_LOADER_START.value)
mode = bool(self.train_stage in self.which_loader)
with torch.set_grad_enabled(mode=mode):
for self.batch_idx, self.batch in enumerate(dataloader):
with self.backend.autocast:
self.run_batch()
self._run_event(EVENTS.ON_LOADER_END.value)
[docs] def train_step(self) -> Dict:
"""Method to perform train step.
The train step includes forward pass, loss evaluation, backward pass.
Note:
Use self.backend attribute for doing zero_grad backward pass etc.
It is compulsory for train_step to return a dictionary with loss.
If you are using metrics then train_step should return a
dictionary with both predictions and loss.
Returns:
A dictionary with train_step results::
{
"predictions" : The train batch predictions,
"loss" : The loss for the train_step
}
"""
self.backend.zero_grad(optimizer=self.state.optimizer)
preds = self.state.model(self.batch[self.input_key])
loss = self.state.criterion(preds, self.batch[self.target_key])
self.backend.backward_loss(loss=loss)
self.backend.optimizer_step(optimizer=self.state.optimizer)
return {self.prediction_key: preds, self.loss_key: loss.item()}
[docs] def val_step(self) -> Dict:
"""Method to perform validation step.
The train step includes forward pass, loss evaluation.
Note:
It is compulsory for val_step to return a dictionary with loss.
If you are using metrics then val_step should
return a dictionary with both predictions and loss.
Returns:
A dictionary with val_step results::
{
"predictions" : The validation batch predictions,
"loss" : The loss for the val_step
}
"""
preds = self.state.model(self.batch[self.input_key])
loss = self.state.criterion(preds, self.batch[self.target_key])
return {self.prediction_key: preds, self.loss_key: loss.item()}
def _do_epoch(self) -> None:
for self.which_loader, dataloader in self.state.dataloaders.items():
self._set_model_stage(stage=self.which_loader)
self.run_loader(dataloader=dataloader)
def _run(self) -> None:
"""Method to run experiment for full number of epochs."""
for _ in range(self.num_epochs):
self._run_event(EVENTS.ON_EPOCH_START.value)
self._do_epoch()
self._run_event(EVENTS.ON_EPOCH_END.value)
if self.stop_training:
break
def _general_fit(self) -> None:
self._run_event(EVENTS.ON_EXPERIMENT_START.value)
self._run()
self._run_event(EVENTS.ON_EXPERIMENT_END.value)
[docs] def fit(
self,
x: Union[torch.Tensor, np.ndarray],
y: Union[torch.Tensor, np.ndarray],
val_data: Optional[Union[Tuple, List]] = None,
batch_size: int = 64,
dataloader_kwargs: Dict = None,
):
"""Train and validate the model on training and validation dataset.
Args:
x: A numpy array(or array-like) or torch.tensor for inputs to the model.
y: Target data. Same type as input data coule numpy array(or array-like)
or torch.tensors.
val_data : A tuple or list (x_val , y_val) of numpy arrays or torch.tensors.
batch_size(int): The batch size to be used for training and validation.
dataloader_kwargs(Dict): Keyword arguments to pass to the PyTorch dataloaders
created internally. By default, shuffle=True is passed for the training dataloader
but this can be overriden by using this argument.
Note:
Model will only be saved when ModelCheckpoint callback is used.
"""
if dataloader_kwargs is None:
dataloader_kwargs = {}
dataloader_kwargs = {"batch_size": batch_size, **dataloader_kwargs}
train_dl = self._dataloader_from_data((x, y), {"shuffle": True, **dataloader_kwargs})
valid_dl = (
None if val_data is None else self._dataloader_from_data(val_data, dataloader_kwargs)
)
self.fit_loader(train_dl=train_dl, valid_dl=valid_dl)
[docs] def fit_loader(self, train_dl: DataLoader, valid_dl: DataLoader = None):
"""Train and validate the model using dataloaders.
Args:
train_dl(DataLoader) : The training dataloader.
valid_dl(DataLoader) : The validation dataloader.
Note:
Model will only be saved when ModelCheckpoint callback is used.
"""
self.set_dataloaders(train_dl=train_dl, valid_dl=valid_dl)
self._general_fit()
@torch.no_grad()
def predict_on_loader(
self,
path_to_model: str,
test_dl: DataLoader,
device: str = "cuda",
) -> torch.Tensor:
"""Method to perform inference on test dataloader.
Args:
test_dl(DataLoader): The dataloader to be use for testing.
device(str): The device on which you want to perform inference.
path_to_model(str): The full path to model
Yields:
Output per batch.
"""
# move model to device
self._model_to_device()
ckpt = torch.load(path_to_model, map_location=torch.device(device))
if isinstance(ckpt, dict):
self.state.model.load_state_dict(ckpt["model_state_dict"])
else:
self.state.model.load_state_dict(ckpt)
for inp in test_dl:
inp = to_device(inp, device=device)
op = self.state.model(inp)
yield op.detach().cpu()
@torch.no_grad()
def predict(
self,
x: Union[torch.Tensor, np.ndarray],
path_to_model: str,
batch_size: int = 64,
dataloader_kwargs: Dict = None,
device: str = "cuda",
):
"""Method to perform inference on test data.
Args:
x: A numpy array(or array-like) or torch.tensor for inputs to the model.
batch_size: The batch size to be used for inference.
device: The device on which you want to perform inference.
dataloader_kwargs: Keyword arguments to pass to the PyTorch dataloader
which is created internally.
path_to_model: The full path to the model.
"""
if dataloader_kwargs is None:
dataloader_kwargs = {}
dataloader_kwargs = {"batch_size": batch_size, **dataloader_kwargs}
dl = self._dataloader_from_data((x,), dataloader_kwargs)
return self.predict_on_loader(path_to_model=path_to_model, test_dl=dl, device=device)
__all__ = ["Experiment"]