Source code for torchflare.callbacks.wandb_logger

"""Implements logger for weights and biases."""
from abc import ABC
from typing import TYPE_CHECKING, Dict, List, Optional

from torchflare.callbacks.callback import Callbacks
from torchflare.callbacks.states import CallbackOrder
from torchflare.utils.imports_check import module_available

_AVAILABLE = module_available("wandb")
if _AVAILABLE:
    import wandb
else:
    wandb = None


if TYPE_CHECKING:
    from torchflare.experiments.experiment import Experiment


[docs]class WandbLogger(Callbacks, ABC): """Callback to log your metrics and loss values to wandb platform. For more information about wandb take a look at [Weights and Biases](https://wandb.ai/) Args: project: The name of the project where you're sending the new run entity: An entity is a username or team name where you're sending runs. name: A short display name for this run config: The hyperparameters for your model and experiment as a dictionary tags: List of strings. directory: where to save wandb local run directory. If set to None it will use experiments save_dir argument. notes: A longer description of the run, like a -m commit message in git Note: set os.environ['WANDB_SILENT'] = True to silence wandb log statements. If this is set all logs will be written to WANDB_DIR/debug.log Examples: .. code-block:: from torchflare.callbacks import WandbLogger params = {"bs": 16, "lr": 0.3} logger = WandbLogger( project="Experiment", entity="username", name="Experiment_10", config=params, tags=["Experiment", "fold_0"]) """ def __init__( self, project: str, entity: str, name: str = None, config: Dict = None, tags: List[str] = None, notes: Optional[str] = None, directory: str = None, ): """Constructor of WandbLogger.""" super(WandbLogger, self).__init__(order=CallbackOrder.LOGGING) self.entity = entity self.project = project self.name = name self.config = config self.tags = tags self.notes = notes self.dir = directory self.experiment = None def on_experiment_start(self, experiment: "Experiment"): """Experiment start.""" self.experiment = wandb.init( entity=self.entity, project=self.project, name=self.name, config=self.config, tags=self.tags, notes=self.notes, dir=self.dir, ) def on_epoch_end(self, experiment: "Experiment"): """Method to log metrics and values at the end of very epoch.""" logs = {k: v for k, v in experiment.exp_logs.items() if k != experiment.epoch_key} self.experiment.log(logs) def on_experiment_end(self, experiment: "Experiment"): """Method to end experiment after training is done.""" self.experiment.finish()