Source code for tango.integrations.torch.training_engine

import os
import tempfile
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union, cast

import torch
import torch.distributed as dist
import torch.nn as nn

from tango.common import Lazy, Registrable, Tqdm

from .model import Model
from .optim import LRScheduler, Optimizer
from .train_config import TrainConfig
from .util import move_to_device


[docs]class TrainingEngine(Registrable): """ A :class:`TrainingEngine` defines and drives the strategy for training a model in :class:`TorchTrainStep`. :ivar TrainConfig train_config: The training config. :ivar Model model: The model being trained. :ivar Optimizer optimizer: The optimizer being used to train the model. :ivar LRScheduler lr_scheduler: The optional learning rate scheduler. """ default_implementation = "torch" """ The default implementation is :class:`TorchTrainingEngine`. """ def __init__( self, train_config: TrainConfig, model: Union[Model, Lazy[Model]], optimizer: Lazy[Optimizer], *, lr_scheduler: Optional[Lazy[LRScheduler]] = None, ) -> None: self.train_config = train_config self.model = self._construct_model(model) self.optimizer = self._construct_optimizer(optimizer) self.lr_scheduler: Optional[LRScheduler] = None if lr_scheduler is not None: self.lr_scheduler = self._construct_lr_scheduler(lr_scheduler) def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model: if isinstance(model, Lazy): model = model.construct() return model.to(self.train_config.worker_local_default_device) def _construct_optimizer(self, optimizer: Lazy[Optimizer]) -> Optimizer: optimizer: Optimizer = optimizer.construct(params=self.model.parameters()) return optimizer def _construct_lr_scheduler(self, lr_scheduler: Lazy[LRScheduler]) -> LRScheduler: lr_scheduler: LRScheduler = lr_scheduler.construct(optimizer=self.optimizer) return lr_scheduler
[docs] @abstractmethod def forward_train( self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int ) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Run a forward training pass on the model. """ raise NotImplementedError
[docs] @abstractmethod def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]: """ Run a forward evaluation pass on the model. """ raise NotImplementedError
[docs] @abstractmethod def backward(self, loss: torch.Tensor) -> None: """ Run a backwards pass on the model. This will always be called after :meth:`forward_train()`. """ raise NotImplementedError
[docs] @abstractmethod def step(self) -> None: """ Take an optimization step. This will always be called after :meth:`backward()`. """ raise NotImplementedError
[docs] @abstractmethod def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None: """ Save a training checkpoint with model state, optimizer state, etc., as well as the arbitrary ``client_state`` to the given ``checkpoint_dir``. """ raise NotImplementedError
[docs] @abstractmethod def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]: """ Load a checkpoint to resume training. Should return the same ``client_state`` saved in :meth:`save_checkpoint()`. """ raise NotImplementedError
[docs] @abstractmethod def save_complete_weights_from_checkpoint( self, checkpoint_dir: Path, weights_path: Path ) -> None: """ Gather the final weights from the best checkpoint and save to the file at ``weights_path``. """ raise NotImplementedError
[docs]@TrainingEngine.register("torch") class TorchTrainingEngine(TrainingEngine): """ This train engine only uses native PyTorch functionality to provide vanilla distributed data parallel training and AMP. .. tip:: Registered as a :class:`TrainingEngine` under the name "torch". .. important:: Only the parameters listed below should be defined in a configuration file. The other parameters will be automatically passed to the constructor within :class:`TorchTrainStep`. :param amp: Use automatic mixed precision. Default is ``False``. :param max_grad_norm: If set, gradients will be clipped to have this max norm. Default is ``None``. :param amp_use_bfloat16: Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training. Only applicable when ``amp=True``. If not specified, the default behavior will be to use ``bfloat16`` when training with AMP on CPU, otherwise not. """ def __init__( self, train_config: TrainConfig, model: Union[Model, Lazy[Model]], optimizer: Lazy[Optimizer], *, lr_scheduler: Optional[Lazy[LRScheduler]] = None, amp: bool = False, max_grad_norm: Optional[float] = None, amp_use_bfloat16: Optional[bool] = None, ) -> None: self.device = train_config.worker_local_default_device if amp_use_bfloat16 is None: amp_use_bfloat16 = True if train_config.device_type == "cpu" else False self.amp = amp self.amp_dtype = torch.bfloat16 if amp_use_bfloat16 else torch.float16 self.max_grad_norm = max_grad_norm self.grad_scaler: Optional[torch.cuda.amp.GradScaler] = ( None if not amp else torch.cuda.amp.GradScaler() ) if train_config.is_distributed: # Initialize distributed process group. backend: str if train_config.device_type != "cpu": torch.cuda.set_device(self.device) backend = "nccl" else: backend = "gloo" dist.init_process_group( backend=backend, init_method=f"tcp://{train_config.distributed_address}:{train_config.distributed_port}", world_size=train_config.world_size, rank=train_config.worker_id, ) super().__init__(train_config, model, optimizer, lr_scheduler=lr_scheduler) def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model: if isinstance(model, Lazy): model = model.construct() model.to(self.train_config.worker_local_default_device) # Wrap model with DDP wrapper. if self.train_config.is_distributed: model = cast(Model, nn.parallel.DistributedDataParallel(model)) return model def forward_train( self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int ) -> Tuple[torch.Tensor, Dict[str, Any]]: if micro_batch_idx == 0: self.optimizer.zero_grad(set_to_none=True) # Move tensors to right device. micro_batch = move_to_device(micro_batch, self.device) with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype): outputs = self.model(**micro_batch) micro_batch_loss = outputs["loss"] / num_micro_batches return micro_batch_loss, outputs def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]: # Move tensors to right device. batch = move_to_device(batch, self.device) with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype): with torch.inference_mode(): outputs = self.model(**batch) return outputs def backward(self, loss: torch.Tensor) -> None: if self.grad_scaler is not None: self.grad_scaler.scale(loss).backward() else: loss.backward() def clip_grad_norm(self) -> None: if self.max_grad_norm is not None: nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) def step(self) -> None: # Unscale gradients. if self.grad_scaler is not None: self.grad_scaler.unscale_(self.optimizer) # Clip gradients. self.clip_grad_norm() # Take optimizer step. if self.grad_scaler is not None: self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: self.optimizer.step() # Adjust LR schedule. if self.lr_scheduler is not None: self.lr_scheduler.step() def get_model_state(self) -> Dict[str, torch.Tensor]: if self.train_config.is_distributed: return self.model.module.state_dict() # type: ignore[union-attr] else: return self.model.state_dict() def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None: if self.train_config.is_distributed: self.model.module.load_state_dict(state_dict) # type: ignore else: self.model.load_state_dict(state_dict) # type: ignore def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None: checkpoint_dir.mkdir(exist_ok=True) def save_state(state: Dict[str, Any], name: str): temp_state_file = tempfile.NamedTemporaryFile( "w+b", dir=checkpoint_dir, delete=False, suffix=".pt" ) try: with Tqdm.wrapattr( temp_state_file, "write", desc=f"Saving {name} state", leave=False, disable=not self.train_config.is_local_main_process, ) as f: torch.save(state, f) temp_state_file.close() os.replace( temp_state_file.name, checkpoint_dir / f"worker{self.train_config.worker_id}_{name}.pt", ) finally: if os.path.exists(temp_state_file.name): os.remove(temp_state_file.name) save_state(self.get_model_state(), "model") save_state(self.optimizer.state_dict(), "optimizer"), if self.lr_scheduler is not None: save_state(self.lr_scheduler.state_dict(), "lr_scheduler") if self.grad_scaler is not None: save_state(self.grad_scaler.state_dict(), "grad_scaler") save_state(client_state, "trainer") def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]: self.load_model_state( torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_model.pt") ) self.optimizer.load_state_dict( torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_optimizer.pt") ) if self.lr_scheduler is not None: self.lr_scheduler.load_state_dict( torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_lr_scheduler.pt") ) if self.grad_scaler is not None: self.grad_scaler.load_state_dict( torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_grad_scaler.pt") ) return torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_trainer.pt") def save_complete_weights_from_checkpoint( self, checkpoint_dir: Path, weights_path: Path ) -> None: os.link(checkpoint_dir.resolve() / "worker0_model.pt", weights_path)