Self Supervised Learning: Bootstrap your own latent

  1. Import the necessary libraries

       import random
       from typing import Callable, Dict, Tuple
       import torch
       import torch.nn.functional as F
       import torchvision
       from kornia import augmentation as aug
       from kornia import filters
       from kornia.geometry import transform as ktf
       from torch import Tensor, nn
       from torchvision.transforms import ToTensor
       import torchflare.callbacks as cbs
       from torchflare.experiments import Experiment, ModelConfig

2. Defining augmentations using kornia
class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        return x if random.random() > self.p else self.fn(x)


def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    return nn.Sequential(
        ktf.Resize(size=image_size),
        RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
        aug.RandomGrayscale(p=0.2),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )
  1. Defining the model architectures

class MLPHead(nn.Module):
    def __init__(self, in_channels: int, projection_size: int = 256, hidden_size: int = 4096):
        super(MLPHead, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(in_channels, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, projection_size),
        )

    def forward(self, x):
        return self.net(x)


# Defining resnet encoders.
class ResnetEncoder(nn.Module):
    def __init__(self, pretrained, mlp_params):
        super(ResnetEncoder, self).__init__()
        resnet = torchvision.models.resnet18(pretrained=pretrained)
        self.encoder = torch.nn.Sequential(*list(resnet.children())[:-1])
        self.projector = MLPHead(in_channels=resnet.fc.in_features, **mlp_params)

    def forward(self, x):
        h = self.encoder(x)
        h = h.view(h.shape[0], h.shape[1])
        return self.projector(h)
  1. Defining loss functions and some custom callbacks.

@cbs.on_experiment_start(order=cbs.CallbackOrder.MODEL_INIT)
def init_target_network(experiment: "Experiment"):
    for online_params, target_params in zip(
        experiment.state.model["online_network"],
        experiment.state.model["target_network"],
    ):
        target_params.data.copy_(online_params.data)
        target_params.requires_grad = False


# Defining the loss function
def normalized_mse(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return torch.mean(2 - 2 * (x * y).sum(dim=-1))
  1. Finally we will customize the training loop for self-supervised training

class BYOLExperiment(Experiment):
    def __init__(self, momentum, augmentation_fn, image_size, **kwargs):
        super(BYOLExperiment, self).__init__(**kwargs)
        self.momentum = momentum
        self.augmentation_fn = augmentation_fn(image_size)

    def get_model_params(self, config):
        if config.model_dict and not config.optimizer_dict:
            grad_params = list(self.state.model["online_network"].parameters()) + list(
                self.state.model["predictor"].parameters()
            )
            return grad_params

    @torch.no_grad()
    def update_target_network(self):
        for online_params, target_params in zip(
            self.state.model["online_network"].parameters(),
            self.state.model["target_network"].parameters(),
        ):
            target_params.data = (
                target_params.data * self.momentum + online_params.data * self.momentum
            )

    def train_step(self) -> Dict:

        self.backend.zero_grad(optimizer=self.state.optimizer)
        x = self.batch[self.input_key]
        view_1, view_2 = self.augmentation_fn(x), self.augmentation_fn(x)
        pred_1 = self.state.model["predictor"](self.state.model["online_network"](view_1))
        pred_2 = self.state.model["predictor"](self.state.model["online_network"](view_2))

        with torch.no_grad():
            target_2 = self.state.model["target_network"](view_1)
            target_1 = self.state.model["target_network"](view_2)

        loss = self.state.criterion(pred_1, target_1) + self.state.criterion(pred_2, target_2)
        self.backend.backward_loss(loss=loss)
        self.backend.optimizer_step(optimizer=self.state.optimizer)
        self.update_target_network()
        return {self.loss_key: loss}