⚡️ PyTorch Lightning#
Important
To use this integration you should install tango
with the “pytorch_lightning” extra
(e.g. pip install tango[pytorch_lightning]
) or just install PyTorch Lightning after the fact.
Components for Tango integration with PyTorch Lightning.
These include a basic training loop Step
and registrable versions
of pytorch_lightning
classes, such as LightningModule
,
Callback
, etc.
Example: training a model#
Let’s look a simple example of training a model.
We’ll make a very basic regression model and generate some fake data to train on. First, the setup:
import torch
import torch.nn as nn
from tango.common.dataset_dict import DatasetDict
from tango.step import Step
from tango.integrations.pytorch_lightning import LightningModule
Now let’s build and register our model:
@LightningModule.register("basic_regression")
class BasicRegression(LightningModule):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
self.sigmoid = nn.Sigmoid()
self.mse = nn.MSELoss()
def forward(self, x, y=None):
pred = self.sigmoid(self.linear(x))
out = {"pred": pred}
if y is not None:
out["loss"] = self.mse(pred, y)
return out
def _to_params(self):
return {}
def training_step(self, batch, batch_idx):
outputs = self.forward(**batch)
return outputs["loss"]
def validation_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self.forward(**batch)
return outputs
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
Lastly, we’ll need a step to generate data:
@Step.register("generate_data")
class GenerateData(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self) -> DatasetDict:
torch.manual_seed(1)
return DatasetDict(
{
"train": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)],
"validation": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)],
}
)
You could then run this experiment with a config that looks like this:
{
"steps": {
"data": {
"type": "generate_data",
},
"train": {
"type": "pytorch_lightning::train",
"model": {
"type": "basic_regression",
},
"trainer": {
"type": "default",
"max_epochs": 5,
"log_every_n_steps": 3,
"logger": [
{"type": "pytorch_lightning::TensorBoardLogger"},
{"type": "pytorch_lightning::CSVLogger"},
],
"accelerator": "cpu",
"profiler": {
"type": "pytorch_lightning::SimpleProfiler",
},
},
"dataset_dict": {
"type": "ref",
"ref": "data"
},
"train_dataloader": {
"batch_size": 8,
"shuffle": true
},
"validation_split": "validation",
"validation_dataloader": {
"batch_size": 8,
"shuffle": false
},
}
}
}
For example,
tango run train.jsonnet -i my_package -d /tmp/train
Starting new run cool-crow
â—Ź Starting step "data" (needed by "train")...
âś“ Finished step "data"
â—Ź Starting step "train"...
âś“ Finished step "train"
âś“ Finished run cool-crow
...
Tips#
PyTorch Lightning functionality#
You can use existing Pytorch Lightning callbacks, loggers, and profilers, which are registered as
“pytorch_lightning::name of callback logger or profiler”.
For example, the EarlyStopping
callback
is registered as “pytorch_lightning::EarlyStopping”.
Reference#
Train step#
- class tango.integrations.pytorch_lightning.LightningTrainStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, **kwargs)[source]#
A step for training a model using PyTorch Lightning.
Tip
Registered as a
Step
under the name “pytorch_lightning::train”.- run(trainer, model, *, dataset_dict=None, train_dataloader=None, train_split='train', validation_dataloader=None, validation_split='validation', datamodule=None)[source]#
Run a basic training loop to train the
model
.- Parameters
trainer (
Lazy
[LightningTrainer
]) – The lightning trainer object.model (
LightningModule
) – The lightning module to train.dataset_dict (
Optional
[DatasetDict
], default:None
) – The train and optional validation data. This is ignored if the datamodule argument is provided.train_dataloader (
Optional
[Lazy
[DataLoader
]], default:None
) – The data loader that generates training batches. The batches should bedict
objects. This is ignored if the datamodule argument is provided.train_split (
str
, default:'train'
) – The name of the data split used for training in thedataset_dict
. Default is “train”. This is ignored if the datamodule argument is provided.validation_split (
str
, default:'validation'
) – Optional name of the validation split in thedataset_dict
. Default isNone
, which means no validation. This is ignored if the datamodule argument is provided.validation_dataloader (
Optional
[Lazy
[DataLoader
]], default:None
) – An optional data loader for generating validation batches. The batches should bedict
objects. If not specified, butvalidation_split
is given, the validationDataLoader
will be constructed from the same parameters as the trainDataLoader
. This is ignored if the datamodule argument is provided.datamodule (
Optional
[LightningDataModule
], default:None
) – If aLightningDataModule
object is given, the other data loading arguments are ignored.
- Return type
Module
- Returns
The trained model on CPU with the weights from the best checkpoint loaded.
- CACHEABLE: Optional[bool] = True#
This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is
None
, the step figures out by itself whether it should be cacheable or not.
Trainer#
- class tango.integrations.pytorch_lightning.LightningTrainer(work_dir, logger=None, callbacks=None, profiler=None, accelerator=None, strategy=None, plugins=None, **kwargs)[source]#
This is simply a
Registrable
version of the PyTorch LightningTrainer
.
Model#
- class tango.integrations.pytorch_lightning.LightningModule(*args, **kwargs)[source]#
This is simply a
Registrable
version of the PyTorch LightningLightningModule
. It includes the following methods:forward()
training_step()
validation_step()
test_step()
configure_optimizers()
Callback#
- class tango.integrations.pytorch_lightning.LightningCallback[source]#
This is simply a
Registrable
version of the PyTorch LightningCallback
.
Logger#
- class tango.integrations.pytorch_lightning.LightningLogger(agg_key_funcs=None, agg_default_func=None)[source]#
This is simply a
Registrable
version of the PyTorch LightningLightningLoggerBase
.
Profiler#
- class tango.integrations.pytorch_lightning.LightningProfiler(*args, **kwargs)[source]#
This is simply a
Registrable
version of the PyTorch LightningBaseProfiler
.
Accelerator#
- class tango.integrations.pytorch_lightning.LightningAccelerator[source]#
This is simply a
Registrable
version of the PyTorch LightningAccelerator
.