Flax#

Reference#

Train step#

class tango.integrations.flax.FlaxTrainStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#

A Flax training step that supports distributed training with configurable dataloaders, callbacks, optimizer.

Tip

Registered as a Step under the name “flax::train”.

Important

To train on GPUs and TPUs, installation of jax[cuda] or jax[tpu] is required. Follow the instructions here: https://github.com/google/jax to set up jax for GPUs and TPUs. Note: CUDA and cuDNN installation is required to run jax on NVidia GPUs. It is recommended to install cuDNN in your conda environment using: conda install -c anaconda cudnn.

Distributed data parallel training is activated when the device_count is greater than 1. You can control which CUDA devices to use with the environment variable CUDA_VISIBLE_DEVICES. For example, to only use the GPUs with IDs 0 and 1, set CUDA_VISIBLE_DEVICES=0,1 (and device_count to 2).

Warning

During validation, the validation metric (specified by the val_metric_name parameter) is aggregated by simply averaging across validation batches and distributed processes. This behavior is usually correct when your validation metric is “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”. If this is not correct for your metric you will need to handle the aggregation internally in your model or with a TrainCallback using the TrainCallback.post_val_batch() method. Then set the parameter auto_aggregate_val_metric to False.

Jax pre-allocates 90% of GPU memory. If you run into out-of-memory (OOM) issues, please refer to this: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html.

run(model, dataset, optimizer, train_dataloader, *, wrapper, seed=42, keep_checkpoints=5, lr_scheduler=None, train_split='train', validation_dataloader=None, validation_split=None, train_steps=None, train_epoch=None, validation_steps=None, log_every=10, checkpoint_every=100, validate_every=None, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, callbacks=None, remove_stale_checkpoints=True)[source]#

Run a basic training loop to train the model.

Parameters:
  • model (Model) – The flax model to train. It should define __call__(). Defining setup() is Optional.

  • dataset (DatasetDictBase) – The train and optional validation dataset.

  • optimizer (Lazy[Optimizer]) – The name of the optax Optimizer to use for training.

  • train_dataloader (Lazy[FlaxDataLoader]) – The dataloader object that generates training batches.

  • wrapper (FlaxWrapper) – A Wrapper class that defines loss_fn(), eval_fn() and compute_metrics()

  • seed (int, default: 42) – Used to set the PRNG state. By default, seed=42

  • keep_checkpoints (int, default: 5) – An integer which denotes how many previous checkpoints should be stored while training. By default, keep_checkpoints=5

  • lr_scheduler (Optional[Lazy[LRScheduler]], default: None) – The name of the learning rate scheduler.

  • train_split (str, default: 'train') – The name of the data split used for training in the dataset_dict. Default is “train”.

  • validation_dataloader (Optional[Lazy[FlaxDataLoader]], default: None) – An optional data loader for generating validation batches. The batches should be dict objects. If not specified, but validation_split is given, the validation DataLoader will be constructed from the same parameters as the train DataLoader.

  • validation_split (Optional[str], default: None) – Optional name of the validation split in the dataset_dict. Default is None, which means no validation.

  • train_steps (Optional[int], default: None) – The number of steps to train for. If not specified training will stop after a complete iteration through the train_dataloader.

  • train_epoch (Optional[int], default: None) – The number of epochs to train for. You cannot specify train_steps and train_epochs at the same time.

  • validation_steps (Optional[int], default: None) – The number of steps to validate for. If not specified validation will stop after a complete iteration through the validation_dataloader.

  • log_every (int, default: 10) – Log every this many steps.

  • checkpoint_every (int, default: 100) – Save a checkpoint every this many steps.

  • validate_every (Optional[int], default: None) – Run the validation loop every this many steps.

  • val_metric_name (str, default: 'loss') – The name of the validation metric, i.e. the key of the metric in the dictionary returned by the forward pass of the model. Default is “loss”.

  • minimize_val_metric (bool, default: True) – Whether the validation metric is meant to be minimized (such as the loss). Default is True. When using a metric such as accuracy, you should set this to False.

  • auto_aggregate_val_metric (bool, default: True) – If True (the default), the validation metric will be averaged across validation batches and distributed processes. This may not be the correct behavior for some metrics (such as F1), in which you should set this to False and handle the aggregation internally in your model or with a TrainCallback (using TrainCallback.post_val_batch()).

  • callbacks (Optional[List[Lazy[TrainCallback]]], default: None) – A list of :class: TrainCallback.

  • remove_stale_checkpoints (bool, default: True) – If True (the default), stale checkpoints will be removed throughout training so that only the latest and best checkpoints are kept.

Return type:

Any

Returns:

The trained model with the last checkpoint loaded.

CACHEABLE: Optional[bool] = True#

This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is None, the step figures out by itself whether it should be cacheable or not.

DETERMINISTIC: bool = True#

This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is False, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.

FORMAT: Format = <tango.integrations.flax.format.FlaxFormat object>#

This specifies the format the results of this step will be serialized in. See the documentation for Format for details.

METADATA: Dict[str, Any] = {'artifact_kind': 'model'}#

Arbitrary metadata about the step.

SKIP_ID_ARGUMENTS: Set[str] = {'log_every'}#

If your run() method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.

For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.

class tango.integrations.flax.TrainConfig(step_id, work_dir, step_name=None, train_split='train', validation_split=None, seed=42, train_steps=None, train_epochs=None, validation_steps=None, log_every=10, checkpoint_every=100, validate_every=None, is_distributed=False, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, remove_stale_checkpoints=True)[source]#

Encapsulates the parameters of FlaxTrainStep. This is used to pass all the training options to TrainCallback.

auto_aggregate_val_metric: bool = True#

Controls automatic aggregation of validation metric.

property best_state_path: Path#

The path to the best state checkpoint file according to the validation metric or training loss (if no validation split is given).

checkpoint_every: int = 100#

Controls the frequency of checkpoints.

is_distributed: bool = False#

Whether or not the training job is distributed.

log_every: int = 10#

Controls the frequency of log updates.

minimize_val_metric: bool = True#

Should be True when the validation metric being tracked should be minimized.

remove_stale_checkpoints: bool = True#

Controls removal of stale checkpoints.

seed: int = 42#

The random seed used to generate

property state_path: Path#

The path to the latest state checkpoint file.

step_id: str#

The unique ID of the current step.

step_name: Optional[str] = None#

The name of the current step.

train_epochs: Optional[int] = None#

The number of epochs to train for.

You cannot specify train_steps and train_epochs at the same time.

train_split: str = 'train'#

The name of the training split.

train_steps: Optional[int] = None#

The number of steps to train for.

val_metric_name: str = 'loss'#

The name of the validation metric to track.

validate_every: Optional[int] = None#

Controls the frequency of the validation loop.

validation_split: Optional[str] = None#

The name of the validation split.

validation_steps: Optional[int] = None#

The number of validation steps.

work_dir: Path#

The working directory for the training run.

Eval step#

class tango.integrations.flax.FlaxEvalStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#

A Flax evaluation loop that pairs well with FlaxTrainStep.

Tip

Registered as a Step under the name “flax::eval”.

Important

The evaluation loop will use a GPU/TPU automatically if one is available. You can control which GPU it uses with the environment variable CUDA_VISIBLE_DEVICES. For example, set CUDA_VISIBLE_DEVICES=1 to force FlaxEvalStep to only use the GPU with ID 1.

Warning

By default the metrics specified by the metric_names parameter are aggregated by simply averaging across batches. This behavior is usually correct for metrics like “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”.

If this is not correct for your metric you will need to handle the aggregation internally in your model or with an EvalCallback using the EvalCallback.post_batch() method. Then set the parameter auto_aggregate_metrics to False.

run(state, dataset, dataloader, wrapper, test_split='test', seed=42, log_every=1, do_distributed=False, eval_steps=None, metric_names=('loss',), auto_aggregate_metrics=True, callbacks=None)[source]#

Evaluate the model.

Parameters:
  • state (TrainState) – The state of the model to evaluate. This contains the parameters.

  • dataset (DatasetDictBase) – Should contain the test data.

  • dataloader (Lazy[FlaxDataLoader]) – The data loader that generates test batches. The batches should be dict objects.

  • wrapper (FlaxWrapper) – The wrapper should define eval_metrics().

  • test_split (str, default: 'test') – The name of the data split used for evaluation in the dataset_dict. Default is “test”.

  • seed (int, default: 42) – Used to set the PRNG states at the beginning of the evaluation loop.

  • log_every (int, default: 1) – Log every this many steps. Default is 1.

  • do_distributed (bool, default: False) – Whether to do distributed training or not. Set as 0 or 1.

  • eval_steps (Optional[int], default: None) – The number of steps to evaluate for. If not specified evaluation will stop after a complete iteration through the dataloader.

  • metric_names (Sequence[str], default: ('loss',)) – The names of the metrics to track and aggregate. Default is ("loss",).

  • auto_aggregate_metrics (bool, default: True) – If True (the default), the metrics will be averaged across batches. This may not be the correct behavior for some metrics (such as F1), in which you should set this to False and handle the aggregation internally in your model or with an EvalCallback (using EvalCallback.post_batch()).

  • callbacks (Optional[List[Lazy[EvalCallback]]], default: None) – A list of EvalCallback.

Return type:

Dict[str, float]

CACHEABLE: Optional[bool] = True#

This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is None, the step figures out by itself whether it should be cacheable or not.

DETERMINISTIC: bool = True#

This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is False, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.

FORMAT: Format = <tango.format.JsonFormat object>#

This specifies the format the results of this step will be serialized in. See the documentation for Format for details.

SKIP_ID_ARGUMENTS: Set[str] = {'log_every'}#

If your run() method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.

For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.

Flax format#

class tango.integrations.flax.FlaxFormat[source]#

This format writes the artifact.

Tip

Registered as a Format under the name “flax”.

Model#

class tango.integrations.flax.Model(parent=<flax.linen.module._Sentinel object>, name=None)[source]#

This is a Registrable mixin class that inherits from flax.linen.Module. Its setup() can be used to register submodules, variables, parameters you will need in your model. Its __call__() returns the output of the model for a given input.

Optim#

class tango.integrations.flax.Optimizer(optimizer)[source]#

A Registrable version of Optax optimizers.

All built-in Optax optimizers are registered according to their class name (e.g. “optax::adam”).

Tip

You can see a list of all available optimizers by running

from tango.integrations.flax import Optimizer
for name in sorted(Optimizer.list_available()):
    print(name)
optax::adabelief
optax::adafactor
optax::adagrad
optax::adam
...
class tango.integrations.flax.LRScheduler(scheduler)[source]#

A Registrable version of an Optax learning rate scheduler.

All built-in Optax learning rate schedulers are registered according to their class name (e.g. “optax::linear_schedule”).

Tip

You can see a list of all available schedulers by running

from tango.integrations.flax import LRScheduler
for name in sorted(LRScheduler.list_available()):
    print(name)
optax::constant_schedule
optax::cosine_decay_schedule
optax::cosine_onecycle_schedule
optax::exponential_decay
...

Data#

class tango.integrations.flax.DataLoader[source]#

A Registrable version of a Flax DataLoader. Flax DataLoader accepts Dataset object. The class yields a numpy batch.

class tango.integrations.flax.FlaxDataLoader(dataset, batch_size=8, drop_last=True, shuffle=True)[source]#

Callbacks#

class tango.integrations.flax.TrainCallback(workspace, train_config, dataset, train_dataloader, model, optimizer, validation_dataloader=None)[source]#

A TrainCallback is a Registrable class that can be used within FlaxTrainStep to customize behavior in the training loop. You can set the training callbacks with the callbacks parameter to FlaxTrainStep.

Tip

All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.

Tip

You can access the model being trained through self.model.

Important

The step argument to callback methods is the total/overall number of training steps so far, independent of the current epoch.

See also

See WandbTrainCallback for an example implementation.

Variables:
  • workspace (Workspace) – The tango workspace being used.

  • train_config (TrainConfig) – The training config.

  • dataset_dict (DatasetDictBase) – The dataset dict containing train and optional validation splits.

  • train_dataloader (DataLoader) – The dataloader used for the training split.

  • model (Model) – The flax model being trained.

  • optimizer (Optimizer) – The optimizer being used for training.

  • validation_dataloader (DataLoader) – Optional dataloader used for the validation split.

property step_id: str#

The unique ID of the current Step.

property step_name: Optional[str]#

The name of the current:class:~tango.Step.

property work_dir: Path#

The working directory of the current train step

state_dict()[source]#

Return any state that needs to be kept after a restart.

Some callbacks need to maintain state across restarts. This is the callback’s opportunity to save it’s state. It will be restored using load_state_dict().

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]#

Load the state on a restart.

Some callbacks need to maintain state across restarts. This is the callback’s opportunity to restore it’s state. It gets saved using state_dict().

pre_train_loop()[source]#

Called right before the first batch is processed, or after a restart

Return type:

None

post_train_loop(step, epoch)[source]#

Called after the training loop completes.

This is the last method that is called, so any cleanup can be done in this method.

Return type:

None

pre_epoch(step, epoch)[source]#

Called before start of an epoch. Epochs start at 0.

Return type:

None

post_epoch(step, epoch)[source]#

Called after an epoch is completed. Epochs start at 0.

Return type:

None

pre_batch(step, epoch, batch)[source]#

Called directly before processing a batch.

Return type:

None

post_batch(step, epoch, train_metrics)[source]#

Called directly after processing a batch, but before unscaling gradients, clipping gradients, and taking an optimizer step. :rtype: None

Note

The train_metrics here is the dictionary with train metrics of the current batch. If doing, distributed training, use jax_utils.unreplicate(train_metrics) before using train_metrics.

If you need the average loss, use log_batch().

log_batch(step, epoch, train_metrics)[source]#

Called after the optimizer step. Here train_metrics is the average metrics across all distributed workers. If doing, distributed training, use jax_utils.unreplicate(train_metrics) before using train_metrics. :rtype: None

Note

This callback method is not necessarily called on every step. The frequency depends on the value of the log_every parameter of FlaxTrainStep.

pre_val_loop(step, val_step, state)[source]#

Called right before the validation loop starts.

Return type:

None

pre_val_batch(step, val_step, epoch, val_batch)[source]#

Called right before a validation batch is processed.

Return type:

None

post_val_batch(step, val_step, epoch, val_metrics)[source]#

Called right after a validation batch is processed with the outputs of the batch. :rtype: None

Tip

This method can be used to modify val_metrics in place, which is useful in scenarios like distributed training where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to set auto_aggregate_val_metric to False in FlaxTrainStep.

post_val_loop(step, epoch, val_metric, best_val_metric)[source]#

Called right after the evaluation loop finishes

Return type:

None

class tango.integrations.flax.EvalCallback(workspace, step_id, work_dir, dataset_dict, dataloader)[source]#

An EvalCallback is a Registrable class that can be used within FlaxEvalStep to customize the behavior of the evaluation loop, similar to how TrainCallback is used to customize the behavior of the training loop.

Tip

All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.

Variables:
  • workspace (Workspace) – The tango workspace being used.

  • step_id (str) – The unique ID of the step.

  • work_dir (Path) – The working directory of the step

  • dataset_dict (DatasetDictBase) – The dataset dict containing the evaluation split.

  • dataloader (DataLoader) – The data loader used to load the evaluation split data.

pre_eval_loop()[source]#

Called right before the first batch is processed

Return type:

None

post_eval_loop(aggregated_metrics)[source]#

Called after the evaluation loop completes with the final aggregated metrics.

This is the last method that is called, so any cleanup can be done in this method.

Return type:

None

pre_batch(step, batch)[source]#

Called directly before processing a batch.

Return type:

None

post_batch(step, batch_outputs)[source]#

Called directly after processing a batch with the outputs of the batch. :rtype: None

Tip

This method can be used to modify batch_outputs in place, which is useful in scenarios where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to set auto_aggregate_metrics to False in FlaxEvalStep.