Source code for tango.integrations.fairscale.fsdp_config

from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional

import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP

from tango.common import FromParams

[docs]@dataclass class FSDPConfig(FromParams): """ Defines all of the configurable options for FairScale's :class:`~fairscale.nn.FullyShardedDataParallel`. .. seealso:: `Best practices for FullyShardedDataParallel <>`_ from the FairScale docs. """ # noqa: E501 reshard_after_forward: bool = True """ See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. """ move_params_to_cpu: bool = False """ See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. """ move_grads_to_cpu: Optional[bool] = None """ See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. .. seealso:: :data:`move_params_to_cpu` .. warning:: At the moment we recommend that you don't mess with this parameter, or only explicitly set it to the same value as :data:`move_params_to_cpu`. If you leave it as ``None`` (the default), it will automatically be set to match :data:`move_params_to_cpu` by FairScale. Currently training seems to crash if you set this ``False`` while :data:`move_params_to_cpu` is ``True``. We're tracking `fairscale#918 <>`_, which may be related. """ mixed_precision: bool = False """ See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. .. important:: We recommend setting this to the same value as the ``amp`` parameter in :class:`FairScaleTrainingEngine`. Based on our experiments, if you're training with AMP enabled (``amp=True``) you might see a small additional speedup in training time along with a small additional decrease in GPU memory utilization without any performance penalty (with respect to convergence) by setting this to ``True``. But if you're *not* training with AMP, setting this ``True`` could impact the model's ability to converge. """
[docs] def as_kwargs(self) -> Dict[str, Any]: """ Convert to the appropriate ``kwargs`` for :class:`~fairscale.nn.FullyShardedDataParallel`. """ return asdict(self)
[docs] def wrap(self, module: torch.nn.Module): """ A convenience method for wrapping a module in :class:`~fairscale.nn.FullyShardedDataParallel` with all of the options defined in this class. .. seealso:: Internally this is what :func:`with_wrapped_modules()` calls. """ return FSDP(module, **self.as_kwargs())