Source code for tango.integrations.torch.optim
from typing import Type
import torch
from tango.common.registrable import Registrable
[docs]class Optimizer(torch.optim.Optimizer, Registrable):
"""
A :class:`~tango.common.Registrable` version of a PyTorch
:class:`torch.optim.Optimizer`.
All `built-in PyTorch optimizers
<https://pytorch.org/docs/stable/optim.html#algorithms>`_
are registered according to their class name (e.g. "torch::Adam").
.. tip::
You can see a list of all available optimizers by running
.. testcode::
from tango.integrations.torch import Optimizer
for name in sorted(Optimizer.list_available()):
print(name)
.. testoutput::
:options: +ELLIPSIS
torch::ASGD
torch::Adadelta
torch::Adagrad
torch::Adam
torch::AdamW
...
"""
[docs]class LRScheduler(torch.optim.lr_scheduler._LRScheduler, Registrable):
"""
A :class:`~tango.common.Registrable` version of a PyTorch learning
rate scheduler.
All `built-in PyTorch learning rate schedulers
<https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_
are registered according to their class name (e.g. "torch::StepLR").
.. tip::
You can see a list of all available schedulers by running
.. testcode::
from tango.integrations.torch import LRScheduler
for name in sorted(LRScheduler.list_available()):
print(name)
.. testoutput::
:options: +ELLIPSIS
torch::ChainedScheduler
torch::ConstantLR
torch::CosineAnnealingLR
...
"""
# Register all optimizers.
for name, cls in torch.optim.__dict__.items():
if (
isinstance(cls, type)
and issubclass(cls, torch.optim.Optimizer)
and not cls == torch.optim.Optimizer
):
Optimizer.register("torch::" + name)(cls)
# Note: This is a hack. Remove after we upgrade the torch version.
base_class: Type
try:
base_class = torch.optim.lr_scheduler.LRScheduler
except AttributeError:
base_class = torch.optim.lr_scheduler._LRScheduler
# Register all learning rate schedulers.
for name, cls in torch.optim.lr_scheduler.__dict__.items():
if isinstance(cls, type) and issubclass(cls, base_class) and not cls == base_class:
LRScheduler.register("torch::" + name)(cls)