Source code for tango.integrations.flax.optim

from inspect import isfunction
from typing import Callable, Type

import optax

from tango.common.registrable import Registrable


[docs]class Optimizer(Registrable): """ A :class:`~tango.common.Registrable` version of Optax optimizers. All `built-in Optax optimizers <https://optax.readthedocs.io/en/latest/api.html#>`_ are registered according to their class name (e.g. "optax::adam"). .. tip:: You can see a list of all available optimizers by running .. testcode:: from tango.integrations.flax import Optimizer for name in sorted(Optimizer.list_available()): print(name) .. testoutput:: :options: +ELLIPSIS optax::adabelief optax::adafactor optax::adagrad optax::adam ... """ def __init__(self, optimizer: Callable) -> None: self.optimizer = optimizer def __call__(self, **kwargs) -> optax.GradientTransformation: return self.optimizer(**kwargs)
[docs]class LRScheduler(Registrable): """ A :class:`~tango.common.Registrable` version of an Optax learning rate scheduler. All `built-in Optax learning rate schedulers <https://optax.readthedocs.io/en/latest/api.html#schedules>`_ are registered according to their class name (e.g. "optax::linear_schedule"). .. tip:: You can see a list of all available schedulers by running .. testcode:: from tango.integrations.flax import LRScheduler for name in sorted(LRScheduler.list_available()): print(name) .. testoutput:: :options: +ELLIPSIS optax::constant_schedule optax::cosine_decay_schedule optax::cosine_onecycle_schedule optax::exponential_decay ... """ def __init__(self, scheduler: Callable) -> None: self.scheduler = scheduler def __call__(self, **kwargs): return self.scheduler(**kwargs)
def optimizer_factory(optim_method: Callable) -> Type[Callable]: def factory_func(): return Optimizer(optim_method) return factory_func() def scheduler_factory(scheduler_method: Callable) -> Type[Callable]: def factory_func(): return LRScheduler(scheduler_method) return factory_func() # Register all optimizers. for name, cls in optax._src.alias.__dict__.items(): if isfunction(cls) and not name.startswith("_") and cls.__annotations__: factory_func = optimizer_factory(cls) Optimizer.register("optax::" + name)(factory_func) # Register all learning rate schedulers. for name, cls in optax._src.schedule.__dict__.items(): if isfunction(cls) and not name.startswith("_") and cls.__annotations__: factory_func = scheduler_factory(cls) LRScheduler.register("optax::" + name)(factory_func) # TODO: Handle inject_hyperparams. # Refer: https://optax.readthedocs.io/en/latest/api.html?highlight=inject%20hyperparam