import re
from typing import Optional, Set

import torch
import torch.nn as nn
from fairscale.nn.checkpoint import checkpoint_wrapper

from tango.integrations.torch import Model

from .fsdp_config import FSDPConfig

[docs]@Model.register("fairscale::with_wrapped_modules") # type: ignore[arg-type] def with_wrapped_modules( model: Model, modules_to_wrap: Set[str], fsdp_config: Optional[FSDPConfig] = None, activation_checkpointing: bool = False, ) -> Model: """ A :class:`~tango.integrations.torch.Model` wrapper that can be used to easily wrap inner modules of a model with FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` wrapper and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`. .. tip:: Registered as a :class:`~tango.integrations.torch.Model` constructor under the name "fairscale::with_wrapped_modules". .. important:: This is meant to be used with the :class:`FairScaleTrainingEngine`. :param model: The model to wrap. :param modules_to_wrap: The names of submodule to wrap. Can be regular expressions. :param fsdp_config: The ``FullyShardedDataParallel`` configuration to use when wrapping the modules. If not specified, the modules will NOT be wrapped with FSDP. :param activation_checkpointing: Whether to wrap the modules with FairScale's :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`. Examples -------- You can use this as a :class:`~tango.integrations.torch.Model` constructor from a config/params like this: .. testcode:: import torch.nn as nn from tango.integrations.torch import Model class FeedForward(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(4, 4) self.activation = nn.ReLU() def forward(self, x): return self.activation(self.linear(x)) @Model.register("simple_regression_model") class SimpleRegressionModel(Model): def __init__(self): super().__init__() self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)]) self.regression_head = nn.Linear(4, 1) self.loss_fcn = nn.MSELoss() def forward(self, x, y): output = self.blocks(x) output = self.regression_head(output) loss = self.loss_fcn(output, y) return {"loss": loss} model = Model.from_params({ "type": "fairscale::with_wrapped_modules", "model": { "type": "simple_regression_model", }, "modules_to_wrap": [r"blocks\\.[0-9]+", "regression_head"], "activation_checkpointing": True, }) """ def wrap_module( module: nn.Module, ) -> nn.Module: if activation_checkpointing: module = checkpoint_wrapper(module, offload_to_cpu=True) if fsdp_config is not None and torch.distributed.is_initialized(): module = fsdp_config.wrap(module) return module all_module_names: Set[str] = set([name for name, _ in model.named_modules() if name]) actual_modules_to_wrap: Set[str] = set() unmatched_patterns: Set[str] = modules_to_wrap.copy() for module_name in all_module_names: for pattern in modules_to_wrap: if re.fullmatch(pattern, module_name): actual_modules_to_wrap.add(module_name) if pattern in unmatched_patterns: unmatched_patterns.remove(pattern) if unmatched_patterns: raise ValueError( f"Some patterns in 'modules_to_wrap' did not match actual module names ({unmatched_patterns})" ) for module_name in actual_modules_to_wrap: if "." in module_name: *parent_parts, module_name = module_name.split(".") parent_module = model.get_submodule(".".join(parent_parts)) else: parent_module = model module = parent_module.get_submodule(module_name) module = wrap_module(module) parent_module.add_module(module_name, module) return model