š„ FairScale#
Important
To use this integration you should install tango
with the āfairscaleā extra
(e.g. pip install tango[fairscale]
) or just install FairScale after the fact.
This integration also depends on PyTorch, so make sure you install the correct version of torch first given your operating system and supported CUDA version. Check pytorch.org/get-started/locally/ for more details.
Components for Tango integration with FairScale.
Overview#
FairScale is a PyTorch library for large scale training. Among other things, it implements the main memory-savings techniques for distributed data-parallel training (DDP) that came from the paper ZeRO: Memory Optimization Towards Training A Trillion Parameter Models.
The main part of this Tango integration is the FairScaleTrainingEngine
.
This is a TrainingEngine
implementation that utilizes
FairScaleās FullyShardedDataParallel
(FSDP) for substantial memory savings
during distributed training.
For the best performance you should also use with_wrapped_modules()
to wrap the inner modules
of your Model
. When used with FSDP this will dramatically reduce
the memory required to load your model.
Reference#
- class tango.integrations.fairscale.FairScaleTrainingEngine(train_config, model, optimizer, *, lr_scheduler=None, amp=False, max_grad_norm=None, amp_use_bfloat16=None, fsdp_config=None)[source]#
A
TrainingEngine
that leverages FairScaleāsFullyShardedDataParallel
for use withinTorchTrainStep
.Tip
Registered as an
TrainingEngine
under the name āfairscaleā.Tip
To get the best performance out of
FairScaleTrainingEngine
you should wrap individual layers of your model withFullyShardedDataParallel
and/orcheckpoint_wrapper
while instantiating them. You can usewith_wrapped_modules()
to accomplish this.Important
Only the parameters listed below should be defined in a configuration file. The other parameters will be automatically passed to the constructor within
TorchTrainStep
.Warning
FairScaleTrainingEngine
can only be used in distributed training, i.e. whendevice_count > 1
in theTorchTrainStep
.For maximum memory savings, we recommend training with AMP enabled and the following
FSDPConfig
:from tango.integrations.fairscale import FSDPConfig fsdp_config = FSDPConfig( reshard_after_forward=True, move_params_to_cpu=True, move_grads_to_cpu=True, mixed_precision=True, )
For maximum training speed, we recommend training with AMP enabled and the following
FSDPConfig
:from tango.integrations.fairscale import FSDPConfig fsdp_config = FSDPConfig( reshard_after_forward=False, move_params_to_cpu=False, move_grads_to_cpu=False, mixed_precision=True, )
- Parameters:
amp (
bool
, default:False
) ā Use automatic mixed precision (AMP). Default isFalse
.max_grad_norm (
Optional
[float
], default:None
) ā If set, gradients will be clipped to have this max norm. Default isNone
.amp_use_bfloat16 (
Optional
[bool
], default:None
) ā Set toTrue
to force using thebfloat16
datatype in mixed precision training. Only applicable whenamp=True
. If not specified, the default behavior will be to usebfloat16
when training with AMP on CPU, otherwise not.fsdp_config (
Optional
[FSDPConfig
], default:None
) ā The options forFullyShardedDataParallel
. If not specified, the default options will be used.
- class tango.integrations.fairscale.FSDPConfig(reshard_after_forward=True, move_params_to_cpu=False, move_grads_to_cpu=None, mixed_precision=False)[source]#
Defines all of the configurable options for FairScaleās
FullyShardedDataParallel
.See also
Best practices for FullyShardedDataParallel from the FairScale docs.
- as_kwargs()[source]#
Convert to the appropriate
kwargs
forFullyShardedDataParallel
.
- wrap(module)[source]#
A convenience method for wrapping a module in
FullyShardedDataParallel
with all of the options defined in this class.See also
Internally this is what
with_wrapped_modules()
calls.
-
mixed_precision:
bool
= False# See the docstring for
FullyShardedDataParallel
.Important
We recommend setting this to the same value as the
amp
parameter inFairScaleTrainingEngine
.Based on our experiments, if youāre training with AMP enabled (
amp=True
) you might see a small additional speedup in training time along with a small additional decrease in GPU memory utilization without any performance penalty (with respect to convergence) by setting this toTrue
. But if youāre not training with AMP, setting thisTrue
could impact the modelās ability to converge.
-
move_grads_to_cpu:
Optional
[bool
] = None# See the docstring for
FullyShardedDataParallel
.See also
Warning
At the moment we recommend that you donāt mess with this parameter, or only explicitly set it to the same value as
move_params_to_cpu
. If you leave it asNone
(the default), it will automatically be set to matchmove_params_to_cpu
by FairScale.Currently training seems to crash if you set this
False
whilemove_params_to_cpu
isTrue
. Weāre tracking fairscale#918, which may be related.
-
move_params_to_cpu:
bool
= False# See the docstring for
FullyShardedDataParallel
.
-
reshard_after_forward:
bool
= True# See the docstring for
FullyShardedDataParallel
.
- tango.integrations.fairscale.with_wrapped_modules(model, modules_to_wrap, fsdp_config=None, activation_checkpointing=False)[source]#
A
Model
wrapper that can be used to easily wrap inner modules of a model with FairScaleāsFullyShardedDataParallel
wrapper and/orcheckpoint_wrapper
.Tip
Registered as a
Model
constructor under the name āfairscale::with_wrapped_modulesā.Important
This is meant to be used with the
FairScaleTrainingEngine
.- Parameters:
model (
Model
) ā The model to wrap.modules_to_wrap (
Set
[str
]) ā The names of submodule to wrap. Can be regular expressions.fsdp_config (
Optional
[FSDPConfig
], default:None
) ā TheFullyShardedDataParallel
configuration to use when wrapping the modules. If not specified, the modules will NOT be wrapped with FSDP.activation_checkpointing (
bool
, default:False
) ā Whether to wrap the modules with FairScaleāscheckpoint_wrapper
.
- Return type:
Examples
You can use this as a
Model
constructor from a config/params like this:import torch.nn as nn from tango.integrations.torch import Model class FeedForward(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(4, 4) self.activation = nn.ReLU() def forward(self, x): return self.activation(self.linear(x)) @Model.register("simple_regression_model") class SimpleRegressionModel(Model): def __init__(self): super().__init__() self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)]) self.regression_head = nn.Linear(4, 1) self.loss_fcn = nn.MSELoss() def forward(self, x, y): output = self.blocks(x) output = self.regression_head(output) loss = self.loss_fcn(output, y) return {"loss": loss} model = Model.from_params({ "type": "fairscale::with_wrapped_modules", "model": { "type": "simple_regression_model", }, "modules_to_wrap": [r"blocks\.[0-9]+", "regression_head"], "activation_checkpointing": True, })