Source code for tango.integrations.flax.data
import logging
from typing import Generic, TypeVar
import jax.random
import numpy as np
from datasets import Dataset
from flax.training.common_utils import shard
from tango.common.registrable import Registrable
T = TypeVar("T")
[docs]class DataLoader(Generic[T], Registrable):
"""
A :class:`~tango.common.Registrable` version of a ``Flax DataLoader``.
``Flax DataLoader`` accepts Dataset object. The class yields a numpy batch.
"""
[docs]@DataLoader.register("flax::dataloader")
class FlaxDataLoader(DataLoader):
def __init__(
self,
dataset: Dataset,
batch_size: int = 8,
drop_last: bool = True,
shuffle: bool = True,
):
self.dataset = dataset
self.dataset_size = dataset.num_rows
self.batch_size = batch_size
self.drop_last = drop_last
if not drop_last:
raise NotImplementedError(
"With Jax you have to drop the last incomplete batch, because the batch size is compiled into the "
"model."
)
self.shuffle = shuffle
self.logger = logging.getLogger(FlaxDataLoader.__name__)
def __call__(self, rng: jax.random.PRNGKeyArray, do_distributed: bool):
steps_per_epoch = self.dataset_size // self.batch_size
if self.shuffle:
perms = jax.random.permutation(rng, self.dataset_size)
perms = np.asarray(perms) # using jax arrays for indexing is a bottleneck on TPUs.
else:
perms = np.arange(self.dataset_size)
self.logger.info("Skipping last incomplete batch")
perms = perms[: steps_per_epoch * self.batch_size] # Skip incomplete batch.
perms = perms.reshape((steps_per_epoch, self.batch_size))
for perm in perms:
batch = self.dataset[perm]
if do_distributed:
batch = shard(batch)
yield batch