Source code for torchflare.experiments.config

from dataclasses import MISSING, dataclass, field
from typing import Any, Callable, Dict, Union

import torch.nn as nn
from torch.optim import Optimizer

from torchflare.experiments.simple_utils import check_both_dicts, check_same_keys


def _validate_inputs(d1, d2):
    if check_both_dicts(d1, d2):
        return check_same_keys(d1, d2)
    return False


[docs]@dataclass class ModelConfig: """Model Config to initialize model related parameters, optimizers and criterion. Args: nn_module: Uninstantiated PyTorch model class or dictionary of model classes. module_params: The params required to init the nn_module. optimizer: The name or uninstantiated optimizer class or dictionary of optimizer classes to be used. optimizer_params: The params to init the optimizer class. critertion: The critertion or dictionary of criterion to be used. Example: .. code-block:: python from torchflare.experiments import ModelConfig config = ModelConfig( nn_module=SomeModelClass, module_params={"in_features": 100, "out_features": 50}, optimizer="Adam", optimizer_params={"lr": 3e-4}, criterion="binary_cross_entropy", ) Note: If you are having multiple models and optimizers ensure that both the model keys and optimizer keys are same. """ nn_module: Union[nn.Module, Dict] = field( default=MISSING, metadata={ "help": "An uninstantiated PyTorch class or a dictionary of uninstantiated PyTorch classes \ which defines the model." }, ) module_params: Dict = field( default=MISSING, metadata={"help": "The params required to initialize model class."} ) optimizer: Union[str, Dict, Optimizer, Any] = field( default=MISSING, metadata={ "help": "The optimizer class to be used or name of optimizer or dict of optimizers \ If you pass in the name of the optimizer, " "only optimizers available in pytorch are supported." }, ) optimizer_params: Dict = field( default=MISSING, metadata={"help": "The parameters to instantiate optimizer."} ) criterion: Union[Callable, Dict, str] = field( default=MISSING, metadata={ "help": "The loss function to optimize or name of the loss function.\ If you pass in the name of the loss function,\ only loss functions available in pytorch can be supported." }, ) model_dict: bool = field(default=False) optimizer_dict: bool = field(default=False) def __post_init__(self): """Post initialisation checks.""" self.model_dict = _validate_inputs(d1=self.nn_module, d2=self.module_params) self.optimizer_dict = _validate_inputs(d1=self.optimizer, d2=self.optimizer_params) if self.model_dict and self.optimizer_dict: if _validate_inputs(self.nn_module, self.optimizer) is False: raise ValueError("Both nn_module keys and optimizer keys must match.")