Source code for torchflare.callbacks.comet_logger

"""Implements Comet Logger."""

from abc import ABC
from typing import TYPE_CHECKING, List

from torchflare.callbacks.callback import Callbacks
from torchflare.callbacks.states import CallbackOrder
from torchflare.utils.imports_check import module_available

if TYPE_CHECKING:
    from torchflare.experiments.experiment import Experiment

_AVAILABLE = module_available("come_ml")
if _AVAILABLE:
    import comet_ml
else:
    comet_ml = None


[docs]class CometLogger(Callbacks, ABC): """Callback to log your metrics and loss values to Comet to track your experiments. For more information about Comet look at [Comet.ml](https://www.comet.ml/site/) Args: api_token: Your API key obtained from comet.ml params: The hyperparameters for your model and experiment as a dictionary project_name: Send your experiment to a specific project. Otherwise, will be sent to Uncategorized Experiments. workspace: Attach an experiment to a project that belongs to this workspace tags: List of strings. Examples: .. code-block:: from torchflare.callbacks import CometLogger params = {"bs": 16, "lr": 0.3} logger = CometLogger( project_name="experiment_10", workspace="username", params=params, tags=["Experiment", "fold_0"], api_token="your_secret_api_token", ) """ def __init__( self, api_token: str, params: dict, project_name: str, workspace: str, tags: List[str], ): """Constructor for CometLogger class.""" super(CometLogger, self).__init__(order=CallbackOrder.LOGGING) self.api_token = api_token self.project_name = project_name self.workspace = workspace self.params = params self.tags = tags self.experiment = None def on_experiment_start(self, experiment: "Experiment"): """Start of experiment.""" self.experiment = comet_ml.Experiment( project_name=self.project_name, api_key=self.api_token, workspace=self.workspace, log_code=False, display_summary_level=0, ) if self.tags is not None: self.experiment.add_tags(self.tags) if self.params is not None: self.experiment.log_parameters(self.params) def on_epoch_end(self, experiment: "Experiment"): """Function to log your metrics and values at the end of very epoch.""" logs = {k: v for k, v in experiment.exp_logs.items() if k != experiment.epoch_key} self.experiment.log_metrics(logs, step=experiment.exp_logs[experiment.epoch_key]) def on_experiment_end(self, experiment: "Experiment"): """Function to close the experiment when training ends.""" self.experiment.end() self.experiment = None