🔥 PyTorch#
Important
To use this integration you should install tango
with the “torch” extra
(e.g. pip install tango[torch]
) or just install PyTorch after the fact.
Make sure you install the correct version of torch given your operating system and supported CUDA version. Check pytorch.org/get-started/locally/ for more details.
Components for Tango integration with PyTorch.
These include a training loop Step
and registrable versions
of many torch
classes, such torch.optim.Optimizer
and torch.utils.data.DataLoader
.
Example: training a model#
Let’s look a simple example of training a model.
We’ll make a basic regression model and generate some fake data to train on. First, the setup:
import torch
import torch.nn as nn
from tango.common.dataset_dict import DatasetDict
from tango.step import Step
from tango.integrations.torch import Model
Now let’s build and register our model:
@Model.register("basic_regression")
class BasicRegression(Model):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
self.sigmoid = nn.Sigmoid()
self.mse = nn.MSELoss()
def forward(self, x, y=None):
pred = self.sigmoid(self.linear(x))
out = {"pred": pred}
if y is not None:
out["loss"] = self.mse(pred, y)
return out
def _to_params(self):
return {}
Lastly, we’ll need a step to generate data:
@Step.register("generate_data")
class GenerateData(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self) -> DatasetDict:
torch.manual_seed(1)
return DatasetDict(
{
"train": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)],
"validation": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)],
}
)
You could then run this experiment with a config that looks like this:
{
"steps": {
"data": {
"type": "generate_data",
},
"train": {
"type": "torch::train",
"model": {
"type": "basic_regression",
},
"training_engine": {
"optimizer": {
"type": "torch::Adam",
},
},
"dataset_dict": {
"type": "ref",
"ref": "data"
},
"train_dataloader": {
"batch_size": 8,
"shuffle": true
},
"validation_split": "validation",
"validation_dataloader": {
"batch_size": 8,
"shuffle": false
},
"train_steps": 100,
"validate_every": 10,
"checkpoint_every": 10,
"log_every": 1
}
}
}
For example,
tango run train.jsonnet -i my_package -d /tmp/train
would produce the following output:
Starting new run boss-alien
● Starting step "data" (needed by "train")...
✓ Finished step "data"
● Starting step "train"...
✓ Finished step "train"
✓ Finished run boss-alien
...
Tips#
Debugging#
When debugging a training loop that’s causing errors on a GPU, you should set the environment variable
CUDA_LAUNCH_BLOCKING=1
. This will ensure that the stack traces shows where the error actually happened.
You could also use a custom TrainCallback
to log each batch before they are passed into the model
so that you can see the exact inputs that are causing the issue.
Stopping early#
You can stop the “torch::train” step early using a custom TrainCallback
. Your callback just
needs to raise the StopEarly
exception.
Reference#
Train step#
- class tango.integrations.torch.TorchTrainStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#
A PyTorch training loop step that supports gradient accumulation, distributed training, and AMP, with configurable dataloaders, callbacks, optimizer, and LR scheduler.
Tip
Registered as a
Step
under the name “torch::train”.Important
The training loop will use GPU(s) automatically when available, as long as at least
device_count
CUDA devices are available.Distributed data parallel training is activated when the
device_count
is greater than 1.You can control which CUDA devices to use with the environment variable
CUDA_VISIBLE_DEVICES
. For example, to only use the GPUs with IDs 0 and 1, setCUDA_VISIBLE_DEVICES=0,1
(anddevice_count
to 2).Warning
During validation, the validation metric (specified by the
val_metric_name
parameter) is aggregated by simply averaging across validation batches and distributed processes. This behavior is usually correct when your validation metric is “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”.If this is not correct for your metric you will need to handle the aggregation internally in your model or with a
TrainCallback
using theTrainCallback.post_val_batch()
method. Then set the parameterauto_aggregate_val_metric
toFalse
.Note that correctly aggregating your metric during distributed training will involve distributed communication.
- run(model, training_engine, dataset_dict, train_dataloader, *, train_split='train', validation_split=None, validation_dataloader=None, seed=42, train_steps=None, train_epochs=None, validation_steps=None, grad_accum=1, log_every=10, checkpoint_every=100, validate_every=None, device_count=1, distributed_port=54761, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, callbacks=None, remove_stale_checkpoints=True)[source]#
Run a basic training loop to train the
model
.- Parameters:
model (
Union
[Lazy
[Model
],Model
]) – The model to train. It should return adict
that includes theloss
during training and theval_metric_name
during validation.training_engine (
Lazy
[TrainingEngine
]) – TheTrainingEngine
to use to train the model.dataset_dict (
DatasetDictBase
) – The train and optional validation data.train_dataloader (
Lazy
[DataLoader
]) – The data loader that generates training batches. The batches should bedict
objects that will be used askwargs
for the model’sforward()
method.train_split (
str
, default:'train'
) – The name of the data split used for training in thedataset_dict
. Default is “train”.validation_split (
Optional
[str
], default:None
) – Optional name of the validation split in thedataset_dict
. Default isNone
, which means no validation.validation_dataloader (
Optional
[Lazy
[DataLoader
]], default:None
) – An optional data loader for generating validation batches. The batches should bedict
objects. If not specified, butvalidation_split
is given, the validationDataLoader
will be constructed from the same parameters as the trainDataLoader
.seed (
int
, default:42
) – Used to set the RNG states at the beginning of training.train_steps (
Optional
[int
], default:None
) – The number of steps to train for. If not specified training will stop after a complete iteration through thetrain_dataloader
.train_epochs (
Optional
[int
], default:None
) – The number of epochs to train for. You cannot specifytrain_steps
andtrain_epochs
at the same time.validation_steps (
Optional
[int
], default:None
) – The number of steps to validate for. If not specified validation will stop after a complete iteration through thevalidation_dataloader
.grad_accum (
int
, default:1
) –The number of gradient accumulation steps. Defaults to 1.
Note
This parameter - in conjuction with the settings of your data loader and the number distributed workers - determines the effective batch size of your training run.
log_every (
int
, default:10
) – Log every this many steps.checkpoint_every (
int
, default:100
) – Save a checkpoint every this many steps.validate_every (
Optional
[int
], default:None
) – Run the validation loop every this many steps.device_count (
int
, default:1
) – The number of devices to train on, i.e. the number of distributed data parallel workers.distributed_port (
int
, default:54761
) – The port of the distributed process group. Default = “54761”.val_metric_name (
str
, default:'loss'
) – The name of the validation metric, i.e. the key of the metric in the dictionary returned by the forward pass of the model. Default is “loss”.minimize_val_metric (
bool
, default:True
) – Whether the validation metric is meant to be minimized (such as the loss). Default isTrue
. When using a metric such as accuracy, you should set this toFalse
.auto_aggregate_val_metric (
bool
, default:True
) – IfTrue
(the default), the validation metric will be averaged across validation batches and distributed processes. This may not be the correct behavior for some metrics (such as F1), in which you should set this toFalse
and handle the aggregation internally in your model or with aTrainCallback
(usingTrainCallback.post_val_batch()
).callbacks (
Optional
[List
[Lazy
[TrainCallback
]]], default:None
) – A list ofTrainCallback
.remove_stale_checkpoints (
bool
, default:True
) – IfTrue
(the default), stale checkpoints will be removed throughout training so that only the latest and best checkpoints are kept.
- Return type:
- Returns:
The trained model on CPU with the weights from the best checkpoint loaded.
- CACHEABLE: Optional[bool] = True#
This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is
None
, the step figures out by itself whether it should be cacheable or not.
- DETERMINISTIC: bool = True#
This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is
False
, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.
-
FORMAT:
Format
= <tango.integrations.torch.format.TorchFormat object># This specifies the format the results of this step will be serialized in. See the documentation for
Format
for details.
- METADATA: Dict[str, Any] = {'artifact_kind': 'model'}#
Arbitrary metadata about the step.
- SKIP_ID_ARGUMENTS: Set[str] = {'distributed_port', 'log_every'}#
If your
run()
method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.
- property resources: StepResources#
Defines the minimum compute resources required to run this step. Certain executors require this information in order to allocate resources for each step.
You can set this with the
step_resources
argument toStep
or you can override this method to automatically define the required resources.
- class tango.integrations.torch.TrainConfig(step_id, work_dir, step_name=None, worker_id=0, train_split='train', validation_split=None, seed=42, train_steps=None, train_epochs=None, validation_steps=None, grad_accum=1, log_every=10, checkpoint_every=100, validate_every=None, is_distributed=False, devices=None, distributed_address='127.0.0.1', distributed_port=54761, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, remove_stale_checkpoints=True, world_size=1, _worker_local_default_device=None, _device_type=None)[source]#
Encapsulates the parameters of
TorchTrainStep
. This is used to pass all the training options toTrainCallback
.- property best_state_path: Path#
The path to the best state checkpoint file according to the validation metric or training loss (if no validation split is given).
-
minimize_val_metric:
bool
= True# Should be
True
when the validation metric being tracked should be minimized.
-
step_name:
Optional
[str
] = None# The name of the current step.
Note
The same step can be run under different names.
-
train_epochs:
Optional
[int
] = None# The number of epochs to train for.
You cannot specify train_steps and train_epochs at the same time.
-
validate_every:
Optional
[int
] = None# Controls the frequency of the validation loop, in number of optimizer steps
Eval step#
- class tango.integrations.torch.TorchEvalStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#
A PyTorch evaluation loop that pairs well with
TorchTrainStep
.Tip
Registered as a
Step
under the name “torch::eval”.Important
The evaluation loop will use a GPU automatically if one is available. You can control which GPU it uses with the environment variable
CUDA_VISIBLE_DEVICES
. For example, setCUDA_VISIBLE_DEVICES=1
to forceTorchEvalStep
to only use the GPU with ID 1.Warning
By default the metrics specified by the
metric_names
parameter are aggregated by simply averaging across batches. This behavior is usually correct for metrics like “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”.If this is not correct for your metric you will need to handle the aggregation internally in your model or with an
EvalCallback
using theEvalCallback.post_batch()
method. Then set the parameterauto_aggregate_metrics
toFalse
.- run(model, dataset_dict, dataloader, test_split='test', seed=42, eval_steps=None, log_every=1, metric_names=('loss',), auto_aggregate_metrics=True, callbacks=None)[source]#
Evaluate the
model
.- Parameters:
model (
Model
) – The model to evaluate. It should return adict
from itsforward()
method that includes all of the metrics inmetric_names
.dataset_dict (
DatasetDictBase
) – Should contain the test data.dataloader (
Lazy
[DataLoader
]) – The data loader that generates test batches. The batches should bedict
objects.test_split (
str
, default:'test'
) – The name of the data split used for evaluation in thedataset_dict
. Default is “test”.seed (
int
, default:42
) – Used to set the RNG states at the beginning of the evaluation loop.eval_steps (
Optional
[int
], default:None
) – The number of steps to evaluate for. If not specified evaluation will stop after a complete iteration through thedataloader
.log_every (
int
, default:1
) – Log every this many steps. Default is1
.metric_names (
Sequence
[str
], default:('loss',)
) – The names of the metrics to track and aggregate. Default is("loss",)
.auto_aggregate_metrics (
bool
, default:True
) – IfTrue
(the default), the metrics will be averaged across batches. This may not be the correct behavior for some metrics (such as F1), in which you should set this toFalse
and handle the aggregation internally in your model or with anEvalCallback
(usingEvalCallback.post_batch()
).callbacks (
Optional
[List
[Lazy
[EvalCallback
]]], default:None
) – A list ofEvalCallback
.
- Return type:
- CACHEABLE: Optional[bool] = True#
This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is
None
, the step figures out by itself whether it should be cacheable or not.
- DETERMINISTIC: bool = True#
This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is
False
, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.
-
FORMAT:
Format
= <tango.format.JsonFormat object># This specifies the format the results of this step will be serialized in. See the documentation for
Format
for details.
- SKIP_ID_ARGUMENTS: Set[str] = {'log_every'}#
If your
run()
method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.
- property resources: StepResources#
Defines the minimum compute resources required to run this step. Certain executors require this information in order to allocate resources for each step.
You can set this with the
step_resources
argument toStep
or you can override this method to automatically define the required resources.
Torch format#
- class tango.integrations.torch.TorchFormat(*args, **kwds)[source]#
This format writes the artifact using
torch.save()
.Unlike
tango.format.DillFormat
, this has no special support for iterators.Tip
Registered as a
Format
under the name “torch”.
Model#
- class tango.integrations.torch.Model(*args, **kwargs)[source]#
This is a
Registrable
mixin class that inherits fromtorch.nn.Module
. Itsforward()
method should return adict
that includes theloss
during training and any tracked metrics during validation.
TrainingEngine#
- class tango.integrations.torch.TrainingEngine(train_config, model, optimizer, *, lr_scheduler=None)[source]#
A
TrainingEngine
defines and drives the strategy for training a model inTorchTrainStep
.- Variables:
train_config (TrainConfig) – The training config.
model (Model) – The model being trained.
optimizer (Optimizer) – The optimizer being used to train the model.
lr_scheduler (LRScheduler) – The optional learning rate scheduler.
- abstract backward(loss)[source]#
Run a backwards pass on the model. This will always be called after
forward_train()
.- Return type:
- abstract forward_train(micro_batch, micro_batch_idx, num_micro_batches)[source]#
Run a forward training pass on the model.
- abstract load_checkpoint(checkpoint_dir)[source]#
Load a checkpoint to resume training. Should return the same
client_state
saved insave_checkpoint()
.
- abstract save_checkpoint(checkpoint_dir, client_state)[source]#
Save a training checkpoint with model state, optimizer state, etc., as well as the arbitrary
client_state
to the givencheckpoint_dir
.- Return type:
- abstract save_complete_weights_from_checkpoint(checkpoint_dir, weights_path)[source]#
Gather the final weights from the best checkpoint and save to the file at
weights_path
.- Return type:
- abstract step()[source]#
Take an optimization step. This will always be called after
backward()
.- Return type:
-
default_implementation:
Optional
[str
] = 'torch'# The default implementation is
TorchTrainingEngine
.
- class tango.integrations.torch.TorchTrainingEngine(train_config, model, optimizer, *, lr_scheduler=None, amp=False, max_grad_norm=None, amp_use_bfloat16=None)[source]#
This train engine only uses native PyTorch functionality to provide vanilla distributed data parallel training and AMP.
Tip
Registered as a
TrainingEngine
under the name “torch”.Important
Only the parameters listed below should be defined in a configuration file. The other parameters will be automatically passed to the constructor within
TorchTrainStep
.- Parameters:
amp (
bool
, default:False
) – Use automatic mixed precision. Default isFalse
.max_grad_norm (
Optional
[float
], default:None
) – If set, gradients will be clipped to have this max norm. Default isNone
.amp_use_bfloat16 (
Optional
[bool
], default:None
) – Set toTrue
to force using thebfloat16
datatype in mixed precision training. Only applicable whenamp=True
. If not specified, the default behavior will be to usebfloat16
when training with AMP on CPU, otherwise not.
Optim#
- class tango.integrations.torch.Optimizer(params, defaults)[source]#
A
Registrable
version of a PyTorchtorch.optim.Optimizer
.All built-in PyTorch optimizers are registered according to their class name (e.g. “torch::Adam”).
Tip
You can see a list of all available optimizers by running
from tango.integrations.torch import Optimizer for name in sorted(Optimizer.list_available()): print(name)
torch::ASGD torch::Adadelta torch::Adagrad torch::Adam torch::AdamW ...
- class tango.integrations.torch.LRScheduler(optimizer, last_epoch=-1, verbose=False)[source]#
A
Registrable
version of a PyTorch learning rate scheduler.All built-in PyTorch learning rate schedulers are registered according to their class name (e.g. “torch::StepLR”).
Tip
You can see a list of all available schedulers by running
from tango.integrations.torch import LRScheduler for name in sorted(LRScheduler.list_available()): print(name)
torch::ChainedScheduler torch::ConstantLR torch::CosineAnnealingLR ...
Data#
- class tango.integrations.torch.DataLoader(dataset, collate_fn=<tango.integrations.torch.data.ConcatTensorDictsCollator object>, sampler=None, **kwargs)[source]#
A
Registrable
version of a PyTorchDataLoader
.
- class tango.integrations.torch.Sampler(data_source)[source]#
A
Registrable
version of a PyTorchSampler
.All built-in PyTorch samplers are registered under their corresponding class name (e.g. “RandomSampler”).
- class tango.integrations.torch.DataCollator(*args, **kwds)[source]#
A
Registrable
version of acollate_fn
for aDataLoader
.Subclasses just need to implement
__call__()
.-
default_implementation:
Optional
[str
] = 'concat_tensor_dicts'# The default implementation is
ConcatTensorDictsCollator
.
-
default_implementation:
- class tango.integrations.torch.ConcatTensorDictsCollator(*args, **kwds)[source]#
A simple
collate_fn
that expects items to be dictionaries of tensors. The tensors are just concatenated together.Tip
Registered as a
DataCollator
under the name “concat_tensor_dicts”.
Callbacks#
- class tango.integrations.torch.TrainCallback(workspace, train_config, training_engine, dataset_dict, train_dataloader, validation_dataloader=None)[source]#
A
TrainCallback
is aRegistrable
class that can be used withinTorchTrainStep
to customize behavior in the training loop. You can set the training callbacks with thecallbacks
parameter toTorchTrainStep
.Tip
All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.
Tip
You can access the model being trained through
self.model
.Important
The
step
argument to callback methods is the total/overall number of training steps so far, independent of the current epoch.See also
See
WandbTrainCallback
for an example implementation.- Variables:
workspace (Workspace) – The tango workspace being used.
train_config (TrainConfig) – The training config.
training_engine (TrainingEngine) – The engine used to train the model.
dataset_dict (DatasetDictBase) – The dataset dict containing train and optional validation splits.
train_dataloader (DataLoader) – The dataloader used for the training split.
validation_dataloader (DataLoader) – Optional dataloader used for the validation split.
- property is_local_main_process: bool#
This is
True
if the current worker is the main distributed worker of the current node, or if we are not using distributed training.
- state_dict()[source]#
Return any state that needs to be kept after a restart.
Some callbacks need to maintain state across restarts. This is the callback’s opportunity to save it’s state. It will be restored using
load_state_dict()
.
- load_state_dict(state_dict)[source]#
Load the state on a restart.
Some callbacks need to maintain state across restarts. This is the callback’s opportunity to restore it’s state. It gets saved using
state_dict()
.- Return type:
- pre_train_loop()[source]#
Called right before the first batch is processed, or after a restart.
- Return type:
- post_train_loop(step, epoch)[source]#
Called after the training loop completes.
This is the last method that is called, so any cleanup can be done in this method.
- Return type:
- pre_epoch(step, epoch)[source]#
Called right before the start of an epoch. Epochs start at 0.
- Return type:
- post_epoch(step, epoch)[source]#
Called after an epoch is completed. Epochs start at 0.
- Return type:
- pre_batch(step, epoch, batch)[source]#
Called directly before processing a batch. :rtype:
None
Note
A type of
batch
is a list because with gradient accumulation there will more than one “micro batch” in the batch.
- post_batch(step, epoch, batch_loss, batch_outputs)[source]#
Called directly after processing a batch, but before unscaling gradients, clipping gradients, and taking an optimizer step. :rtype:
None
Note
The
batch_loss
here is the loss local to the current worker, not the overall (average) batch loss across distributed workers.If you need the average loss, use
log_batch()
.Note
A type of
batch_outputs
is a list because with gradient accumulation there will more than one “micro batch” in the batch.
- log_batch(step, epoch, batch_loss, batch_outputs)[source]#
Called after the optimizer step. Here
batch_loss
is the average loss across all distributed workers. :rtype:None
Note
This callback method is not necessarily called on every step. The frequency depends on the value of the
log_every
parameter ofTorchTrainStep
.Note
A type of
batch_outputs
is a list because with gradient accumulation there will more than one “micro batch” in the batch.
- pre_val_batch(step, val_step, epoch, val_batch)[source]#
Called right before a validation batch is processed.
- Return type:
- post_val_batch(step, val_step, epoch, val_batch_outputs)[source]#
Called right after a validation batch is processed with the outputs of the batch. :rtype:
None
Tip
This method can be used to modify
val_batch_outputs
in place, which is useful in scenarios like distributed training where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to setauto_aggregate_val_metric
toFalse
inTorchTrainStep
.
- class tango.integrations.torch.EvalCallback(workspace, step_id, work_dir, model, dataset_dict, dataloader)[source]#
An
EvalCallback
is aRegistrable
class that can be used withinTorchEvalStep
to customize the behavior of the evaluation loop, similar to howTrainCallback
is used to customize the behavior of the training loop.Tip
All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.
- Variables:
workspace (Workspace) – The tango workspace being used.
step_id (str) – The unique ID of the step.
work_dir (Path) – The working directory of the step
model (Model) – The model being evaluated.
dataset_dict (DatasetDictBase) – The dataset dict containing the evaluation split.
dataloader (DataLoader) – The data loader used to load the evaluation split data.
- post_eval_loop(aggregated_metrics)[source]#
Called after the evaluation loop completes with the final aggregated metrics.
This is the last method that is called, so any cleanup can be done in this method.
- Return type:
- post_batch(step, batch_outputs)[source]#
Called directly after processing a batch with the outputs of the batch. :rtype:
None
Tip
This method can be used to modify
batch_outputs
in place, which is useful in scenarios where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to setauto_aggregate_metrics
toFalse
inTorchEvalStep
.
- class tango.integrations.torch.StopEarlyCallback(*args, patience=10000, **kwargs)[source]#
A
TrainCallback
for early stopping. Training is stopped early afterpatience
steps without an improvement to the validation metric.Tip
Registered as a
TrainCallback
under the name “torch::stop_early”.