🔥 PyTorch#

Important

To use this integration you should install tango with the “torch” extra (e.g. pip install tango[torch]) or just install PyTorch after the fact.

Make sure you install the correct version of torch given your operating system and supported CUDA version. Check pytorch.org/get-started/locally/ for more details.

Components for Tango integration with PyTorch.

These include a training loop Step and registrable versions of many torch classes, such torch.optim.Optimizer and torch.utils.data.DataLoader.

Example: training a model#

Let’s look a simple example of training a model.

We’ll make a basic regression model and generate some fake data to train on. First, the setup:

import torch
import torch.nn as nn

from tango.common.dataset_dict import DatasetDict
from tango.step import Step
from tango.integrations.torch import Model

Now let’s build and register our model:

@Model.register("basic_regression")
class BasicRegression(Model):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)
        self.sigmoid = nn.Sigmoid()
        self.mse = nn.MSELoss()

    def forward(self, x, y=None):
        pred = self.sigmoid(self.linear(x))
        out = {"pred": pred}
        if y is not None:
            out["loss"] = self.mse(pred, y)
        return out

    def _to_params(self):
        return {}

Lastly, we’ll need a step to generate data:

@Step.register("generate_data")
class GenerateData(Step):
    DETERMINISTIC = True
    CACHEABLE = False

    def run(self) -> DatasetDict:
        torch.manual_seed(1)
        return DatasetDict(
            {
                "train": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)],
                "validation": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)],
            }
        )

You could then run this experiment with a config that looks like this:

{
    "steps": {
        "data": {
            "type": "generate_data",
        },
        "train": {
            "type": "torch::train",
            "model": {
                "type": "basic_regression",
            },
            "training_engine": {
                "optimizer": {
                    "type": "torch::Adam",
                },
            },
            "dataset_dict": {
                "type": "ref",
                "ref": "data"
            },
            "train_dataloader": {
                "batch_size": 8,
                "shuffle": true
            },
            "validation_split": "validation",
            "validation_dataloader": {
                "batch_size": 8,
                "shuffle": false
            },
            "train_steps": 100,
            "validate_every": 10,
            "checkpoint_every": 10,
            "log_every": 1
        }
    }
}

For example,

tango run train.jsonnet -i my_package -d /tmp/train

would produce the following output:

Starting new run boss-alien
● Starting step "data" (needed by "train")...
✓ Finished step "data"
● Starting step "train"...
✓ Finished step "train"
✓ Finished run boss-alien
...

Tips#

Debugging#

When debugging a training loop that’s causing errors on a GPU, you should set the environment variable CUDA_LAUNCH_BLOCKING=1. This will ensure that the stack traces shows where the error actually happened.

You could also use a custom TrainCallback to log each batch before they are passed into the model so that you can see the exact inputs that are causing the issue.

Stopping early#

You can stop the “torch::train” step early using a custom TrainCallback. Your callback just needs to raise the StopEarly exception.

Reference#

Train step#

class tango.integrations.torch.TorchTrainStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#

A PyTorch training loop step that supports gradient accumulation, distributed training, and AMP, with configurable dataloaders, callbacks, optimizer, and LR scheduler.

Tip

Registered as a Step under the name “torch::train”.

Important

The training loop will use GPU(s) automatically when available, as long as at least device_count CUDA devices are available.

Distributed data parallel training is activated when the device_count is greater than 1.

You can control which CUDA devices to use with the environment variable CUDA_VISIBLE_DEVICES. For example, to only use the GPUs with IDs 0 and 1, set CUDA_VISIBLE_DEVICES=0,1 (and device_count to 2).

Warning

During validation, the validation metric (specified by the val_metric_name parameter) is aggregated by simply averaging across validation batches and distributed processes. This behavior is usually correct when your validation metric is “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”.

If this is not correct for your metric you will need to handle the aggregation internally in your model or with a TrainCallback using the TrainCallback.post_val_batch() method. Then set the parameter auto_aggregate_val_metric to False.

Note that correctly aggregating your metric during distributed training will involve distributed communication.

run(model, training_engine, dataset_dict, train_dataloader, *, train_split='train', validation_split=None, validation_dataloader=None, seed=42, train_steps=None, train_epochs=None, validation_steps=None, grad_accum=1, log_every=10, checkpoint_every=100, validate_every=None, device_count=1, distributed_port=54761, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, callbacks=None, remove_stale_checkpoints=True)[source]#

Run a basic training loop to train the model.

Parameters
  • model (Union[Lazy[Model], Model]) – The model to train. It should return a dict that includes the loss during training and the val_metric_name during validation.

  • training_engine (Lazy[TrainingEngine]) – The TrainingEngine to use to train the model.

  • dataset_dict (DatasetDictBase) – The train and optional validation data.

  • train_dataloader (Lazy[DataLoader]) – The data loader that generates training batches. The batches should be dict objects that will be used as kwargs for the model’s forward() method.

  • train_split (str, default: 'train') – The name of the data split used for training in the dataset_dict. Default is “train”.

  • validation_split (Optional[str], default: None) – Optional name of the validation split in the dataset_dict. Default is None, which means no validation.

  • validation_dataloader (Optional[Lazy[DataLoader]], default: None) – An optional data loader for generating validation batches. The batches should be dict objects. If not specified, but validation_split is given, the validation DataLoader will be constructed from the same parameters as the train DataLoader.

  • seed (int, default: 42) – Used to set the RNG states at the beginning of training.

  • train_steps (Optional[int], default: None) – The number of steps to train for. If not specified training will stop after a complete iteration through the train_dataloader.

  • train_epochs (Optional[int], default: None) – The number of epochs to train for. You cannot specify train_steps and train_epochs at the same time.

  • validation_steps (Optional[int], default: None) – The number of steps to validate for. If not specified validation will stop after a complete iteration through the validation_dataloader.

  • grad_accum (int, default: 1) –

    The number of gradient accumulation steps. Defaults to 1.

    Note

    This parameter - in conjuction with the settings of your data loader and the number distributed workers - determines the effective batch size of your training run.

  • log_every (int, default: 10) – Log every this many steps.

  • checkpoint_every (int, default: 100) – Save a checkpoint every this many steps.

  • validate_every (Optional[int], default: None) – Run the validation loop every this many steps.

  • device_count (int, default: 1) – The number of devices to train on, i.e. the number of distributed data parallel workers.

  • distributed_port (int, default: 54761) – The port of the distributed process group. Default = “54761”.

  • val_metric_name (str, default: 'loss') – The name of the validation metric, i.e. the key of the metric in the dictionary returned by the forward pass of the model. Default is “loss”.

  • minimize_val_metric (bool, default: True) – Whether the validation metric is meant to be minimized (such as the loss). Default is True. When using a metric such as accuracy, you should set this to False.

  • auto_aggregate_val_metric (bool, default: True) – If True (the default), the validation metric will be averaged across validation batches and distributed processes. This may not be the correct behavior for some metrics (such as F1), in which you should set this to False and handle the aggregation internally in your model or with a TrainCallback (using TrainCallback.post_val_batch()).

  • callbacks (Optional[List[Lazy[TrainCallback]]], default: None) – A list of TrainCallback.

  • remove_stale_checkpoints (bool, default: True) – If True (the default), stale checkpoints will be removed throughout training so that only the latest and best checkpoints are kept.

Return type

Model

Returns

The trained model on CPU with the weights from the best checkpoint loaded.

CACHEABLE: Optional[bool] = True#

This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is None, the step figures out by itself whether it should be cacheable or not.

DETERMINISTIC: bool = True#

This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is False, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.

FORMAT: Format = <tango.integrations.torch.format.TorchFormat object>#

This specifies the format the results of this step will be serialized in. See the documentation for Format for details.

METADATA: Dict[str, Any] = {'artifact_kind': 'model'}#

Arbitrary metadata about the step.

SKIP_ID_ARGUMENTS: Set[str] = {'distributed_port', 'log_every'}#

If your run() method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.

For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.

property resources: StepResources#

Defines the minimum compute resources required to run this step. Certain executors require this information in order to allocate resources for each step.

You can set this with the step_resources argument to Step or you can override this method to automatically define the required resources.

Return type

StepResources

class tango.integrations.torch.TrainConfig(step_id, work_dir, step_name=None, worker_id=0, train_split='train', validation_split=None, seed=42, train_steps=None, train_epochs=None, validation_steps=None, grad_accum=1, log_every=10, checkpoint_every=100, validate_every=None, is_distributed=False, devices=None, distributed_address='127.0.0.1', distributed_port=54761, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, remove_stale_checkpoints=True, world_size=1, _worker_local_default_device=None, _device_type=None)[source]#

Encapsulates the parameters of TorchTrainStep. This is used to pass all the training options to TrainCallback.

auto_aggregate_val_metric: bool = True#

Controls automatic aggregation of validation metric.

property best_state_path: Path#

The path to the best state checkpoint file according to the validation metric or training loss (if no validation split is given).

Return type

Path

checkpoint_every: int = 100#

Controls the frequency of checkpoints, in number of optimizer steps

devices: Optional[List[int]] = None#

The devices used (for distributed jobs).

distributed_address: str = '127.0.0.1'#

The IP address of the main distributed process.

distributed_port: int = 54761#

The port of the main distributed process.

grad_accum: int = 1#

The number of micro-batches per gradient accumulation mini-batch.

is_distributed: bool = False#

Whether or not the training job is distributed.

property is_local_main_process: bool#

Whether the local process is the main distributed worker.

Return type

bool

log_every: int = 10#

Controls the frequency of log updates, in number of optimizer steps

minimize_val_metric: bool = True#

Should be True when the validation metric being tracked should be minimized.

remove_stale_checkpoints: bool = True#

Controls removal of stale checkpoints.

seed: int = 42#

The random seed.

property state_path: Path#

The path to the latest state checkpoint file.

Return type

Path

step_id: str#

The unique ID of the current step.

step_name: Optional[str] = None#

The name of the current step.

Note

The same step can be run under different names.

train_epochs: Optional[int] = None#

The number of epochs to train for.

You cannot specify train_steps and train_epochs at the same time.

train_split: str = 'train'#

The name of the training split.

train_steps: Optional[int] = None#

The number of steps to train for.

val_metric_name: str = 'loss'#

The name of the validation metric to track.

validate_every: Optional[int] = None#

Controls the frequency of the validation loop, in number of optimizer steps

validation_split: Optional[str] = None#

The name of the validation split.

validation_steps: Optional[int] = None#

The number of validation steps.

The default is to validate on the entire validation set.

work_dir: Path#

The working directory for the training run.

worker_id: int = 0#

The ID of the distributed worker.

property worker_local_default_device: device#

The default torch device for the current worker.

Return type

device

world_size: int = 1#

The number of distributed workers.

Eval step#

class tango.integrations.torch.TorchEvalStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#

A PyTorch evaluation loop that pairs well with TorchTrainStep.

Tip

Registered as a Step under the name “torch::eval”.

Important

The evaluation loop will use a GPU automatically if one is available. You can control which GPU it uses with the environment variable CUDA_VISIBLE_DEVICES. For example, set CUDA_VISIBLE_DEVICES=1 to force TorchEvalStep to only use the GPU with ID 1.

Warning

By default the metrics specified by the metric_names parameter are aggregated by simply averaging across batches. This behavior is usually correct for metrics like “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”.

If this is not correct for your metric you will need to handle the aggregation internally in your model or with an EvalCallback using the EvalCallback.post_batch() method. Then set the parameter auto_aggregate_metrics to False.

run(model, dataset_dict, dataloader, test_split='test', seed=42, eval_steps=None, log_every=1, metric_names=('loss',), auto_aggregate_metrics=True, callbacks=None)[source]#

Evaluate the model.

Parameters
  • model (Model) – The model to evaluate. It should return a dict from its forward() method that includes all of the metrics in metric_names .

  • dataset_dict (DatasetDictBase) – Should contain the test data.

  • dataloader (Lazy[DataLoader]) – The data loader that generates test batches. The batches should be dict objects.

  • test_split (str, default: 'test') – The name of the data split used for evaluation in the dataset_dict. Default is “test”.

  • seed (int, default: 42) – Used to set the RNG states at the beginning of the evaluation loop.

  • eval_steps (Optional[int], default: None) – The number of steps to evaluate for. If not specified evaluation will stop after a complete iteration through the dataloader.

  • log_every (int, default: 1) – Log every this many steps. Default is 1.

  • metric_names (Sequence[str], default: ('loss',)) – The names of the metrics to track and aggregate. Default is ("loss",).

  • auto_aggregate_metrics (bool, default: True) – If True (the default), the metrics will be averaged across batches. This may not be the correct behavior for some metrics (such as F1), in which you should set this to False and handle the aggregation internally in your model or with an EvalCallback (using EvalCallback.post_batch()).

  • callbacks (Optional[List[Lazy[EvalCallback]]], default: None) – A list of EvalCallback.

Return type

Dict[str, float]

CACHEABLE: Optional[bool] = True#

This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is None, the step figures out by itself whether it should be cacheable or not.

DETERMINISTIC: bool = True#

This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is False, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.

FORMAT: Format = <tango.format.JsonFormat object>#

This specifies the format the results of this step will be serialized in. See the documentation for Format for details.

SKIP_ID_ARGUMENTS: Set[str] = {'log_every'}#

If your run() method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.

For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.

property resources: StepResources#

Defines the minimum compute resources required to run this step. Certain executors require this information in order to allocate resources for each step.

You can set this with the step_resources argument to Step or you can override this method to automatically define the required resources.

Return type

StepResources

Torch format#

class tango.integrations.torch.TorchFormat(*args, **kwds)[source]#

This format writes the artifact using torch.save().

Unlike tango.format.DillFormat, this has no special support for iterators.

Tip

Registered as a Format under the name “torch”.

Model#

class tango.integrations.torch.Model[source]#

This is a Registrable mixin class that inherits from torch.nn.Module. Its forward() method should return a dict that includes the loss during training and any tracked metrics during validation.

TrainingEngine#

class tango.integrations.torch.TrainingEngine(train_config, model, optimizer, *, lr_scheduler=None)[source]#

A TrainingEngine defines and drives the strategy for training a model in TorchTrainStep.

Variables
  • train_config (TrainConfig) – The training config.

  • model (Model) – The model being trained.

  • optimizer (Optimizer) – The optimizer being used to train the model.

  • lr_scheduler (LRScheduler) – The optional learning rate scheduler.

abstract backward(loss)[source]#

Run a backwards pass on the model. This will always be called after forward_train().

Return type

None

abstract forward_eval(batch)[source]#

Run a forward evaluation pass on the model.

Return type

Dict[str, Any]

abstract forward_train(micro_batch, micro_batch_idx, num_micro_batches)[source]#

Run a forward training pass on the model.

Return type

Tuple[Tensor, Dict[str, Any]]

abstract load_checkpoint(checkpoint_dir)[source]#

Load a checkpoint to resume training. Should return the same client_state saved in save_checkpoint().

Return type

Dict[str, Any]

abstract save_checkpoint(checkpoint_dir, client_state)[source]#

Save a training checkpoint with model state, optimizer state, etc., as well as the arbitrary client_state to the given checkpoint_dir.

Return type

None

abstract save_complete_weights_from_checkpoint(checkpoint_dir, weights_path)[source]#

Gather the final weights from the best checkpoint and save to the file at weights_path.

Return type

None

abstract step()[source]#

Take an optimization step. This will always be called after backward().

Return type

None

default_implementation: Optional[str] = 'torch'#

The default implementation is TorchTrainingEngine.

class tango.integrations.torch.TorchTrainingEngine(train_config, model, optimizer, *, lr_scheduler=None, amp=False, max_grad_norm=None, amp_use_bfloat16=None)[source]#

This train engine only uses native PyTorch functionality to provide vanilla distributed data parallel training and AMP.

Tip

Registered as a TrainingEngine under the name “torch”.

Important

Only the parameters listed below should be defined in a configuration file. The other parameters will be automatically passed to the constructor within TorchTrainStep.

Parameters
  • amp (bool, default: False) – Use automatic mixed precision. Default is False.

  • max_grad_norm (Optional[float], default: None) – If set, gradients will be clipped to have this max norm. Default is None.

  • amp_use_bfloat16 (Optional[bool], default: None) – Set to True to force using the bfloat16 datatype in mixed precision training. Only applicable when amp=True. If not specified, the default behavior will be to use bfloat16 when training with AMP on CPU, otherwise not.

Optim#

class tango.integrations.torch.Optimizer(params, defaults)[source]#

A Registrable version of a PyTorch torch.optim.Optimizer.

All built-in PyTorch optimizers are registered according to their class name (e.g. “torch::Adam”).

Tip

You can see a list of all available optimizers by running

from tango.integrations.torch import Optimizer
for name in sorted(Optimizer.list_available()):
    print(name)
torch::ASGD
torch::Adadelta
torch::Adagrad
torch::Adam
torch::AdamW
...
class tango.integrations.torch.LRScheduler(optimizer, last_epoch=-1, verbose=False)[source]#

A Registrable version of a PyTorch learning rate scheduler.

All built-in PyTorch learning rate schedulers are registered according to their class name (e.g. “torch::StepLR”).

Tip

You can see a list of all available schedulers by running

from tango.integrations.torch import LRScheduler
for name in sorted(LRScheduler.list_available()):
    print(name)
torch::ChainedScheduler
torch::ConstantLR
torch::CosineAnnealingLR
...

Data#

class tango.integrations.torch.DataLoader(dataset, collate_fn=<tango.integrations.torch.data.ConcatTensorDictsCollator object>, sampler=None, **kwargs)[source]#

A Registrable version of a PyTorch DataLoader.

class tango.integrations.torch.Sampler(data_source)[source]#

A Registrable version of a PyTorch Sampler.

All built-in PyTorch samplers are registered under their corresponding class name (e.g. “RandomSampler”).

class tango.integrations.torch.DataCollator(*args, **kwds)[source]#

A Registrable version of a collate_fn for a DataLoader.

Subclasses just need to implement __call__().

__call__(items)[source]#

Takes a list of items from a dataset and combines them into a batch.

Return type

Dict[str, Any]

default_implementation: Optional[str] = 'concat_tensor_dicts'#

The default implementation is ConcatTensorDictsCollator.

class tango.integrations.torch.ConcatTensorDictsCollator(*args, **kwds)[source]#

A simple collate_fn that expects items to be dictionaries of tensors. The tensors are just concatenated together.

Tip

Registered as a DataCollator under the name “concat_tensor_dicts”.

Callbacks#

class tango.integrations.torch.TrainCallback(workspace, train_config, training_engine, dataset_dict, train_dataloader, validation_dataloader=None)[source]#

A TrainCallback is a Registrable class that can be used within TorchTrainStep to customize behavior in the training loop. You can set the training callbacks with the callbacks parameter to TorchTrainStep.

Tip

All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.

Tip

You can access the model being trained through self.model.

Important

The step argument to callback methods is the total/overall number of training steps so far, independent of the current epoch.

See also

See WandbTrainCallback for an example implementation.

Variables
  • workspace (Workspace) – The tango workspace being used.

  • train_config (TrainConfig) – The training config.

  • training_engine (TrainingEngine) – The engine used to train the model.

  • dataset_dict (DatasetDictBase) – The dataset dict containing train and optional validation splits.

  • train_dataloader (DataLoader) – The dataloader used for the training split.

  • validation_dataloader (DataLoader) – Optional dataloader used for the validation split.

property step_id: str#

The unique ID of the current Step.

Return type

str

property step_name: Optional[str]#

The name of the current Step.

Return type

Optional[str]

property work_dir: Path#

The working directory of the current train step.

Return type

Path

property is_local_main_process: bool#

This is True if the current worker is the main distributed worker of the current node, or if we are not using distributed training.

Return type

bool

property model: Model#

The Model being trained.

Return type

Model

state_dict()[source]#

Return any state that needs to be kept after a restart.

Some callbacks need to maintain state across restarts. This is the callback’s opportunity to save it’s state. It will be restored using load_state_dict().

Return type

Dict[str, Any]

load_state_dict(state_dict)[source]#

Load the state on a restart.

Some callbacks need to maintain state across restarts. This is the callback’s opportunity to restore it’s state. It gets saved using state_dict().

Return type

None

pre_train_loop()[source]#

Called right before the first batch is processed, or after a restart.

Return type

None

post_train_loop(step, epoch)[source]#

Called after the training loop completes.

This is the last method that is called, so any cleanup can be done in this method.

Return type

None

pre_epoch(step, epoch)[source]#

Called right before the start of an epoch. Epochs start at 0.

Return type

None

post_epoch(step, epoch)[source]#

Called after an epoch is completed. Epochs start at 0.

Return type

None

pre_batch(step, epoch, batch)[source]#

Called directly before processing a batch.

Note

A type of batch is a list because with gradient accumulation there will more than one “micro batch” in the batch.

Return type

None

post_batch(step, epoch, batch_loss, batch_outputs)[source]#

Called directly after processing a batch, but before unscaling gradients, clipping gradients, and taking an optimizer step.

Note

The batch_loss here is the loss local to the current worker, not the overall (average) batch loss across distributed workers.

If you need the average loss, use log_batch().

Note

A type of batch_outputs is a list because with gradient accumulation there will more than one “micro batch” in the batch.

Return type

None

log_batch(step, epoch, batch_loss, batch_outputs)[source]#

Called after the optimizer step. Here batch_loss is the average loss across all distributed workers.

Note

This callback method is not necessarily called on every step. The frequency depends on the value of the log_every parameter of TorchTrainStep.

Note

A type of batch_outputs is a list because with gradient accumulation there will more than one “micro batch” in the batch.

Return type

None

pre_val_batch(step, val_step, epoch, val_batch)[source]#

Called right before a validation batch is processed.

Return type

None

post_val_batch(step, val_step, epoch, val_batch_outputs)[source]#

Called right after a validation batch is processed with the outputs of the batch.

Tip

This method can be used to modify val_batch_outputs in place, which is useful in scenarios like distributed training where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to set auto_aggregate_val_metric to False in TorchTrainStep.

Return type

None

post_val_loop(step, epoch, val_metric, best_val_metric)[source]#

Called right after the validation loop finishes.

Return type

None

class tango.integrations.torch.EvalCallback(workspace, step_id, work_dir, model, dataset_dict, dataloader)[source]#

An EvalCallback is a Registrable class that can be used within TorchEvalStep to customize the behavior of the evaluation loop, similar to how TrainCallback is used to customize the behavior of the training loop.

Tip

All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.

Variables
  • workspace (Workspace) – The tango workspace being used.

  • step_id (str) – The unique ID of the step.

  • work_dir (Path) – The working directory of the step

  • model (Model) – The model being evaluated.

  • dataset_dict (DatasetDictBase) – The dataset dict containing the evaluation split.

  • dataloader (DataLoader) – The data loader used to load the evaluation split data.

pre_eval_loop()[source]#

Called right before the first batch is processed.

Return type

None

post_eval_loop(aggregated_metrics)[source]#

Called after the evaluation loop completes with the final aggregated metrics.

This is the last method that is called, so any cleanup can be done in this method.

Return type

None

pre_batch(step, batch)[source]#

Called directly before processing a batch.

Return type

None

post_batch(step, batch_outputs)[source]#

Called directly after processing a batch with the outputs of the batch.

Tip

This method can be used to modify batch_outputs in place, which is useful in scenarios where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to set auto_aggregate_metrics to False in TorchEvalStep.

Return type

None

class tango.integrations.torch.StopEarlyCallback(*args, patience=10000, **kwargs)[source]#

A TrainCallback for early stopping. Training is stopped early after patience steps without an improvement to the validation metric.

Tip

Registered as a TrainCallback under the name “torch::stop_early”.

class tango.integrations.torch.StopEarly[source]#

Callbacks can raise this exception to stop training early without crashing.

Important

During distributed training all workers must raise this exception at the same point in the training loop, otherwise there will be a deadlock.