🔥 PyTorch#
Important
To use this integration you should install tango with the “torch” extra
(e.g. pip install tango[torch]) or just install PyTorch after the fact.
Make sure you install the correct version of torch given your operating system and supported CUDA version. Check pytorch.org/get-started/locally/ for more details.
Components for Tango integration with PyTorch.
These include a training loop Step and registrable versions
of many torch classes, such torch.optim.Optimizer and torch.utils.data.DataLoader.
Example: training a model#
Let’s look a simple example of training a model.
We’ll make a basic regression model and generate some fake data to train on. First, the setup:
import torch
import torch.nn as nn
from tango.common.dataset_dict import DatasetDict
from tango.step import Step
from tango.integrations.torch import Model
Now let’s build and register our model:
@Model.register("basic_regression")
class BasicRegression(Model):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
self.sigmoid = nn.Sigmoid()
self.mse = nn.MSELoss()
def forward(self, x, y=None):
pred = self.sigmoid(self.linear(x))
out = {"pred": pred}
if y is not None:
out["loss"] = self.mse(pred, y)
return out
def _to_params(self):
return {}
Lastly, we’ll need a step to generate data:
@Step.register("generate_data")
class GenerateData(Step):
DETERMINISTIC = True
CACHEABLE = False
def run(self) -> DatasetDict:
torch.manual_seed(1)
return DatasetDict(
{
"train": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)],
"validation": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)],
}
)
You could then run this experiment with a config that looks like this:
{
"steps": {
"data": {
"type": "generate_data",
},
"train": {
"type": "torch::train",
"model": {
"type": "basic_regression",
},
"training_engine": {
"optimizer": {
"type": "torch::Adam",
},
},
"dataset_dict": {
"type": "ref",
"ref": "data"
},
"train_dataloader": {
"batch_size": 8,
"shuffle": true
},
"validation_split": "validation",
"validation_dataloader": {
"batch_size": 8,
"shuffle": false
},
"train_steps": 100,
"validate_every": 10,
"checkpoint_every": 10,
"log_every": 1
}
}
}
For example,
tango run train.jsonnet -i my_package -d /tmp/train
would produce the following output:
Starting new run boss-alien
● Starting step "data" (needed by "train")...
✓ Finished step "data"
● Starting step "train"...
✓ Finished step "train"
✓ Finished run boss-alien
...
Tips#
Debugging#
When debugging a training loop that’s causing errors on a GPU, you should set the environment variable
CUDA_LAUNCH_BLOCKING=1. This will ensure that the stack traces shows where the error actually happened.
You could also use a custom TrainCallback to log each batch before they are passed into the model
so that you can see the exact inputs that are causing the issue.
Stopping early#
You can stop the “torch::train” step early using a custom TrainCallback. Your callback just
needs to raise the StopEarly exception.
Reference#
Train step#
- class tango.integrations.torch.TorchTrainStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#
A PyTorch training loop step that supports gradient accumulation, distributed training, and AMP, with configurable dataloaders, callbacks, optimizer, and LR scheduler.
Tip
Registered as a
Stepunder the name “torch::train”.Important
The training loop will use GPU(s) automatically when available, as long as at least
device_countCUDA devices are available.Distributed data parallel training is activated when the
device_countis greater than 1.You can control which CUDA devices to use with the environment variable
CUDA_VISIBLE_DEVICES. For example, to only use the GPUs with IDs 0 and 1, setCUDA_VISIBLE_DEVICES=0,1(anddevice_countto 2).Warning
During validation, the validation metric (specified by the
val_metric_nameparameter) is aggregated by simply averaging across validation batches and distributed processes. This behavior is usually correct when your validation metric is “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”.If this is not correct for your metric you will need to handle the aggregation internally in your model or with a
TrainCallbackusing theTrainCallback.post_val_batch()method. Then set the parameterauto_aggregate_val_metrictoFalse.Note that correctly aggregating your metric during distributed training will involve distributed communication.
- run(model, training_engine, dataset_dict, train_dataloader, *, train_split='train', validation_split=None, validation_dataloader=None, seed=42, train_steps=None, train_epochs=None, validation_steps=None, grad_accum=1, log_every=10, checkpoint_every=100, validate_every=None, device_count=1, distributed_port=54761, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, callbacks=None, remove_stale_checkpoints=True)[source]#
Run a basic training loop to train the
model.- Parameters:
model (
Union[Lazy[Model],Model]) – The model to train. It should return adictthat includes thelossduring training and theval_metric_nameduring validation.training_engine (
Lazy[TrainingEngine]) – TheTrainingEngineto use to train the model.dataset_dict (
DatasetDictBase) – The train and optional validation data.train_dataloader (
Lazy[DataLoader]) – The data loader that generates training batches. The batches should bedictobjects that will be used askwargsfor the model’sforward()method.train_split (
str, default:'train') – The name of the data split used for training in thedataset_dict. Default is “train”.validation_split (
Optional[str], default:None) – Optional name of the validation split in thedataset_dict. Default isNone, which means no validation.validation_dataloader (
Optional[Lazy[DataLoader]], default:None) – An optional data loader for generating validation batches. The batches should bedictobjects. If not specified, butvalidation_splitis given, the validationDataLoaderwill be constructed from the same parameters as the trainDataLoader.seed (
int, default:42) – Used to set the RNG states at the beginning of training.train_steps (
Optional[int], default:None) – The number of steps to train for. If not specified training will stop after a complete iteration through thetrain_dataloader.train_epochs (
Optional[int], default:None) – The number of epochs to train for. You cannot specifytrain_stepsandtrain_epochsat the same time.validation_steps (
Optional[int], default:None) – The number of steps to validate for. If not specified validation will stop after a complete iteration through thevalidation_dataloader.grad_accum (
int, default:1) –The number of gradient accumulation steps. Defaults to 1.
Note
This parameter - in conjuction with the settings of your data loader and the number distributed workers - determines the effective batch size of your training run.
log_every (
int, default:10) – Log every this many steps.checkpoint_every (
int, default:100) – Save a checkpoint every this many steps.validate_every (
Optional[int], default:None) – Run the validation loop every this many steps.device_count (
int, default:1) – The number of devices to train on, i.e. the number of distributed data parallel workers.distributed_port (
int, default:54761) – The port of the distributed process group. Default = “54761”.val_metric_name (
str, default:'loss') – The name of the validation metric, i.e. the key of the metric in the dictionary returned by the forward pass of the model. Default is “loss”.minimize_val_metric (
bool, default:True) – Whether the validation metric is meant to be minimized (such as the loss). Default isTrue. When using a metric such as accuracy, you should set this toFalse.auto_aggregate_val_metric (
bool, default:True) – IfTrue(the default), the validation metric will be averaged across validation batches and distributed processes. This may not be the correct behavior for some metrics (such as F1), in which you should set this toFalseand handle the aggregation internally in your model or with aTrainCallback(usingTrainCallback.post_val_batch()).callbacks (
Optional[List[Lazy[TrainCallback]]], default:None) – A list ofTrainCallback.remove_stale_checkpoints (
bool, default:True) – IfTrue(the default), stale checkpoints will be removed throughout training so that only the latest and best checkpoints are kept.
- Return type:
- Returns:
The trained model on CPU with the weights from the best checkpoint loaded.
- CACHEABLE: Optional[bool] = True#
This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is
None, the step figures out by itself whether it should be cacheable or not.
- DETERMINISTIC: bool = True#
This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is
False, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.
-
FORMAT:
Format= <tango.integrations.torch.format.TorchFormat object># This specifies the format the results of this step will be serialized in. See the documentation for
Formatfor details.
- METADATA: Dict[str, Any] = {'artifact_kind': 'model'}#
Arbitrary metadata about the step.
- SKIP_ID_ARGUMENTS: Set[str] = {'distributed_port', 'log_every'}#
If your
run()method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.
- property resources: StepResources#
Defines the minimum compute resources required to run this step. Certain executors require this information in order to allocate resources for each step.
You can set this with the
step_resourcesargument toStepor you can override this method to automatically define the required resources.
- class tango.integrations.torch.TrainConfig(step_id, work_dir, step_name=None, worker_id=0, train_split='train', validation_split=None, seed=42, train_steps=None, train_epochs=None, validation_steps=None, grad_accum=1, log_every=10, checkpoint_every=100, validate_every=None, is_distributed=False, devices=None, distributed_address='127.0.0.1', distributed_port=54761, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, remove_stale_checkpoints=True, world_size=1, _worker_local_default_device=None, _device_type=None)[source]#
Encapsulates the parameters of
TorchTrainStep. This is used to pass all the training options toTrainCallback.- property best_state_path: Path#
The path to the best state checkpoint file according to the validation metric or training loss (if no validation split is given).
-
minimize_val_metric:
bool= True# Should be
Truewhen the validation metric being tracked should be minimized.
-
step_name:
Optional[str] = None# The name of the current step.
Note
The same step can be run under different names.
-
train_epochs:
Optional[int] = None# The number of epochs to train for.
You cannot specify train_steps and train_epochs at the same time.
-
validate_every:
Optional[int] = None# Controls the frequency of the validation loop, in number of optimizer steps
Eval step#
- class tango.integrations.torch.TorchEvalStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#
A PyTorch evaluation loop that pairs well with
TorchTrainStep.Tip
Registered as a
Stepunder the name “torch::eval”.Important
The evaluation loop will use a GPU automatically if one is available. You can control which GPU it uses with the environment variable
CUDA_VISIBLE_DEVICES. For example, setCUDA_VISIBLE_DEVICES=1to forceTorchEvalStepto only use the GPU with ID 1.Warning
By default the metrics specified by the
metric_namesparameter are aggregated by simply averaging across batches. This behavior is usually correct for metrics like “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”.If this is not correct for your metric you will need to handle the aggregation internally in your model or with an
EvalCallbackusing theEvalCallback.post_batch()method. Then set the parameterauto_aggregate_metricstoFalse.- run(model, dataset_dict, dataloader, test_split='test', seed=42, eval_steps=None, log_every=1, metric_names=('loss',), auto_aggregate_metrics=True, callbacks=None)[source]#
Evaluate the
model.- Parameters:
model (
Model) – The model to evaluate. It should return adictfrom itsforward()method that includes all of the metrics inmetric_names.dataset_dict (
DatasetDictBase) – Should contain the test data.dataloader (
Lazy[DataLoader]) – The data loader that generates test batches. The batches should bedictobjects.test_split (
str, default:'test') – The name of the data split used for evaluation in thedataset_dict. Default is “test”.seed (
int, default:42) – Used to set the RNG states at the beginning of the evaluation loop.eval_steps (
Optional[int], default:None) – The number of steps to evaluate for. If not specified evaluation will stop after a complete iteration through thedataloader.log_every (
int, default:1) – Log every this many steps. Default is1.metric_names (
Sequence[str], default:('loss',)) – The names of the metrics to track and aggregate. Default is("loss",).auto_aggregate_metrics (
bool, default:True) – IfTrue(the default), the metrics will be averaged across batches. This may not be the correct behavior for some metrics (such as F1), in which you should set this toFalseand handle the aggregation internally in your model or with anEvalCallback(usingEvalCallback.post_batch()).callbacks (
Optional[List[Lazy[EvalCallback]]], default:None) – A list ofEvalCallback.
- Return type:
- CACHEABLE: Optional[bool] = True#
This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is
None, the step figures out by itself whether it should be cacheable or not.
- DETERMINISTIC: bool = True#
This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is
False, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.
-
FORMAT:
Format= <tango.format.JsonFormat object># This specifies the format the results of this step will be serialized in. See the documentation for
Formatfor details.
- SKIP_ID_ARGUMENTS: Set[str] = {'log_every'}#
If your
run()method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.
- property resources: StepResources#
Defines the minimum compute resources required to run this step. Certain executors require this information in order to allocate resources for each step.
You can set this with the
step_resourcesargument toStepor you can override this method to automatically define the required resources.
Torch format#
- class tango.integrations.torch.TorchFormat[source]#
This format writes the artifact using
torch.save().Unlike
tango.format.DillFormat, this has no special support for iterators.Tip
Registered as a
Formatunder the name “torch”.
Model#
- class tango.integrations.torch.Model(*args, **kwargs)[source]#
This is a
Registrablemixin class that inherits fromtorch.nn.Module. Itsforward()method should return adictthat includes thelossduring training and any tracked metrics during validation.
TrainingEngine#
- class tango.integrations.torch.TrainingEngine(train_config, model, optimizer, *, lr_scheduler=None)[source]#
A
TrainingEnginedefines and drives the strategy for training a model inTorchTrainStep.- Variables:
train_config (TrainConfig) – The training config.
model (Model) – The model being trained.
optimizer (Optimizer) – The optimizer being used to train the model.
lr_scheduler (LRScheduler) – The optional learning rate scheduler.
- abstract backward(loss)[source]#
Run a backwards pass on the model. This will always be called after
forward_train().- Return type:
- abstract forward_train(micro_batch, micro_batch_idx, num_micro_batches)[source]#
Run a forward training pass on the model.
- abstract load_checkpoint(checkpoint_dir)[source]#
Load a checkpoint to resume training. Should return the same
client_statesaved insave_checkpoint().
- abstract save_checkpoint(checkpoint_dir, client_state)[source]#
Save a training checkpoint with model state, optimizer state, etc., as well as the arbitrary
client_stateto the givencheckpoint_dir.- Return type:
- abstract save_complete_weights_from_checkpoint(checkpoint_dir, weights_path)[source]#
Gather the final weights from the best checkpoint and save to the file at
weights_path.- Return type:
- abstract step()[source]#
Take an optimization step. This will always be called after
backward().- Return type:
- default_implementation: Optional[str] = 'torch'#
The default implementation is
TorchTrainingEngine.
- class tango.integrations.torch.TorchTrainingEngine(train_config, model, optimizer, *, lr_scheduler=None, amp=False, max_grad_norm=None, amp_use_bfloat16=None)[source]#
This train engine only uses native PyTorch functionality to provide vanilla distributed data parallel training and AMP.
Tip
Registered as a
TrainingEngineunder the name “torch”.Important
Only the parameters listed below should be defined in a configuration file. The other parameters will be automatically passed to the constructor within
TorchTrainStep.- Parameters:
amp (
bool, default:False) – Use automatic mixed precision. Default isFalse.max_grad_norm (
Optional[float], default:None) – If set, gradients will be clipped to have this max norm. Default isNone.amp_use_bfloat16 (
Optional[bool], default:None) – Set toTrueto force using thebfloat16datatype in mixed precision training. Only applicable whenamp=True. If not specified, the default behavior will be to usebfloat16when training with AMP on CPU, otherwise not.
Optim#
- class tango.integrations.torch.Optimizer(params, defaults)[source]#
A
Registrableversion of a PyTorchtorch.optim.Optimizer.All built-in PyTorch optimizers are registered according to their class name (e.g. “torch::Adam”).
Tip
You can see a list of all available optimizers by running
from tango.integrations.torch import Optimizer for name in sorted(Optimizer.list_available()): print(name)
torch::ASGD torch::Adadelta torch::Adagrad torch::Adam torch::AdamW ...
- class tango.integrations.torch.LRScheduler(optimizer, last_epoch=-1, verbose=False)[source]#
A
Registrableversion of a PyTorch learning rate scheduler.All built-in PyTorch learning rate schedulers are registered according to their class name (e.g. “torch::StepLR”).
Tip
You can see a list of all available schedulers by running
from tango.integrations.torch import LRScheduler for name in sorted(LRScheduler.list_available()): print(name)
torch::ChainedScheduler torch::ConstantLR torch::CosineAnnealingLR ...
Data#
- class tango.integrations.torch.DataLoader(dataset, collate_fn=<tango.integrations.torch.data.ConcatTensorDictsCollator object>, sampler=None, **kwargs)[source]#
A
Registrableversion of a PyTorchDataLoader.
- class tango.integrations.torch.Sampler(data_source)[source]#
A
Registrableversion of a PyTorchSampler.All built-in PyTorch samplers are registered under their corresponding class name (e.g. “RandomSampler”).
- class tango.integrations.torch.DataCollator[source]#
A
Registrableversion of acollate_fnfor aDataLoader.Subclasses just need to implement
__call__().- default_implementation: Optional[str] = 'concat_tensor_dicts'#
The default implementation is
ConcatTensorDictsCollator.
- class tango.integrations.torch.ConcatTensorDictsCollator[source]#
A simple
collate_fnthat expects items to be dictionaries of tensors. The tensors are just concatenated together.Tip
Registered as a
DataCollatorunder the name “concat_tensor_dicts”.
Callbacks#
- class tango.integrations.torch.TrainCallback(workspace, train_config, training_engine, dataset_dict, train_dataloader, validation_dataloader=None)[source]#
A
TrainCallbackis aRegistrableclass that can be used withinTorchTrainStepto customize behavior in the training loop. You can set the training callbacks with thecallbacksparameter toTorchTrainStep.Tip
All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.
Tip
You can access the model being trained through
self.model.Important
The
stepargument to callback methods is the total/overall number of training steps so far, independent of the current epoch.See also
See
WandbTrainCallbackfor an example implementation.- Variables:
workspace (Workspace) – The tango workspace being used.
train_config (TrainConfig) – The training config.
training_engine (TrainingEngine) – The engine used to train the model.
dataset_dict (DatasetDictBase) – The dataset dict containing train and optional validation splits.
train_dataloader (DataLoader) – The dataloader used for the training split.
validation_dataloader (DataLoader) – Optional dataloader used for the validation split.
- property is_local_main_process: bool#
This is
Trueif the current worker is the main distributed worker of the current node, or if we are not using distributed training.
- state_dict()[source]#
Return any state that needs to be kept after a restart.
Some callbacks need to maintain state across restarts. This is the callback’s opportunity to save it’s state. It will be restored using
load_state_dict().
- load_state_dict(state_dict)[source]#
Load the state on a restart.
Some callbacks need to maintain state across restarts. This is the callback’s opportunity to restore it’s state. It gets saved using
state_dict().- Return type:
- pre_train_loop()[source]#
Called right before the first batch is processed, or after a restart.
- Return type:
- post_train_loop(step, epoch)[source]#
Called after the training loop completes.
This is the last method that is called, so any cleanup can be done in this method.
- Return type:
- pre_epoch(step, epoch)[source]#
Called right before the start of an epoch. Epochs start at 0.
- Return type:
- post_epoch(step, epoch)[source]#
Called after an epoch is completed. Epochs start at 0.
- Return type:
- pre_batch(step, epoch, batch)[source]#
Called directly before processing a batch. :rtype:
NoneNote
A type of
batchis a list because with gradient accumulation there will more than one “micro batch” in the batch.
- post_batch(step, epoch, batch_loss, batch_outputs)[source]#
Called directly after processing a batch, but before unscaling gradients, clipping gradients, and taking an optimizer step. :rtype:
NoneNote
The
batch_losshere is the loss local to the current worker, not the overall (average) batch loss across distributed workers.If you need the average loss, use
log_batch().Note
A type of
batch_outputsis a list because with gradient accumulation there will more than one “micro batch” in the batch.
- log_batch(step, epoch, batch_loss, batch_outputs)[source]#
Called after the optimizer step. Here
batch_lossis the average loss across all distributed workers. :rtype:NoneNote
This callback method is not necessarily called on every step. The frequency depends on the value of the
log_everyparameter ofTorchTrainStep.Note
A type of
batch_outputsis a list because with gradient accumulation there will more than one “micro batch” in the batch.
- pre_val_batch(step, val_step, epoch, val_batch)[source]#
Called right before a validation batch is processed.
- Return type:
- post_val_batch(step, val_step, epoch, val_batch_outputs)[source]#
Called right after a validation batch is processed with the outputs of the batch. :rtype:
NoneTip
This method can be used to modify
val_batch_outputsin place, which is useful in scenarios like distributed training where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to setauto_aggregate_val_metrictoFalseinTorchTrainStep.
- class tango.integrations.torch.EvalCallback(workspace, step_id, work_dir, model, dataset_dict, dataloader)[source]#
An
EvalCallbackis aRegistrableclass that can be used withinTorchEvalStepto customize the behavior of the evaluation loop, similar to howTrainCallbackis used to customize the behavior of the training loop.Tip
All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.
- Variables:
workspace (Workspace) – The tango workspace being used.
step_id (str) – The unique ID of the step.
work_dir (Path) – The working directory of the step
model (Model) – The model being evaluated.
dataset_dict (DatasetDictBase) – The dataset dict containing the evaluation split.
dataloader (DataLoader) – The data loader used to load the evaluation split data.
- post_eval_loop(aggregated_metrics)[source]#
Called after the evaluation loop completes with the final aggregated metrics.
This is the last method that is called, so any cleanup can be done in this method.
- Return type:
- post_batch(step, batch_outputs)[source]#
Called directly after processing a batch with the outputs of the batch. :rtype:
NoneTip
This method can be used to modify
batch_outputsin place, which is useful in scenarios where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to setauto_aggregate_metricstoFalseinTorchEvalStep.
- class tango.integrations.torch.StopEarlyCallback(*args, patience=10000, **kwargs)[source]#
A
TrainCallbackfor early stopping. Training is stopped early afterpatiencesteps without an improvement to the validation metric.Tip
Registered as a
TrainCallbackunder the name “torch::stop_early”.