Source code for tango.integrations.flax.train_config

from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Optional


[docs]@dataclass class TrainConfig: """ Encapsulates the parameters of :class:`FlaxTrainStep`. This is used to pass all the training options to :class:`TrainCallback`. """ step_id: str """ The unique ID of the current step. """ work_dir: Path """ The working directory for the training run. """ step_name: Optional[str] = None """ The name of the current step. """ train_split: str = "train" """ The name of the training split. """ validation_split: Optional[str] = None """ The name of the validation split. """ seed: int = 42 """ The random seed used to generate """ train_steps: Optional[int] = None """ The number of steps to train for. """ train_epochs: Optional[int] = None """ The number of epochs to train for. You cannot specify `train_steps` and `train_epochs` at the same time. """ validation_steps: Optional[int] = None """ The number of validation steps. """ log_every: int = 10 """ Controls the frequency of log updates. """ checkpoint_every: int = 100 """ Controls the frequency of checkpoints. """ validate_every: Optional[int] = None """ Controls the frequency of the validation loop. """ is_distributed: bool = False """ Whether or not the training job is distributed. """ val_metric_name: str = "loss" """ The name of the validation metric to track. """ minimize_val_metric: bool = True """ Should be ``True`` when the validation metric being tracked should be minimized. """ auto_aggregate_val_metric: bool = True """ Controls automatic aggregation of validation metric. """ remove_stale_checkpoints: bool = True """ Controls removal of stale checkpoints. """ @property def state_path(self) -> Path: """ The path to the latest state checkpoint file. """ return self.work_dir / "checkpoint_state_latest" @property def best_state_path(self) -> Path: """ The path to the best state checkpoint file according to the validation metric or training loss (if no validation split is given). """ return self.work_dir / "checkpoint_state_best" def should_log_this_step(self, step: int) -> bool: assert self.train_steps is not None return step == 0 or (step + 1) % self.log_every == 0 or step == self.train_steps - 1 def should_checkpoint_this_step(self, step: int) -> bool: assert self.train_steps is not None return ((step + 1) % self.checkpoint_every == 0) or step == self.train_steps - 1 def should_log_this_val_step(self, val_step: int) -> bool: assert self.validation_steps is not None return val_step % self.log_every == 0 or val_step == self.validation_steps - 1 def as_dict(self) -> Dict[str, Any]: return {k: v for k, v in asdict(self).items() if not k.startswith("_")}