Source code for tango.integrations.flax.format

from pathlib import Path
from typing import Generic, TypeVar

from flax.training import checkpoints

from tango.common.aliases import PathOrStr
from tango.format import Format

T = TypeVar("T")


[docs]@Format.register("flax") class FlaxFormat(Format[T], Generic[T]): """ This format writes the artifact. .. tip:: Registered as a :class:`~tango.format.Format` under the name "flax". """ VERSION = "002" def write(self, artifact: T, dir: PathOrStr) -> None: checkpoints.save_checkpoint(Path(dir), artifact, step=0) def read(self, dir: PathOrStr) -> T: # will return a dict return checkpoints.restore_checkpoint(dir, target=None)