Flax#
Reference#
Train step#
- class tango.integrations.flax.FlaxTrainStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#
A Flax training step that supports distributed training with configurable dataloaders, callbacks, optimizer.
Tip
Registered as a
Step
under the name “flax::train”.Important
To train on GPUs and TPUs, installation of jax[cuda] or jax[tpu] is required. Follow the instructions here: https://github.com/google/jax to set up jax for GPUs and TPUs. Note: CUDA and cuDNN installation is required to run jax on NVidia GPUs. It is recommended to install cuDNN in your conda environment using:
conda install -c anaconda cudnn
.Distributed data parallel training is activated when the
device_count
is greater than 1. You can control which CUDA devices to use with the environment variableCUDA_VISIBLE_DEVICES
. For example, to only use the GPUs with IDs 0 and 1, setCUDA_VISIBLE_DEVICES=0,1
(anddevice_count
to 2).Warning
During validation, the validation metric (specified by the
val_metric_name
parameter) is aggregated by simply averaging across validation batches and distributed processes. This behavior is usually correct when your validation metric is “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”. If this is not correct for your metric you will need to handle the aggregation internally in your model or with aTrainCallback
using theTrainCallback.post_val_batch()
method. Then set the parameterauto_aggregate_val_metric
toFalse
.Jax pre-allocates 90% of GPU memory. If you run into out-of-memory (OOM) issues, please refer to this: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html.
- run(model, dataset, optimizer, train_dataloader, *, wrapper, seed=42, keep_checkpoints=5, lr_scheduler=None, train_split='train', validation_dataloader=None, validation_split=None, train_steps=None, train_epoch=None, validation_steps=None, log_every=10, checkpoint_every=100, validate_every=None, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, callbacks=None, remove_stale_checkpoints=True)[source]#
Run a basic training loop to train the
model
.- Parameters:
model (
Model
) – The flax model to train. It should define__call__()
. Definingsetup()
is Optional.dataset (
DatasetDictBase
) – The train and optional validation dataset.optimizer (
Lazy
[Optimizer
]) – The name of the optax Optimizer to use for training.train_dataloader (
Lazy
[FlaxDataLoader
]) – The dataloader object that generates training batches.wrapper (
FlaxWrapper
) – A Wrapper class that definesloss_fn()
,eval_fn()
andcompute_metrics()
seed (
int
, default:42
) – Used to set the PRNG state. By default,seed=42
keep_checkpoints (
int
, default:5
) – An integer which denotes how many previous checkpoints should be stored while training. By default,keep_checkpoints=5
lr_scheduler (
Optional
[Lazy
[LRScheduler
]], default:None
) – The name of the learning rate scheduler.train_split (
str
, default:'train'
) – The name of the data split used for training in thedataset_dict
. Default is “train”.validation_dataloader (
Optional
[Lazy
[FlaxDataLoader
]], default:None
) – An optional data loader for generating validation batches. The batches should bedict
objects. If not specified, butvalidation_split
is given, the validationDataLoader
will be constructed from the same parameters as the trainDataLoader
.validation_split (
Optional
[str
], default:None
) – Optional name of the validation split in thedataset_dict
. Default isNone
, which means no validation.train_steps (
Optional
[int
], default:None
) – The number of steps to train for. If not specified training will stop after a complete iteration through thetrain_dataloader
.train_epoch (
Optional
[int
], default:None
) – The number of epochs to train for. You cannot specifytrain_steps
andtrain_epochs
at the same time.validation_steps (
Optional
[int
], default:None
) – The number of steps to validate for. If not specified validation will stop after a complete iteration through thevalidation_dataloader
.log_every (
int
, default:10
) – Log every this many steps.checkpoint_every (
int
, default:100
) – Save a checkpoint every this many steps.validate_every (
Optional
[int
], default:None
) – Run the validation loop every this many steps.val_metric_name (
str
, default:'loss'
) – The name of the validation metric, i.e. the key of the metric in the dictionary returned by the forward pass of the model. Default is “loss”.minimize_val_metric (
bool
, default:True
) – Whether the validation metric is meant to be minimized (such as the loss). Default isTrue
. When using a metric such as accuracy, you should set this toFalse
.auto_aggregate_val_metric (
bool
, default:True
) – IfTrue
(the default), the validation metric will be averaged across validation batches and distributed processes. This may not be the correct behavior for some metrics (such as F1), in which you should set this toFalse
and handle the aggregation internally in your model or with aTrainCallback
(usingTrainCallback.post_val_batch()
).callbacks (
Optional
[List
[Lazy
[TrainCallback
]]], default:None
) – A list of :class: TrainCallback.remove_stale_checkpoints (
bool
, default:True
) – IfTrue
(the default), stale checkpoints will be removed throughout training so that only the latest and best checkpoints are kept.
- Return type:
- Returns:
The trained model with the last checkpoint loaded.
- CACHEABLE: Optional[bool] = True#
This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is
None
, the step figures out by itself whether it should be cacheable or not.
- DETERMINISTIC: bool = True#
This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is
False
, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.
-
FORMAT:
Format
= <tango.integrations.flax.format.FlaxFormat object># This specifies the format the results of this step will be serialized in. See the documentation for
Format
for details.
- METADATA: Dict[str, Any] = {'artifact_kind': 'model'}#
Arbitrary metadata about the step.
- SKIP_ID_ARGUMENTS: Set[str] = {'log_every'}#
If your
run()
method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.
- class tango.integrations.flax.TrainConfig(step_id, work_dir, step_name=None, train_split='train', validation_split=None, seed=42, train_steps=None, train_epochs=None, validation_steps=None, log_every=10, checkpoint_every=100, validate_every=None, is_distributed=False, val_metric_name='loss', minimize_val_metric=True, auto_aggregate_val_metric=True, remove_stale_checkpoints=True)[source]#
Encapsulates the parameters of
FlaxTrainStep
. This is used to pass all the training options toTrainCallback
.- property best_state_path: Path#
The path to the best state checkpoint file according to the validation metric or training loss (if no validation split is given).
-
minimize_val_metric:
bool
= True# Should be
True
when the validation metric being tracked should be minimized.
Eval step#
- class tango.integrations.flax.FlaxEvalStep(step_name=None, cache_results=None, step_format=None, step_config=None, step_unique_id_override=None, step_resources=None, step_metadata=None, step_extra_dependencies=None, **kwargs)[source]#
A Flax evaluation loop that pairs well with
FlaxTrainStep
.Tip
Registered as a
Step
under the name “flax::eval”.Important
The evaluation loop will use a GPU/TPU automatically if one is available. You can control which GPU it uses with the environment variable
CUDA_VISIBLE_DEVICES
. For example, setCUDA_VISIBLE_DEVICES=1
to forceFlaxEvalStep
to only use the GPU with ID 1.Warning
By default the metrics specified by the
metric_names
parameter are aggregated by simply averaging across batches. This behavior is usually correct for metrics like “loss” or “accuracy”, for example, but may not be correct for other metrics like “F1”.If this is not correct for your metric you will need to handle the aggregation internally in your model or with an
EvalCallback
using theEvalCallback.post_batch()
method. Then set the parameterauto_aggregate_metrics
toFalse
.- run(state, dataset, dataloader, wrapper, test_split='test', seed=42, log_every=1, do_distributed=False, eval_steps=None, metric_names=('loss',), auto_aggregate_metrics=True, callbacks=None)[source]#
Evaluate the
model
.- Parameters:
state (
TrainState
) – The state of the model to evaluate. This contains the parameters.dataset (
DatasetDictBase
) – Should contain the test data.dataloader (
Lazy
[FlaxDataLoader
]) – The data loader that generates test batches. The batches should bedict
objects.wrapper (
FlaxWrapper
) – The wrapper should defineeval_metrics()
.test_split (
str
, default:'test'
) – The name of the data split used for evaluation in thedataset_dict
. Default is “test”.seed (
int
, default:42
) – Used to set the PRNG states at the beginning of the evaluation loop.log_every (
int
, default:1
) – Log every this many steps. Default is1
.do_distributed (
bool
, default:False
) – Whether to do distributed training or not. Set as 0 or 1.eval_steps (
Optional
[int
], default:None
) – The number of steps to evaluate for. If not specified evaluation will stop after a complete iteration through thedataloader
.metric_names (
Sequence
[str
], default:('loss',)
) – The names of the metrics to track and aggregate. Default is("loss",)
.auto_aggregate_metrics (
bool
, default:True
) – IfTrue
(the default), the metrics will be averaged across batches. This may not be the correct behavior for some metrics (such as F1), in which you should set this toFalse
and handle the aggregation internally in your model or with anEvalCallback
(usingEvalCallback.post_batch()
).callbacks (
Optional
[List
[Lazy
[EvalCallback
]]], default:None
) – A list ofEvalCallback
.
- Return type:
- CACHEABLE: Optional[bool] = True#
This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn’t need to be cached, because HuggingFace datasets already have their own caching mechanism. But it’s still a deterministic step, and all following steps are allowed to cache. If it is
None
, the step figures out by itself whether it should be cacheable or not.
- DETERMINISTIC: bool = True#
This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is
False
, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.
-
FORMAT:
Format
= <tango.format.JsonFormat object># This specifies the format the results of this step will be serialized in. See the documentation for
Format
for details.
- SKIP_ID_ARGUMENTS: Set[str] = {'log_every'}#
If your
run()
method takes some arguments that don’t affect the results, list them here. Arguments listed here will not be used to calculate this step’s unique ID, and thus changing those arguments does not invalidate the cache.For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time.
Flax format#
Model#
- class tango.integrations.flax.Model(parent=<flax.linen.module._Sentinel object>, name=None)[source]#
This is a
Registrable
mixin class that inherits fromflax.linen.Module
. Itssetup()
can be used to register submodules, variables, parameters you will need in your model. Its__call__()
returns the output of the model for a given input.
Optim#
- class tango.integrations.flax.Optimizer(optimizer)[source]#
A
Registrable
version of Optax optimizers.All built-in Optax optimizers are registered according to their class name (e.g. “optax::adam”).
Tip
You can see a list of all available optimizers by running
from tango.integrations.flax import Optimizer for name in sorted(Optimizer.list_available()): print(name)
optax::adabelief optax::adafactor optax::adagrad optax::adam ...
- class tango.integrations.flax.LRScheduler(scheduler)[source]#
A
Registrable
version of an Optax learning rate scheduler.All built-in Optax learning rate schedulers are registered according to their class name (e.g. “optax::linear_schedule”).
Tip
You can see a list of all available schedulers by running
from tango.integrations.flax import LRScheduler for name in sorted(LRScheduler.list_available()): print(name)
optax::constant_schedule optax::cosine_decay_schedule optax::cosine_onecycle_schedule optax::exponential_decay ...
Data#
- class tango.integrations.flax.DataLoader(*args, **kwds)[source]#
A
Registrable
version of aFlax DataLoader
.Flax DataLoader
accepts Dataset object. The class yields a numpy batch.
Callbacks#
- class tango.integrations.flax.TrainCallback(workspace, train_config, dataset, train_dataloader, model, optimizer, validation_dataloader=None)[source]#
A
TrainCallback
is aRegistrable
class that can be used withinFlaxTrainStep
to customize behavior in the training loop. You can set the training callbacks with thecallbacks
parameter toFlaxTrainStep
.Tip
All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.
Tip
You can access the model being trained through
self.model
.Important
The
step
argument to callback methods is the total/overall number of training steps so far, independent of the current epoch.See also
See
WandbTrainCallback
for an example implementation.- Variables:
workspace (Workspace) – The tango workspace being used.
train_config (TrainConfig) – The training config.
dataset_dict (DatasetDictBase) – The dataset dict containing train and optional validation splits.
train_dataloader (DataLoader) – The dataloader used for the training split.
model (Model) – The flax model being trained.
optimizer (Optimizer) – The optimizer being used for training.
validation_dataloader (DataLoader) – Optional dataloader used for the validation split.
- state_dict()[source]#
Return any state that needs to be kept after a restart.
Some callbacks need to maintain state across restarts. This is the callback’s opportunity to save it’s state. It will be restored using
load_state_dict()
.
- load_state_dict(state_dict)[source]#
Load the state on a restart.
Some callbacks need to maintain state across restarts. This is the callback’s opportunity to restore it’s state. It gets saved using
state_dict()
.
- pre_train_loop()[source]#
Called right before the first batch is processed, or after a restart
- Return type:
- post_train_loop(step, epoch)[source]#
Called after the training loop completes.
This is the last method that is called, so any cleanup can be done in this method.
- Return type:
- post_epoch(step, epoch)[source]#
Called after an epoch is completed. Epochs start at 0.
- Return type:
- post_batch(step, epoch, train_metrics)[source]#
Called directly after processing a batch, but before unscaling gradients, clipping gradients, and taking an optimizer step. :rtype:
None
Note
The
train_metrics
here is the dictionary with train metrics of the current batch. If doing, distributed training, use jax_utils.unreplicate(train_metrics) before using train_metrics.If you need the average loss, use
log_batch()
.
- log_batch(step, epoch, train_metrics)[source]#
Called after the optimizer step. Here
train_metrics
is the average metrics across all distributed workers. If doing, distributed training, use jax_utils.unreplicate(train_metrics) before using train_metrics. :rtype:None
Note
This callback method is not necessarily called on every step. The frequency depends on the value of the
log_every
parameter ofFlaxTrainStep
.
- pre_val_loop(step, val_step, state)[source]#
Called right before the validation loop starts.
- Return type:
- pre_val_batch(step, val_step, epoch, val_batch)[source]#
Called right before a validation batch is processed.
- Return type:
- post_val_batch(step, val_step, epoch, val_metrics)[source]#
Called right after a validation batch is processed with the outputs of the batch. :rtype:
None
Tip
This method can be used to modify
val_metrics
in place, which is useful in scenarios like distributed training where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to setauto_aggregate_val_metric
toFalse
inFlaxTrainStep
.
- class tango.integrations.flax.EvalCallback(workspace, step_id, work_dir, dataset_dict, dataloader)[source]#
An
EvalCallback
is aRegistrable
class that can be used withinFlaxEvalStep
to customize the behavior of the evaluation loop, similar to howTrainCallback
is used to customize the behavior of the training loop.Tip
All of the parameters to this base class will be automatically set within the training loop, so you shouldn’t include them in your config for your callbacks.
- Variables:
workspace (Workspace) – The tango workspace being used.
step_id (str) – The unique ID of the step.
work_dir (Path) – The working directory of the step
dataset_dict (DatasetDictBase) – The dataset dict containing the evaluation split.
dataloader (DataLoader) – The data loader used to load the evaluation split data.
- post_eval_loop(aggregated_metrics)[source]#
Called after the evaluation loop completes with the final aggregated metrics.
This is the last method that is called, so any cleanup can be done in this method.
- Return type:
- post_batch(step, batch_outputs)[source]#
Called directly after processing a batch with the outputs of the batch. :rtype:
None
Tip
This method can be used to modify
batch_outputs
in place, which is useful in scenarios where you might need to aggregate metrics in a special way other than a simple average. If that’s the case, make sure to setauto_aggregate_metrics
toFalse
inFlaxEvalStep
.