Source code for tango.integrations.torch.format
from pathlib import Path
from typing import Generic, TypeVar
import dill
import torch
from tango.common.aliases import PathOrStr
from tango.format import Format
T = TypeVar("T")
[docs]@Format.register("torch")
class TorchFormat(Format[T], Generic[T]):
"""
This format writes the artifact using ``torch.save()``.
Unlike :class:`tango.format.DillFormat`, this has no special support for iterators.
.. tip::
Registered as a :class:`~tango.format.Format` under the name "torch".
"""
VERSION = "002"
def write(self, artifact: T, dir: PathOrStr):
filename = Path(dir) / "data.pt"
with open(filename, "wb") as f:
torch.save((self.VERSION, artifact), f, pickle_module=dill)
def read(self, dir: PathOrStr) -> T:
filename = Path(dir) / "data.pt"
with open(filename, "rb") as f:
version, artifact = torch.load(f, pickle_module=dill, map_location=torch.device("cpu"))
if version > self.VERSION:
raise ValueError(
f"File {filename} is too recent for this version of {self.__class__}."
)
return artifact