Source code for torchflare.experiments.backends

import contextlib

import torch


# noinspection PyMethodMayBeStatic
[docs]class BaseBackend: """Class to perform standard steps for optimizer, backward loss etc.""" def __init__(self): self.autocast = contextlib.nullcontext() # skipcq : PYL-R1705
[docs] def zero_grad(self, optimizer) -> None: """Wrapper for optimizer.zero_grad().""" optimizer.zero_grad()
# skipcq : PYL-R1705
[docs] def backward_loss(self, loss) -> None: """Method to propogate loss backward.""" # skipcq: PYL-W0106 loss.backward()
# skipcq : PYL-R1705
[docs] def optimizer_step(self, optimizer) -> None: """Method to perform optimizer step.""" optimizer.step()
# noinspection PyMethodMayBeStatic
[docs]class AMPBackend: """Class to perform standard steps for optimizer , scaling using mixed precision.""" def __init__(self): self.scaler = torch.cuda.amp.GradScaler() self.autocast = torch.cuda.amp.autocast() # skipcq : PYL-R1705
[docs] def zero_grad(self, optimizer) -> None: """Wrapper for optimizer.zero_grad().""" optimizer.zero_grad()
[docs] def backward_loss(self, loss) -> None: """Method to propogate loss backward.""" self.scaler.scale(loss).backward()
[docs] def optimizer_step(self, optimizer) -> None: """Method to perform optimizer step.""" self.scaler.step(optimizer) self.scaler.update()
__all__ = ["BaseBackend", "AMPBackend"]