Source code for tango.integrations.torch.train_config
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
[docs]@dataclass
class TrainConfig:
"""
Encapsulates the parameters of :class:`TorchTrainStep`. This is used to pass all the training
options to :class:`TrainCallback`.
"""
step_id: str
"""
The unique ID of the current step.
"""
work_dir: Path
"""
The working directory for the training run.
"""
step_name: Optional[str] = None
"""
The name of the current step.
.. note::
The same step can be run under different names.
"""
worker_id: int = 0
"""
The ID of the distributed worker.
"""
train_split: str = "train"
"""
The name of the training split.
"""
validation_split: Optional[str] = None
"""
The name of the validation split.
"""
seed: int = 42
"""
The random seed.
"""
train_steps: Optional[int] = None
"""
The number of steps to train for.
"""
train_epochs: Optional[int] = None
"""
The number of epochs to train for.
You cannot specify `train_steps` and `train_epochs` at the same time.
"""
validation_steps: Optional[int] = None
"""
The number of validation steps.
The default is to validate on the entire validation set.
"""
grad_accum: int = 1
"""
The number of micro-batches per gradient accumulation mini-batch.
"""
log_every: int = 10
"""
Controls the frequency of log updates, in number of optimizer steps
"""
checkpoint_every: int = 100
"""
Controls the frequency of checkpoints, in number of optimizer steps
"""
validate_every: Optional[int] = None
"""
Controls the frequency of the validation loop, in number of optimizer steps
"""
is_distributed: bool = False
"""
Whether or not the training job is distributed.
"""
devices: Optional[List[int]] = None
"""
The devices used (for distributed jobs).
"""
distributed_address: str = "127.0.0.1"
"""
The IP address of the main distributed process.
"""
distributed_port: int = 54761
"""
The port of the main distributed process.
"""
val_metric_name: str = "loss"
"""
The name of the validation metric to track.
"""
minimize_val_metric: bool = True
"""
Should be ``True`` when the validation metric being tracked should be minimized.
"""
auto_aggregate_val_metric: bool = True
"""
Controls automatic aggregation of validation metric.
"""
remove_stale_checkpoints: bool = True
"""
Controls removal of stale checkpoints.
"""
world_size: int = 1
"""
The number of distributed workers.
"""
_worker_local_default_device: Optional[torch.device] = None
_device_type: Optional[str] = None # either "cuda" or "cpu"
@property
def worker_local_default_device(self) -> torch.device:
"""
The default ``torch`` device for the current worker.
"""
if self._worker_local_default_device is not None:
return self._worker_local_default_device
else:
if self.devices:
device_id = self.devices[self.worker_id]
if device_id >= 0:
device = torch.device(f"cuda:{device_id}")
else:
device = torch.device("cpu")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
self._worker_local_default_device = device
return device
@property
def device_type(self) -> str:
if self._device_type is None:
device_type = (
"cpu" if self.worker_local_default_device == torch.device("cpu") else "cuda"
)
self._device_type = device_type
return device_type
else:
return self._device_type
@property
def is_local_main_process(self) -> bool:
"""
Whether the local process is the main distributed worker.
"""
return self.worker_id == 0
@property
def state_path(self) -> Path:
"""
The path to the latest state checkpoint file.
"""
return self.work_dir / "checkpoint_state_latest"
@property
def best_state_path(self) -> Path:
"""
The path to the best state checkpoint file according to the validation metric or training
loss (if no validation split is given).
"""
return self.work_dir / "checkpoint_state_best"
def state_path_for_step(self, step: int) -> Path:
return self.work_dir / f"checkpoint_state_step{step + 1}"
@property
def final_weights_path(self) -> Path:
return self.work_dir / "weights.pt"
def should_log_this_step(self, step: int) -> bool:
assert self.train_steps is not None
return step == 0 or (step + 1) % self.log_every == 0 or step == self.train_steps - 1
def should_checkpoint_this_step(self, step: int) -> bool:
assert self.train_steps is not None
return ((step + 1) % self.checkpoint_every == 0) or step == self.train_steps - 1
def should_log_this_val_step(self, val_step: int) -> bool:
assert self.validation_steps is not None
return val_step % self.log_every == 0 or val_step == self.validation_steps - 1
def as_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if not k.startswith("_")}