Source code for tango.format
import bz2
import dataclasses
import gzip
import importlib
import json
import logging
import lzma
from abc import abstractmethod
from os import PathLike
from pathlib import Path
from typing import (
IO,
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
import dill
from tango.common import DatasetDict, filename_is_safe
from tango.common.aliases import PathOrStr
from tango.common.exceptions import ConfigurationError
from tango.common.registrable import Registrable
from tango.common.sequences import SqliteSparseSequence
T = TypeVar("T")
[docs]class Format(Registrable, Generic[T]):
"""
Formats write objects to directories and read them back out.
In the context of Tango, the objects that are written by formats are usually
the result of a :class:`~tango.step.Step`.
"""
VERSION: str = NotImplemented
"""
Formats can have versions. Versions are part of a step's unique signature, part of
:attr:`~tango.step.Step.unique_id`, so when a step's format changes,
that will cause the step to be recomputed.
"""
default_implementation = "dill"
[docs] @abstractmethod
def write(self, artifact: T, dir: PathOrStr):
"""Writes the ``artifact`` to the directory at ``dir``."""
raise NotImplementedError()
[docs] @abstractmethod
def read(self, dir: PathOrStr) -> T:
"""Reads an artifact from the directory at ``dir`` and returns it."""
raise NotImplementedError()
[docs] def _to_params(self) -> Dict[str, Any]:
params_dict = super()._to_params()
for key in ["logger", "__orig_class__"]:
params_dict.pop(key, None) # Removing unnecessary keys.
params_dict["type"] = self.__module__ + "." + self.__class__.__qualname__
return params_dict
_OPEN_FUNCTIONS: Dict[Optional[str], Callable[[PathLike, str], IO]] = {
None: open,
"None": open,
"none": open,
"null": open,
"gz": gzip.open, # type: ignore
"gzip": gzip.open, # type: ignore
"bz": bz2.open, # type: ignore
"bz2": bz2.open, # type: ignore
"bzip": bz2.open, # type: ignore
"bzip2": bz2.open, # type: ignore
"lzma": lzma.open,
}
_SUFFIXES: Dict[Callable, str] = {
open: "",
gzip.open: ".gz",
bz2.open: ".bz2",
lzma.open: ".xz",
}
def _open_compressed(filename: PathOrStr, mode: str) -> IO:
open_fn: Callable
filename = str(filename)
for open_fn, suffix in _SUFFIXES.items():
if len(suffix) > 0 and filename.endswith(suffix):
break
else:
open_fn = open
return open_fn(filename, mode)
[docs]@Format.register("dill")
class DillFormat(Format[T], Generic[T]):
"""
This format writes the artifact as a single file called "data.dill" using dill
(a drop-in replacement for pickle). Optionally, it can compress the data.
This is very flexible, but not always the fastest.
.. tip::
This format has special support for iterables. If you write an iterator, it will consume the
iterator. If you read an iterator, it will read the iterator lazily.
"""
VERSION = "001"
def __init__(self, compress: Optional[str] = None):
if compress not in _OPEN_FUNCTIONS:
raise ConfigurationError(f"The {compress} compression format does not exist.")
self.compress = compress
def write(self, artifact: T, dir: PathOrStr):
filename = self._get_artifact_path(dir)
open_method = _OPEN_FUNCTIONS[self.compress]
with open_method(filename, "wb") as f:
pickler = dill.Pickler(file=f)
pickler.dump(self.VERSION)
if hasattr(artifact, "__next__"):
pickler.dump(True)
for item in cast(Iterable, artifact):
pickler.dump(item)
else:
pickler.dump(False)
pickler.dump(artifact)
def read(self, dir: PathOrStr) -> T:
filename = self._get_artifact_path(dir)
open_method = _OPEN_FUNCTIONS[self.compress]
with open_method(filename, "rb") as f:
unpickler = dill.Unpickler(file=f)
version = unpickler.load()
if version > self.VERSION:
raise ValueError(
f"File {filename} is too recent for this version of {self.__class__}."
)
iterator = unpickler.load()
if iterator:
return DillFormatIterator(filename) # type: ignore
else:
return unpickler.load()
def _get_artifact_path(self, dir: PathOrStr) -> Path:
return Path(dir) / ("data.dill" + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]])
[docs]class DillFormatIterator(Iterator[T], Generic[T]):
"""
An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.DillFormat.read`.
"""
def __init__(self, filename: PathOrStr):
self.f: Optional[IO[Any]] = _open_compressed(filename, "rb")
self.unpickler = dill.Unpickler(self.f)
version = self.unpickler.load()
if version > DillFormat.VERSION:
raise ValueError(f"File {filename} is too recent for this version of {self.__class__}.")
iterator = self.unpickler.load()
if not iterator:
raise ValueError(
f"Tried to open {filename} as an iterator, but it does not store an iterator."
)
def __iter__(self) -> Iterator[T]:
return self
def __next__(self) -> T:
if self.f is None:
raise StopIteration()
try:
return self.unpickler.load()
except EOFError:
self.f.close()
self.f = None
raise StopIteration()
[docs]@Format.register("json")
class JsonFormat(Format[T], Generic[T]):
"""This format writes the artifact as a single file in json format.
Optionally, it can compress the data. This is very flexible, but not always the fastest.
.. tip::
This format has special support for iterables. If you write an iterator, it will consume the
iterator. If you read an iterator, it will read the iterator lazily.
"""
VERSION = "002"
def __init__(self, compress: Optional[str] = None):
self.logger = logging.getLogger(self.__class__.__name__)
if compress not in _OPEN_FUNCTIONS:
raise ConfigurationError(f"The {compress} compression format does not exist.")
self.compress = compress
@staticmethod
def _encoding_fallback(unencodable: Any):
try:
import torch
if isinstance(unencodable, torch.Tensor):
if len(unencodable.shape) == 0:
return unencodable.item()
else:
raise TypeError(
"Tensors must have 1 element and no dimensions to be JSON serializable."
)
except ImportError:
pass
if dataclasses.is_dataclass(unencodable):
result = dataclasses.asdict(unencodable)
module = type(unencodable).__module__
qualname = type(unencodable).__qualname__
if module == "builtins":
result["_dataclass"] = qualname
else:
result["_dataclass"] = [module, qualname]
return result
raise TypeError(f"Object of type {type(unencodable)} is not JSON serializable")
@staticmethod
def _decoding_fallback(o: Dict) -> Any:
if "_dataclass" in o:
classname: Union[str, List[str]] = o.pop("_dataclass")
if isinstance(classname, list) and len(classname) == 2:
module, classname = classname
constructor: Callable = importlib.import_module(module) # type: ignore
for item in classname.split("."):
constructor = getattr(constructor, item)
elif isinstance(classname, str):
constructor = globals()[classname]
else:
raise RuntimeError(f"Could not parse {classname} as the name of a dataclass.")
return constructor(**o)
return o
def write(self, artifact: T, dir: PathOrStr):
open_method = _OPEN_FUNCTIONS[self.compress]
if hasattr(artifact, "__next__"):
filename = self._get_artifact_path(dir, iterator=True)
with open_method(filename, "wt") as f:
for item in cast(Iterable, artifact):
json.dump(item, f, default=self._encoding_fallback)
f.write("\n")
else:
filename = self._get_artifact_path(dir, iterator=False)
with open_method(filename, "wt") as f:
json.dump(artifact, f, default=self._encoding_fallback)
def read(self, dir: PathOrStr) -> T:
iterator_filename = self._get_artifact_path(dir, iterator=True)
iterator_exists = iterator_filename.exists()
non_iterator_filename = self._get_artifact_path(dir, iterator=False)
non_iterator_exists = non_iterator_filename.exists()
if iterator_exists and non_iterator_exists:
self.logger.warning(
"Both %s and %s exist. Ignoring %s.",
iterator_filename,
non_iterator_filename,
iterator_filename,
)
iterator_exists = False
if not iterator_exists and not non_iterator_exists:
raise IOError("Attempting to read non-existing data from %s", dir)
if iterator_exists and not non_iterator_exists:
return JsonFormatIterator(iterator_filename) # type: ignore
elif not iterator_exists and non_iterator_exists:
open_method = _OPEN_FUNCTIONS[self.compress]
with open_method(non_iterator_filename, "rt") as f:
return json.load(f, object_hook=self._decoding_fallback)
else:
raise RuntimeError("This should be impossible.")
def _get_artifact_path(self, dir: PathOrStr, iterator: bool = False) -> Path:
return Path(dir) / (
("data.jsonl" if iterator else "data.json") + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]]
)
[docs]class JsonFormatIterator(Iterator[T], Generic[T]):
"""
An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.JsonFormat.read`.
"""
def __init__(self, filename: PathOrStr):
self.f: Optional[IO[Any]] = _open_compressed(filename, "rt")
def __iter__(self) -> Iterator[T]:
return self
def __next__(self) -> T:
if self.f is None:
raise StopIteration()
try:
line = self.f.readline()
if len(line) <= 0:
raise EOFError()
return json.loads(line, object_hook=JsonFormat._decoding_fallback)
except EOFError:
self.f.close()
self.f = None
raise StopIteration()
[docs]@Format.register("text")
class TextFormat(Format[Union[str, Iterable[str]]]):
"""This format writes the artifact as a single file in text format.
Optionally, it can compress the data. This is very flexible, but not always the fastest.
This format can only write strings, or iterable of strings.
.. tip::
This format has special support for iterables. If you write an iterator, it will consume the
iterator. If you read an iterator, it will read the iterator lazily.
Be aware that if your strings contain newlines, you will read out more strings than you wrote.
For this reason, it's often advisable to use :class:`JsonFormat` instead. With :class:`JsonFormat`,
all special characters are escaped, strings are quoted, but it's all still human-readable.
"""
VERSION = "001"
def __init__(self, compress: Optional[str] = None):
self.logger = logging.getLogger(self.__class__.__name__)
if compress not in _OPEN_FUNCTIONS:
raise ConfigurationError(f"The {compress} compression format does not exist.")
self.compress = compress
def write(self, artifact: Union[str, Iterable[str]], dir: PathOrStr):
open_method = _OPEN_FUNCTIONS[self.compress]
if hasattr(artifact, "__next__"):
filename = self._get_artifact_path(dir, iterator=True)
with open_method(filename, "wt") as f:
for item in cast(Iterable, artifact):
f.write(str(item))
f.write("\n")
else:
filename = self._get_artifact_path(dir, iterator=False)
with open_method(filename, "wt") as f:
f.write(str(artifact))
def read(self, dir: PathOrStr) -> Union[str, Iterable[str]]:
iterator_filename = self._get_artifact_path(dir, iterator=True)
iterator_exists = iterator_filename.exists()
non_iterator_filename = self._get_artifact_path(dir, iterator=False)
non_iterator_exists = non_iterator_filename.exists()
if iterator_exists and non_iterator_exists:
self.logger.warning(
"Both %s and %s exist. Ignoring %s.",
iterator_filename,
non_iterator_filename,
iterator_filename,
)
iterator_exists = False
if not iterator_exists and not non_iterator_exists:
raise IOError("Attempting to read non-existing data from %s", dir)
if iterator_exists and not non_iterator_exists:
return TextFormatIterator(iterator_filename) # type: ignore
elif not iterator_exists and non_iterator_exists:
open_method = _OPEN_FUNCTIONS[self.compress]
with open_method(non_iterator_filename, "rt") as f:
return f.read()
else:
raise RuntimeError("This should be impossible.")
def _get_artifact_path(self, dir: PathOrStr, iterator: bool = False) -> Path:
return Path(dir) / (
("texts.txt" if iterator else "text.txt") + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]]
)
[docs]class TextFormatIterator(Iterator[str]):
"""
An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.TextFormat.read`.
"""
def __init__(self, filename: PathOrStr):
self.f: Optional[IO[Any]] = _open_compressed(filename, "rt")
def __iter__(self) -> Iterator[str]:
return self
def __next__(self) -> str:
if self.f is None:
raise StopIteration()
try:
line = self.f.readline()
if len(line) <= 0:
raise EOFError()
if line.endswith("\n"):
line = line[:-1]
return line
except EOFError:
self.f.close()
self.f = None
raise StopIteration()
[docs]@Format.register("sqlite_sequence")
class SqliteSequenceFormat(Format[Sequence[T]]):
VERSION = "003"
FILENAME = "data.sqlite"
def write(self, artifact: Sequence[T], dir: Union[str, PathLike]):
dir = Path(dir)
try:
(dir / self.FILENAME).unlink()
except FileNotFoundError:
pass
if isinstance(artifact, SqliteSparseSequence):
artifact.copy_to(dir / self.FILENAME)
else:
sqlite = SqliteSparseSequence(dir / self.FILENAME)
sqlite.extend(artifact)
def read(self, dir: Union[str, PathLike]) -> Sequence[T]:
dir = Path(dir)
return SqliteSparseSequence(dir / self.FILENAME, read_only=True)
[docs]@Format.register("sqlite")
class SqliteDictFormat(Format[DatasetDict]):
"""This format works specifically on results of type :class:`~tango.common.DatasetDict`. It writes those
datasets into Sqlite databases.
During reading, the advantage is that the dataset can be read lazily. Reading a result that is stored
in :class:`SqliteDictFormat` takes milliseconds. No actual reading takes place until you access individual
instances.
During writing, you have to take some care to take advantage of the same trick. Recall that
:class:`~tango.DatasetDict` is basically a map, mapping split names to lists of instances. If you ensure
that those lists of instances are of type :class:`~tango.common.sequences.SqliteSparseSequence`, then writing
the results in :class:`SqliteDictFormat` can in many cases be instantaneous.
Here is an example of the pattern to use to make writing fast:
.. code-block:: Python
@Step.register("my_step")
class MyStep(Step[DatasetDict]):
FORMAT: Format = SqliteDictFormat()
VERSION = "001"
def run(self, ...) -> DatasetDict:
result: Dict[str, Sequence] = {}
for split_name in my_list_of_splits:
output_split = SqliteSparseSequence(self.work_dir / f"{split_name}.sqlite")
for instance in instances:
output_split.append(instance)
result[split_name] = output_split
metadata = {}
return DatasetDict(result, metadata)
Observe how for each split, we create a :class:`~tango.common.sequences.SqliteSparseSequence` in the step's
work directory (accessible with :meth:`~tango.step.Step.work_dir`). This has the added advantage that if the
step fails and you have to re-run it, the previous results that were already written to the
:class:`~tango.common.sequences.SqliteSparseSequence` are still there. You could replace the inner ``for``
loop like this to take advantage:
.. code-block:: Python
output_split = SqliteSparseSequence(self.work_dir / f"{split_name}.sqlite")
for instance in instances[len(output_split):]: # <-- here is the difference
output_split.append(instance)
result[split_name] = output_split
This works because when you re-run the step, the work directory will still be there, so ``output_split`` is
not empty when you open it.
"""
VERSION = "003"
def write(self, artifact: DatasetDict, dir: Union[str, PathLike]):
dir = Path(dir)
with gzip.open(dir / "metadata.dill.gz", "wb") as f:
dill.dump(artifact.metadata, f)
for split_name, split in artifact.splits.items():
filename = f"{split_name}.sqlite"
if not filename_is_safe(filename):
raise ValueError(f"{split_name} is not a valid name for a split.")
try:
(dir / filename).unlink()
except FileNotFoundError:
pass
if isinstance(split, SqliteSparseSequence):
split.copy_to(dir / filename)
else:
sqlite = SqliteSparseSequence(dir / filename)
sqlite.extend(split)
def read(self, dir: Union[str, PathLike]) -> DatasetDict:
dir = Path(dir)
with gzip.open(dir / "metadata.dill.gz", "rb") as f:
metadata = dill.load(f)
splits = {
filename.stem: SqliteSparseSequence(filename, read_only=True)
for filename in dir.glob("*.sqlite")
}
return DatasetDict(metadata=metadata, splits=splits)