⚡️ PyTorch Lightning#

Important

To use this integration you should install tango with the “pytorch_lightning” extra (e.g. pip install tango[pytorch_lightning]) or just install PyTorch Lightning after the fact.

Components for Tango integration with PyTorch Lightning.

These include a basic training loop Step and registrable versions of pytorch_lightning classes, such as LightningModule, Callback, etc.

Example: training a model#

Let’s look a simple example of training a model.

We’ll make a very basic regression model and generate some fake data to train on. First, the setup:

import torch
import torch.nn as nn

from tango.common.dataset_dict import DatasetDict
from tango.step import Step
from tango.integrations.pytorch_lightning import LightningModule

Now let’s build and register our model:

@LightningModule.register("basic_regression")
class BasicRegression(LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)
        self.sigmoid = nn.Sigmoid()
        self.mse = nn.MSELoss()

    def forward(self, x, y=None):
        pred = self.sigmoid(self.linear(x))
        out = {"pred": pred}
        if y is not None:
            out["loss"] = self.mse(pred, y)
        return out

    def _to_params(self):
        return {}

    def training_step(self, batch, batch_idx):
        outputs = self.forward(**batch)
        return outputs["loss"]

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self.forward(**batch)
        return outputs

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Lastly, we’ll need a step to generate data:

@Step.register("generate_data")
class GenerateData(Step):
    DETERMINISTIC = True
    CACHEABLE = False

    def run(self) -> DatasetDict:
        torch.manual_seed(1)
        return DatasetDict(
            {
                "train": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)],
                "validation": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)],
            }
        )

You could then run this experiment with a config that looks like this:

{
    "steps": {
        "data": {
            "type": "generate_data",
        },
        "train": {
            "type": "pytorch_lightning::train",
            "model": {
                "type": "basic_regression",
            },
            "trainer": {
                "type": "default",
                "max_epochs": 5,
                "log_every_n_steps": 3,
                "logger": [
                    {"type": "pytorch_lightning::TensorBoardLogger"},
                    {"type": "pytorch_lightning::CSVLogger"},
                ],
                "accelerator": "cpu",
                "profiler": {
                    "type": "pytorch_lightning::SimpleProfiler",
                },
            },
            "dataset_dict": {
                "type": "ref",
                "ref": "data"
            },
            "train_dataloader": {
                "batch_size": 8,
                "shuffle": true
            },
            "validation_split": "validation",
            "validation_dataloader": {
                "batch_size": 8,
                "shuffle": false
            },
        }
    }
}

For example,

tango run train.jsonnet -i my_package -d /tmp/train
Starting new run cool-crow
● Starting step "data" (needed by "train")...
✓ Finished step "data"
● Starting step "train"...
✓ Finished step "train"
✓ Finished run cool-crow
...

Tips#

PyTorch Lightning functionality#

You can use existing Pytorch Lightning callbacks, loggers, and profilers, which are registered as “pytorch_lightning::name of callback logger or profiler”. For example, the EarlyStopping callback is registered as “pytorch_lightning::EarlyStopping”.

Reference#

Train step#

class tango.integrations.pytorch_lightning.LightningTrainStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#

A step for training a model using PyTorch Lightning.

Tip

Registered as a Step under the name “pytorch_lightning::train”.

run(trainer, model, *, dataset_dict=None, train_dataloader=None, train_split='train', validation_dataloader=None, validation_split='validation', datamodule=None)[source]#

Run a basic training loop to train the model.

Parameters
  • trainer (Lazy[LightningTrainer]) – The lightning trainer object.

  • model (Union[Lazy[LightningModule], LightningModule]) – The lightning module to train.

  • dataset_dict (Optional[DatasetDict], default: None) – The train and optional validation data. This is ignored if the datamodule argument is provided.

  • train_dataloader (Optional[Lazy[DataLoader]], default: None) – The data loader that generates training batches. The batches should be dict objects. This is ignored if the datamodule argument is provided.

  • train_split (str, default: 'train') – The name of the data split used for training in the dataset_dict. Default is “train”. This is ignored if the datamodule argument is provided.

  • validation_split (str, default: 'validation') – Optional name of the validation split in the dataset_dict. Default is None, which means no validation. This is ignored if the datamodule argument is provided.

  • validation_dataloader (Optional[Lazy[DataLoader]], default: None) – An optional data loader for generating validation batches. The batches should be dict objects. If not specified, but validation_split is given, the validation DataLoader will be constructed from the same parameters as the train DataLoader. This is ignored if the datamodule argument is provided.

  • datamodule (Optional[LightningDataModule], default: None) – If a LightningDataModule object is given, the other data loading arguments are ignored.

Return type

Module

Returns

The trained model on CPU with the weights from the best checkpoint loaded.

CACHEABLE: Optional[bool] = True#

This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is None, the step figures out by itself whether it should be cacheable or not.

DETERMINISTIC: bool = True#

This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is False, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.

FORMAT: Format = <tango.integrations.torch.format.TorchFormat object>#

This specifies the format the results of this step will be serialized in. See the documentation for Format for details.

METADATA: Dict[str, Any] = {'artifact_kind': 'model'}#

Arbitrary metadata about the step.

Trainer#

class tango.integrations.pytorch_lightning.LightningTrainer(work_dir, logger=None, callbacks=None, profiler=None, accelerator=None, strategy=None, plugins=None, **kwargs)[source]#

This is simply a Registrable version of the PyTorch Lightning Trainer.

Model#

class tango.integrations.pytorch_lightning.LightningModule(*args, **kwargs)[source]#

This is simply a Registrable version of the PyTorch Lightning LightningModule. It includes the following methods:

  • forward()

  • training_step()

  • validation_step()

  • test_step()

  • configure_optimizers()

Callback#

class tango.integrations.pytorch_lightning.LightningCallback[source]#

This is simply a Registrable version of the PyTorch Lightning Callback.

Logger#

class tango.integrations.pytorch_lightning.LightningLogger(agg_key_funcs=None, agg_default_func=None)[source]#

This is simply a Registrable version of the PyTorch Lightning LightningLoggerBase.

Profiler#

class tango.integrations.pytorch_lightning.LightningProfiler(*args, **kwargs)[source]#

This is simply a Registrable version of the PyTorch Lightning BaseProfiler.

Accelerator#

class tango.integrations.pytorch_lightning.LightningAccelerator[source]#

This is simply a Registrable version of the PyTorch Lightning Accelerator.