import collections
import hashlib
import io
from abc import abstractmethod
from typing import Any, MutableMapping, Optional, Type
import base58
import dill
ndarray: Optional[Type]
try:
from numpy import ndarray
except ModuleNotFoundError:
ndarray = None
TorchTensor: Optional[Type]
try:
from torch import Tensor as TorchTensor
except ModuleNotFoundError:
TorchTensor = None
[docs]class CustomDetHash:
"""
By default, :func:`det_hash()` pickles an object, and returns the hash of the pickled
representation. Sometimes you want to take control over what goes into
that hash. In that case, derive from this class and implement :meth:`det_hash_object()`.
:func:`det_hash()` will pickle the result of this method instead of the object itself.
If you return ``None``, :func:`det_hash()` falls back to the original behavior and pickles
the object.
"""
[docs] @abstractmethod
def det_hash_object(self) -> Any:
"""
Return an object to use for deterministic hashing instead of ``self``.
"""
raise NotImplementedError()
[docs]class DetHashFromInitParams(CustomDetHash):
"""
Add this class as a mixin base class to make sure your class's det_hash is derived
exclusively from the parameters passed to ``__init__()``.
"""
_det_hash_object: Any
def __new__(cls, *args, **kwargs):
super_new = super(DetHashFromInitParams, cls).__new__
if super().__new__ is object.__new__ and cls.__init__ is not object.__init__:
instance = super_new(cls)
else:
instance = super_new(cls, *args, **kwargs)
instance._det_hash_object = (args, kwargs)
return instance
[docs] def det_hash_object(self) -> Any:
"""Returns a copy of the parameters that were passed to the class instance's ``__init__()`` method."""
return self._det_hash_object
[docs]class DetHashWithVersion(CustomDetHash):
"""
Add this class as a mixin base class to make sure your class's det_hash can be modified
by altering a static ``VERSION`` member of your class.
Let's say you are working on training a model. Whenever you change code that's part of your experiment,
you have to change the :attr:`~tango.step.Step.VERSION` of the step that's running that code to tell
Tango that the step has changed and should be re-run. But if
you are training your model using Tango's built-in :class:`~tango.integrations.torch.TorchTrainStep`,
how do you change the version of the step? The answer is, leave the version of the step alone, and
instead add a :attr:`VERSION` to your model by deriving from this class:
.. code-block:: Python
class MyModel(DetHashWithVersion):
VERSION = "001"
def __init__(self, ...):
...
"""
VERSION: Optional[str] = None
[docs] def det_hash_object(self) -> Any:
"""
Returns a tuple of :attr:`~tango.common.det_hash.DetHashWithVersion.VERSION` and this instance itself.
"""
if self.VERSION is not None:
return self.VERSION, self
else:
return None # When you return `None` from here, it falls back to just hashing the object itself.
_PICKLE_PROTOCOL = 4
class _DetHashPickler(dill.Pickler):
def __init__(self, buffer: io.BytesIO):
super().__init__(buffer, protocol=_PICKLE_PROTOCOL)
# We keep track of how deeply we are nesting the pickling of an object.
# If a class returns `self` as part of `det_hash_object()`, it causes an
# infinite recursion, because we try to pickle the `det_hash_object()`, which
# contains `self`, which returns a `det_hash_object()`, etc.
# So we keep track of how many times recursively we are trying to pickle the
# same object. We only call `det_hash_object()` the first time. We assume that
# if `det_hash_object()` returns `self` in any way, we want the second time
# to just pickle the object as normal. `DetHashWithVersion` takes advantage
# of this ability.
self.recursively_pickled_ids: MutableMapping[int, int] = collections.Counter()
def save(self, obj, save_persistent_id=True):
self.recursively_pickled_ids[id(obj)] += 1
super().save(obj, save_persistent_id)
self.recursively_pickled_ids[id(obj)] -= 1
def persistent_id(self, obj: Any) -> Any:
if isinstance(obj, CustomDetHash) and self.recursively_pickled_ids[id(obj)] <= 1:
det_hash_object = obj.det_hash_object()
if det_hash_object is not None:
return obj.__class__.__module__, obj.__class__.__qualname__, det_hash_object
else:
return None
elif isinstance(obj, type):
return obj.__module__, obj.__qualname__
elif callable(obj):
if hasattr(obj, "__module__") and hasattr(obj, "__qualname__"):
return obj.__module__, obj.__qualname__
else:
return None
elif ndarray is not None and isinstance(obj, ndarray):
# It's unclear why numpy arrays don't pickle in a consistent way.
return obj.dumps()
elif TorchTensor is not None and isinstance(obj, TorchTensor):
# It's unclear why torch tensors don't pickle in a consistent way.
import torch
with io.BytesIO() as buffer:
torch.save(obj, buffer, pickle_protocol=_PICKLE_PROTOCOL)
return buffer.getvalue()
else:
return None
[docs]def det_hash(o: Any) -> str:
"""
Returns a deterministic hash code of arbitrary Python objects.
If you want to override how we calculate the deterministic hash, derive from the
:class:`CustomDetHash` class and implement :meth:`CustomDetHash.det_hash_object()`.
"""
m = hashlib.blake2b()
with io.BytesIO() as buffer:
pickler = _DetHashPickler(buffer)
pickler.dump(o)
m.update(buffer.getbuffer())
return base58.b58encode(m.digest()).decode()