import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.optim.grad_scaler import ShardedGradScaler

from tango.common import Lazy
from tango.common.exceptions import ConfigurationError
from tango.integrations.torch import (

from .fsdp_config import FSDPConfig

[docs]@TrainingEngine.register("fairscale") class FairScaleTrainingEngine(TorchTrainingEngine): """ A :class:`~tango.integrations.torch.TrainingEngine` that leverages FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` for use within :class:`~tango.integrations.torch.TorchTrainStep`. .. tip:: Registered as an :class:`~tango.integrations.torch.TrainingEngine` under the name "fairscale". .. tip:: To get the best performance out of :class:`FairScaleTrainingEngine` you should wrap individual layers of your model with :class:`~fairscale.nn.FullyShardedDataParallel` and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper` while instantiating them. You can use :class:`with_wrapped_modules()` to accomplish this. .. important:: Only the parameters listed below should be defined in a configuration file. The other parameters will be automatically passed to the constructor within :class:`~tango.integrations.torch.TorchTrainStep`. .. warning:: :class:`~FairScaleTrainingEngine` can only be used in distributed training, i.e. when ``device_count > 1`` in the :class:`~tango.integrations.torch.TorchTrainStep`. For maximum memory savings, we recommend training with AMP enabled and the following :class:`FSDPConfig`: .. testcode:: from tango.integrations.fairscale import FSDPConfig fsdp_config = FSDPConfig( reshard_after_forward=True, move_params_to_cpu=True, move_grads_to_cpu=True, mixed_precision=True, ) For maximum training *speed*, we recommend training with AMP enabled and the following :class:`FSDPConfig`: .. testcode:: from tango.integrations.fairscale import FSDPConfig fsdp_config = FSDPConfig( reshard_after_forward=False, move_params_to_cpu=False, move_grads_to_cpu=False, mixed_precision=True, ) :param amp: Use automatic mixed precision (AMP). Default is ``False``. :param max_grad_norm: If set, gradients will be clipped to have this max norm. Default is ``None``. :param amp_use_bfloat16: Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training. Only applicable when ``amp=True``. If not specified, the default behavior will be to use ``bfloat16`` when training with AMP on CPU, otherwise not. :param fsdp_config: The options for :class:`~fairscale.nn.FullyShardedDataParallel`. If not specified, the default options will be used. """ def __init__( self, train_config: TrainConfig, model: Lazy[Model], optimizer: Lazy[Optimizer], *, lr_scheduler: Optional[Lazy[LRScheduler]] = None, amp: bool = False, max_grad_norm: Optional[float] = None, amp_use_bfloat16: Optional[bool] = None, fsdp_config: Optional[FSDPConfig] = None, ) -> None: if not train_config.is_distributed: raise ConfigurationError( f"{self.__class__.__name__} can only be used with distributed training" ) self.fsdp_config = fsdp_config or FSDPConfig() self.logger = logging.getLogger(self.__class__.__name__) super().__init__( train_config, model, optimizer, lr_scheduler=lr_scheduler, amp=amp, max_grad_norm=max_grad_norm, amp_use_bfloat16=amp_use_bfloat16, ) if amp: self.grad_scaler = ShardedGradScaler() def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model: if isinstance(model, Lazy): model = model.construct() if not self.fsdp_config.move_params_to_cpu: return FSDP(model, **self.fsdp_config.as_kwargs()) def clip_grad_norm(self) -> None: if self.max_grad_norm is not None: self.model.clip_grad_norm_(self.max_grad_norm) # type: ignore def get_model_state(self) -> Dict[str, torch.Tensor]: return { "weights": self.model.local_state_dict(), # type: ignore "metadata": self.model.local_metadata_dict(), # type: ignore } def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None: self.model.load_local_state_dict(state_dict["weights"]) # type: ignore def save_complete_weights_from_checkpoint( self, checkpoint_dir: Path, weights_path: Path ) -> None:"Consolidating sharded checkpoint weights...") sharded_weights: List[Dict[str, torch.Tensor]] = [] sharded_metadata: List[Dict[str, Any]] = [] for path in checkpoint_dir.resolve().glob("worker*"): sharded_state = torch.load(path, map_location="cpu") sharded_weights.append(sharded_state["weights"]) sharded_metadata.append(sharded_state["metadata"]) full_state = FSDP.consolidate_shard_weights(sharded_weights, sharded_metadata) del sharded_weights del sharded_metadata, weights_path)