Source code for tango.integrations.torch.train_callback

import logging
from pathlib import Path
from typing import Any, Dict, List, Optional

from tango.common.dataset_dict import DatasetDictBase
from tango.common.registrable import Registrable
from tango.workspace import Workspace

from .data import DataLoader
from .exceptions import StopEarly
from .model import Model
from .train_config import TrainConfig
from .training_engine import TrainingEngine


[docs]class TrainCallback(Registrable): """ A :class:`TrainCallback` is a :class:`~tango.common.Registrable` class that can be used within :class:`TorchTrainStep` to customize behavior in the training loop. You can set the training callbacks with the ``callbacks`` parameter to :class:`TorchTrainStep`. .. tip:: All of the parameters to this base class will be automatically set within the training loop, so you shouldn't include them in your config for your callbacks. .. tip:: You can access the model being trained through :attr:`self.model <model>`. .. important:: The ``step`` argument to callback methods is the total/overall number of training steps so far, independent of the current epoch. .. seealso:: See :class:`~tango.integrations.wandb.WandbTrainCallback` for an example implementation. :ivar Workspace workspace: The tango workspace being used. :ivar TrainConfig train_config: The training config. :ivar TrainingEngine training_engine: The engine used to train the model. :ivar tango.common.DatasetDictBase dataset_dict: The dataset dict containing train and optional validation splits. :ivar DataLoader train_dataloader: The dataloader used for the training split. :ivar DataLoader validation_dataloader: Optional dataloader used for the validation split. """ def __init__( self, workspace: Workspace, train_config: TrainConfig, training_engine: TrainingEngine, dataset_dict: DatasetDictBase, train_dataloader: DataLoader, validation_dataloader: Optional[DataLoader] = None, ) -> None: self.workspace = workspace self.train_config = train_config self.training_engine = training_engine self.dataset_dict = dataset_dict self.train_dataloader = train_dataloader self.validation_dataloader = validation_dataloader self.logger = logging.getLogger(self.__class__.__name__) @property def step_id(self) -> str: """ The unique ID of the current :class:`~tango.Step`. """ return self.train_config.step_id @property def step_name(self) -> Optional[str]: """ The name of the current :class:`~tango.Step`. """ return self.train_config.step_name @property def work_dir(self) -> Path: """ The working directory of the current train step. """ return self.train_config.work_dir @property def is_local_main_process(self) -> bool: """ This is ``True`` if the current worker is the main distributed worker of the current node, or if we are not using distributed training. """ return self.train_config.is_local_main_process @property def model(self) -> Model: """ The :class:`Model` being trained. """ return self.training_engine.model
[docs] def state_dict(self) -> Dict[str, Any]: """ Return any state that needs to be kept after a restart. Some callbacks need to maintain state across restarts. This is the callback's opportunity to save it's state. It will be restored using :meth:`load_state_dict`. """ return {}
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Load the state on a restart. Some callbacks need to maintain state across restarts. This is the callback's opportunity to restore it's state. It gets saved using :meth:`state_dict`. """ pass
[docs] def pre_train_loop(self) -> None: """ Called right before the first batch is processed, or after a restart. """ pass
[docs] def post_train_loop(self, step: int, epoch: int) -> None: """ Called after the training loop completes. This is the last method that is called, so any cleanup can be done in this method. """ pass
[docs] def pre_epoch(self, step: int, epoch: int) -> None: """ Called right before the start of an epoch. Epochs start at 0. """ pass
[docs] def post_epoch(self, step: int, epoch: int) -> None: """ Called after an epoch is completed. Epochs start at 0. """ pass
[docs] def pre_batch(self, step: int, epoch: int, batch: List[Dict[str, Any]]) -> None: """ Called directly before processing a batch. .. note:: A type of ``batch`` is a list because with gradient accumulation there will more than one "micro batch" in the batch. """ pass
[docs] def post_batch( self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]] ) -> None: """ Called directly after processing a batch, but before unscaling gradients, clipping gradients, and taking an optimizer step. .. note:: The ``batch_loss`` here is the loss local to the current worker, not the overall (average) batch loss across distributed workers. If you need the average loss, use :meth:`log_batch()`. .. note:: A type of ``batch_outputs`` is a list because with gradient accumulation there will more than one "micro batch" in the batch. """ pass
[docs] def log_batch( self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]] ) -> None: """ Called after the optimizer step. Here ``batch_loss`` is the average loss across all distributed workers. .. note:: This callback method is not necessarily called on every step. The frequency depends on the value of the ``log_every`` parameter of :class:`TorchTrainStep`. .. note:: A type of ``batch_outputs`` is a list because with gradient accumulation there will more than one "micro batch" in the batch. """ pass
[docs] def pre_val_batch( self, step: int, val_step: int, epoch: int, val_batch: Dict[str, Any] ) -> None: """ Called right before a validation batch is processed. """ pass
[docs] def post_val_batch( self, step: int, val_step: int, epoch: int, val_batch_outputs: Dict[str, Any] ) -> None: """ Called right after a validation batch is processed with the outputs of the batch. .. tip:: This method can be used to modify ``val_batch_outputs`` in place, which is useful in scenarios like distributed training where you might need to aggregate metrics in a special way other than a simple average. If that's the case, make sure to set ``auto_aggregate_val_metric`` to ``False`` in :class:`TorchTrainStep`. """ pass
[docs] def post_val_loop( self, step: int, epoch: int, val_metric: float, best_val_metric: float ) -> None: """ Called right after the validation loop finishes. """ pass
[docs]@TrainCallback.register("torch::stop_early") class StopEarlyCallback(TrainCallback): """ A :class:`TrainCallback` for early stopping. Training is stopped early after ``patience`` steps without an improvement to the validation metric. .. tip:: Registered as a :class:`TrainCallback` under the name "torch::stop_early". """ def __init__(self, *args, patience: int = 10000, **kwargs) -> None: super().__init__(*args, **kwargs) self.patience = patience self.best_step = 0 self.best_val_metric: Optional[float] = None def post_val_loop( self, step: int, epoch: int, val_metric: float, best_val_metric: float ) -> None: # We can't rely on the best_val_metric parameter, because then we can't detect when the metric stays # the same for many steps. if self.best_val_metric is None or val_metric > self.best_val_metric: self.best_step = step self.best_val_metric = val_metric elif step > self.best_step + self.patience: raise StopEarly def state_dict(self) -> Dict[str, Any]: """ Return any state that needs to be kept after a restart. """ return { "patience": self.patience, "best_step": self.best_step, "best_val_metric": self.best_val_metric, } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Load the state on a restart. """ self.patience = state_dict["patience"] self.best_step = state_dict["best_step"] self.best_val_metric = state_dict["best_val_metric"]