Source code for tango.integrations.wandb.torch_train_callback

from typing import Any, Dict, List, Optional

import torch
import wandb

from tango.common.exceptions import ConfigurationError
from tango.integrations.torch.train_callback import TrainCallback
from tango.integrations.torch.util import peak_gpu_memory

from .util import check_environment
from .workspace import WandbWorkspace


[docs]@TrainCallback.register("wandb::log") class WandbTrainCallback(TrainCallback): """ A torch :class:`~tango.integrations.torch.TrainCallback` for use with the :class:`~tango.integrations.torch.TorchTrainStep` that logs training and validation metrics to W&B. This can be used with any :class:`~tango.workspace.Workspace` implementation, including :class:`WandbWorkspace`. .. tip:: Registered as a :class:`~tango.integrations.torch.TrainCallback` under the name "wandb::log". .. important:: When this callback is used with the :class:`WandbWorkspace` it will log metrics to the same W&B project that the workspace uses. The ``group`` and ``name`` parameters will also automatically be set, so a :class:`~tango.common.exceptions.ConfigurationError` will be raised if any of ``project``, ``entity``, ``group``, or ``name`` are set in this callback. :param project: W&B project to associated this run with. :param entity: W&B entity (user or organization) to associated this run with. :param group: W&B group to associated this run with. :param name: Set the name of the run in W&B. If not set, the default will be the name of the step. :param notes: Arbitrary notes to add in W&B to this run. :param tags: Arbitrary tags to add in W&B to this run. :param watch_model: If ``True``, ``wandb.watch()`` is called to collect gradients and other information about the model throughout training. See `docs.wandb.ai/ref/python/watch <https://docs.wandb.ai/ref/python/watch>`_. :param wandb_config: Arbitrary configuration fields to set in W&B for this run. See `docs.wandb.ai/guides/track/config <https://docs.wandb.ai/guides/track/config>`_. """ def __init__( self, *args, project: Optional[str] = None, entity: Optional[str] = None, group: Optional[str] = None, name: Optional[str] = None, notes: Optional[str] = None, tags: Optional[List[str]] = None, watch_model: bool = False, wandb_config: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) if self.is_local_main_process: check_environment() if isinstance(self.workspace, WandbWorkspace) or wandb.run is not None: err_msg_template = "Cannot set '{var_name}' in WandbTrainCallback " if isinstance(self.workspace, WandbWorkspace): err_msg_template += "since it has already been set from the WandbWorkspace." else: err_msg_template += "since a W&B run has already been initialized." for var, var_name in [ (project, "project"), (entity, "entity"), (group, "group"), (name, "name"), ]: if var is not None: raise ConfigurationError(err_msg_template.format(var_name=var_name)) self.project = ( project if not isinstance(self.workspace, WandbWorkspace) else self.workspace.project ) self.entity = ( entity if not isinstance(self.workspace, WandbWorkspace) else self.workspace.entity ) self.group = group or self.step_id self.notes = notes or self._get_default_notes() self.tags = tags self.watch_model = watch_model self.wandb_config = self.train_config.as_dict() del self.wandb_config["worker_id"] if wandb_config is not None: self.wandb_config.update(wandb_config) if wandb.run is None: self.wandb_config["job_type"] = "train_metrics" self.run_name: str = name or self.step_name or "train" if self.train_config.is_distributed: self.run_name += f" (rank {self.train_config.worker_id})" self.run_id: str = ( wandb.run.id # type: ignore[attr-defined] if wandb.run is not None else self.step_id + f"-rank{self.train_config.worker_id}" ) self.resume: Optional[str] = None self.should_finalize_run: bool = ( wandb.run is None ) # if we have to start out own W&B run, we need to finish it def state_dict(self) -> Dict[str, Any]: return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.resume = "allow" def pre_train_loop(self) -> None: if wandb.run is None: if self.run_id is None: self.run_id = self.step_id + f"-rank{self.train_config.worker_id}" # Initialize a new W&B run. wandb.init( id=self.run_id, dir=str(self.work_dir), project=self.project, entity=self.entity, group=self.group, name=self.run_name, notes=self.notes, config=self.wandb_config, tags=self.tags, job_type="train_metrics", ) else: # We are already running inside of a W&B run, possibly because # we're using the WandbWorkspace. wandb.config.update(self.wandb_config) if self.tags: wandb.run.tags = (wandb.run.tags or tuple()) + tuple(self.tags) if self.notes: wandb.run.notes = self.notes if self.watch_model: wandb.watch(self.training_engine.model) # Log GPU memory statistics. if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() peak_gpu_mbs = peak_gpu_memory() if self.is_local_main_process: metrics = {f"sys/worker{rank}_peak_gpu_mem": mbs for rank, mbs in peak_gpu_mbs.items()} metrics["epoch"] = 0 wandb.log(metrics, step=0) def post_train_loop(self, step: int, epoch: int) -> None: if self.should_finalize_run: wandb.finish() def log_batch( self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]] ) -> None: peak_gpu_mbs = peak_gpu_memory() if self.is_local_main_process: metrics = { "train/loss": batch_loss, "train/lr": self.training_engine.optimizer.param_groups[0]["lr"], "epoch": epoch, } metrics.update( {f"sys/worker{rank}_peak_gpu_mem": mbs for rank, mbs in peak_gpu_mbs.items()} ) wandb.log( metrics, step=step + 1, ) def post_val_loop( self, step: int, epoch: int, val_metric: float, best_val_metric: float ) -> None: if self.is_local_main_process: wandb.log( { f"val/{self.train_config.val_metric_name}": val_metric, f"val/best_{self.train_config.val_metric_name}": best_val_metric, "epoch": epoch, }, step=step + 1, ) def _get_default_notes(self) -> str: notes = ( f'Metrics for Tango step "{self.step_name}" from worker {self.train_config.worker_id}.' ) if isinstance(self.workspace, WandbWorkspace): notes += f"\nMain run for step: {self.workspace.wandb_project_url}/runs/{self.step_id}/overview" return notes