🔥 FairScale#


To use this integration you should install tango with the “fairscale” extra (e.g. pip install tango[fairscale]) or just install FairScale after the fact.

This integration also depends on PyTorch, so make sure you install the correct version of torch first given your operating system and supported CUDA version. Check pytorch.org/get-started/locally/ for more details.

Components for Tango integration with FairScale.


FairScale is a PyTorch library for large scale training. Among other things, it implements the main memory-savings techniques for distributed data-parallel training (DDP) that came from the paper ZeRO: Memory Optimization Towards Training A Trillion Parameter Models.

The main part of this Tango integration is the FairScaleTrainingEngine. This is a TrainingEngine implementation that utilizes FairScale’s FullyShardedDataParallel (FSDP) for substantial memory savings during distributed training.

For the best performance you should also use with_wrapped_modules() to wrap the inner modules of your Model. When used with FSDP this will dramatically reduce the memory required to load your model.


class tango.integrations.fairscale.FairScaleTrainingEngine(train_config, model, optimizer, *, lr_scheduler=None, amp=False, max_grad_norm=None, amp_use_bfloat16=None, fsdp_config=None)[source]#

A TrainingEngine that leverages FairScale’s FullyShardedDataParallel for use within TorchTrainStep.


Registered as an TrainingEngine under the name “fairscale”.


To get the best performance out of FairScaleTrainingEngine you should wrap individual layers of your model with FullyShardedDataParallel and/or checkpoint_wrapper while instantiating them. You can use with_wrapped_modules() to accomplish this.


Only the parameters listed below should be defined in a configuration file. The other parameters will be automatically passed to the constructor within TorchTrainStep.


FairScaleTrainingEngine can only be used in distributed training, i.e. when device_count > 1 in the TorchTrainStep.

For maximum memory savings, we recommend training with AMP enabled and the following FSDPConfig:

from tango.integrations.fairscale import FSDPConfig

fsdp_config = FSDPConfig(

For maximum training speed, we recommend training with AMP enabled and the following FSDPConfig:

from tango.integrations.fairscale import FSDPConfig

fsdp_config = FSDPConfig(
  • amp (bool, default: False) – Use automatic mixed precision (AMP). Default is False.

  • max_grad_norm (Optional[float], default: None) – If set, gradients will be clipped to have this max norm. Default is None.

  • amp_use_bfloat16 (Optional[bool], default: None) – Set to True to force using the bfloat16 datatype in mixed precision training. Only applicable when amp=True. If not specified, the default behavior will be to use bfloat16 when training with AMP on CPU, otherwise not.

  • fsdp_config (Optional[FSDPConfig], default: None) – The options for FullyShardedDataParallel. If not specified, the default options will be used.

class tango.integrations.fairscale.FSDPConfig(reshard_after_forward=True, move_params_to_cpu=False, move_grads_to_cpu=None, mixed_precision=False)[source]#

Defines all of the configurable options for FairScale’s FullyShardedDataParallel.

See also

Best practices for FullyShardedDataParallel from the FairScale docs.


Convert to the appropriate kwargs for FullyShardedDataParallel.

Return type:

Dict[str, Any]


A convenience method for wrapping a module in FullyShardedDataParallel with all of the options defined in this class.

See also

Internally this is what with_wrapped_modules() calls.

mixed_precision: bool = False#

See the docstring for FullyShardedDataParallel.


We recommend setting this to the same value as the amp parameter in FairScaleTrainingEngine.

Based on our experiments, if you’re training with AMP enabled (amp=True) you might see a small additional speedup in training time along with a small additional decrease in GPU memory utilization without any performance penalty (with respect to convergence) by setting this to True. But if you’re not training with AMP, setting this True could impact the model’s ability to converge.

move_grads_to_cpu: Optional[bool] = None#

See the docstring for FullyShardedDataParallel.


At the moment we recommend that you don’t mess with this parameter, or only explicitly set it to the same value as move_params_to_cpu. If you leave it as None (the default), it will automatically be set to match move_params_to_cpu by FairScale.

Currently training seems to crash if you set this False while move_params_to_cpu is True. We’re tracking fairscale#918, which may be related.

move_params_to_cpu: bool = False#

See the docstring for FullyShardedDataParallel.

reshard_after_forward: bool = True#

See the docstring for FullyShardedDataParallel.

tango.integrations.fairscale.with_wrapped_modules(model, modules_to_wrap, fsdp_config=None, activation_checkpointing=False)[source]#

A Model wrapper that can be used to easily wrap inner modules of a model with FairScale’s FullyShardedDataParallel wrapper and/or checkpoint_wrapper.


Registered as a Model constructor under the name “fairscale::with_wrapped_modules”.


This is meant to be used with the FairScaleTrainingEngine.

  • model (Model) – The model to wrap.

  • modules_to_wrap (Set[str]) – The names of submodule to wrap. Can be regular expressions.

  • fsdp_config (Optional[FSDPConfig], default: None) – The FullyShardedDataParallel configuration to use when wrapping the modules. If not specified, the modules will NOT be wrapped with FSDP.

  • activation_checkpointing (bool, default: False) – Whether to wrap the modules with FairScale’s checkpoint_wrapper.

Return type:



You can use this as a Model constructor from a config/params like this:

import torch.nn as nn
from tango.integrations.torch import Model

class FeedForward(nn.Module):
    def __init__(self):
        self.linear = nn.Linear(4, 4)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.activation(self.linear(x))

class SimpleRegressionModel(Model):
    def __init__(self):
        self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)])
        self.regression_head = nn.Linear(4, 1)
        self.loss_fcn = nn.MSELoss()

    def forward(self, x, y):
        output = self.blocks(x)
        output = self.regression_head(output)
        loss = self.loss_fcn(output, y)
        return {"loss": loss}

model = Model.from_params({
    "type": "fairscale::with_wrapped_modules",
    "model": {
        "type": "simple_regression_model",
    "modules_to_wrap": [r"blocks\.[0-9]+", "regression_head"],
    "activation_checkpointing": True,