Source code for tango.integrations.flax.train_callback

import logging
from pathlib import Path
from typing import Any, Dict, Optional

from tango.common.dataset_dict import DatasetDictBase
from tango.common.registrable import Registrable
from tango.workspace import Workspace

from .data import DataLoader
from .model import Model
from .optim import Optimizer
from .train_config import TrainConfig


[docs]class TrainCallback(Registrable): """ A :class:`TrainCallback` is a :class:`~tango.common.Registrable` class that can be used within :class:`FlaxTrainStep` to customize behavior in the training loop. You can set the training callbacks with the ``callbacks`` parameter to :class:`FlaxTrainStep`. .. tip:: All of the parameters to this base class will be automatically set within the training loop, so you shouldn't include them in your config for your callbacks. .. tip:: You can access the model being trained through :attr:`self.model <model>`. .. important:: The ``step`` argument to callback methods is the total/overall number of training steps so far, independent of the current epoch. .. seealso:: See :class:`~tango.integrations.wandb.WandbTrainCallback` for an example implementation. :ivar Workspace workspace: The tango workspace being used. :ivar TrainConfig train_config: The training config. :ivar tango.common.DatasetDictBase dataset_dict: The dataset dict containing train and optional validation splits. :ivar DataLoader train_dataloader: The dataloader used for the training split. :ivar Model model: The flax model being trained. :ivar Optimizer optimizer: The optimizer being used for training. :ivar DataLoader validation_dataloader: Optional dataloader used for the validation split. """ def __init__( self, workspace: Workspace, train_config: TrainConfig, dataset: DatasetDictBase, train_dataloader: DataLoader, model: Model, optimizer: Optimizer, validation_dataloader: Optional[DataLoader] = None, ) -> None: self.workspace = workspace self.train_config = train_config self.dataset = dataset self.train_dataloader = train_dataloader self.model = model self.optimizer = optimizer self.validation_dataloader = validation_dataloader self.logger = logging.getLogger(self.__class__.__name__) @property def step_id(self) -> str: """ The unique ID of the current :class:`~tango.Step`. """ return self.train_config.step_id @property def step_name(self) -> Optional[str]: """ The name of the current:class:`~tango.Step`. """ return self.train_config.step_name @property def work_dir(self) -> Path: """ The working directory of the current train step """ return self.train_config.work_dir
[docs] def state_dict(self) -> Dict[str, Any]: """ Return any state that needs to be kept after a restart. Some callbacks need to maintain state across restarts. This is the callback's opportunity to save it's state. It will be restored using :meth:`load_state_dict`. """ return {}
[docs] def load_state_dict(self, state_dict: Dict[str, Any]): """ Load the state on a restart. Some callbacks need to maintain state across restarts. This is the callback's opportunity to restore it's state. It gets saved using :meth:`state_dict`. """ pass
[docs] def pre_train_loop(self) -> None: """ Called right before the first batch is processed, or after a restart """ pass
[docs] def post_train_loop(self, step: int, epoch: int) -> None: """ Called after the training loop completes. This is the last method that is called, so any cleanup can be done in this method. """ pass
[docs] def pre_epoch(self, step: int, epoch: int) -> None: """ Called before start of an epoch. Epochs start at 0. """ pass
[docs] def post_epoch(self, step: int, epoch: int) -> None: """ Called after an epoch is completed. Epochs start at 0. """ pass
[docs] def pre_batch(self, step: int, epoch: int, batch) -> None: """ Called directly before processing a batch. """
[docs] def post_batch(self, step: int, epoch: int, train_metrics: Dict) -> None: """ Called directly after processing a batch, but before unscaling gradients, clipping gradients, and taking an optimizer step. .. note:: The ``train_metrics`` here is the dictionary with train metrics of the current batch. If doing, distributed training, use `jax_utils.unreplicate(train_metrics)` before using train_metrics. If you need the average loss, use :meth:`log_batch()`. """ pass
[docs] def log_batch(self, step: int, epoch: int, train_metrics: Dict) -> None: """ Called after the optimizer step. Here ``train_metrics`` is the average metrics across all distributed workers. If doing, distributed training, use `jax_utils.unreplicate(train_metrics)` before using train_metrics. .. note:: This callback method is not necessarily called on every step. The frequency depends on the value of the ``log_every`` parameter of :class:`FlaxTrainStep`. """ pass
[docs] def pre_val_loop(self, step: int, val_step: int, state) -> None: """ Called right before the validation loop starts. """ pass
[docs] def pre_val_batch(self, step: int, val_step: int, epoch: int, val_batch) -> None: """ Called right before a validation batch is processed. """ pass
[docs] def post_val_batch(self, step: int, val_step: int, epoch: int, val_metrics: Dict) -> None: """ Called right after a validation batch is processed with the outputs of the batch. .. tip:: This method can be used to modify ``val_metrics`` in place, which is useful in scenarios like distributed training where you might need to aggregate metrics in a special way other than a simple average. If that's the case, make sure to set ``auto_aggregate_val_metric`` to ``False`` in :class:`FlaxTrainStep`. """ pass
[docs] def post_val_loop( self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float] ) -> None: """ Called right after the evaluation loop finishes """ pass