Source code for tango.integrations.torch.data

from typing import Any, Dict, Generic, List, Optional, TypeVar, Union

import torch

from tango.common.lazy import Lazy
from tango.common.registrable import Registrable

T = TypeVar("T")


[docs]class DataCollator(Generic[T], Registrable): """ A :class:`~tango.common.Registrable` version of a ``collate_fn`` for a ``DataLoader``. Subclasses just need to implement :meth:`__call__()`. """ default_implementation = "concat_tensor_dicts" """ The default implementation is :class:`ConcatTensorDictsCollator`. """
[docs] def __call__(self, items: List[T]) -> Dict[str, Any]: """ Takes a list of items from a dataset and combines them into a batch. """ raise NotADirectoryError
[docs]@DataCollator.register("concat_tensor_dicts") class ConcatTensorDictsCollator(DataCollator[Dict[str, Any]]): """ A simple ``collate_fn`` that expects items to be dictionaries of tensors. The tensors are just concatenated together. .. tip:: Registered as a :class:`DataCollator` under the name "concat_tensor_dicts". """ def __call__(self, items: List[Dict[str, Any]]) -> Dict[str, Any]: out = {} keys = items[0].keys() for key in keys: if isinstance(items[0][key], torch.Tensor): out[key] = torch.cat([item[key].unsqueeze(0) for item in items]) elif isinstance(items[0][key], (int, float)): out[key] = torch.tensor([item[key] for item in items]) else: out[key] = [item[key] for item in items] # type: ignore[assignment] return out
[docs]class Sampler(torch.utils.data.Sampler, Registrable): """ A :class:`~tango.common.Registrable` version of a PyTorch :class:`~torch.utils.data.Sampler`. All `built-in PyTorch samplers <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`_ are registered under their corresponding class name (e.g. "RandomSampler"). """
@Sampler.register("torch::BatchSampler") class BatchSampler(torch.utils.data.BatchSampler, Sampler): def __init__( self, dataset: torch.utils.data.Dataset, sampler: Union[Lazy[Sampler], Sampler], batch_size: int, drop_last: bool, ) -> None: super().__init__( sampler.construct(data_source=dataset, dataset=dataset) if isinstance(sampler, Lazy) else sampler, batch_size, drop_last, ) # Register all remaining samplers. for name, cls in torch.utils.data.__dict__.items(): registered_name = "torch::" + name if ( isinstance(cls, type) and issubclass(cls, torch.utils.data.Sampler) and not cls == torch.utils.data.Sampler and registered_name not in Sampler.list_available() ): Sampler.register(registered_name)(cls)
[docs]class DataLoader(torch.utils.data.DataLoader, Registrable): """ A :class:`~tango.common.Registrable` version of a PyTorch :class:`~torch.utils.data.DataLoader`. """ default_implementation = "default" def __init__( self, dataset: torch.utils.data.Dataset, collate_fn: Optional[DataCollator] = ConcatTensorDictsCollator(), sampler: Optional[Union[Lazy[Sampler], Sampler]] = None, **kwargs, ): super().__init__( dataset, collate_fn=collate_fn, sampler=sampler.construct(data_source=dataset, dataset=dataset) if isinstance(sampler, Lazy) else sampler, **kwargs, )
DataLoader.register("default")(DataLoader)