Source code for tango.step

import inspect
import itertools
import logging
import random
import re
import warnings
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    ClassVar,
    Dict,
    Generic,
    Iterable,
    Optional,
    Set,
    Type,
    TypeVar,
    Union,
    cast,
)

from tango.common.det_hash import CustomDetHash, det_hash
from tango.common.exceptions import ConfigurationError, StepStateError
from tango.common.from_params import (
    FromParams,
    infer_constructor_params,
    infer_method_params,
    pop_and_construct_arg,
)
from tango.common.lazy import Lazy
from tango.common.logging import cli_logger, log_exception
from tango.common.params import Params
from tango.common.registrable import Registrable
from tango.format import DillFormat, Format

try:
    from typing import get_args, get_origin  # type: ignore
except ImportError:

    def get_origin(tp):  # type: ignore
        return getattr(tp, "__origin__", None)

    def get_args(tp):  # type: ignore
        return getattr(tp, "__args__", ())


if TYPE_CHECKING:
    from tango.workspace import Workspace

_version_re = re.compile("""^[a-zA-Z0-9]+$""")

T = TypeVar("T")


_random_for_step_names = random.Random()


[docs]@dataclass class StepResources(FromParams): """ TaskResources describe minimum external hardware requirements which must be available for a step to run. """ machine: Optional[str] = None """ This is an executor-dependent option. With the Beaker executor, for example, you can set this to "local" to force the executor to run the step locally instead of on Beaker. """ cpu_count: Optional[float] = None """ Minimum number of logical CPU cores. It may be fractional. Examples: ``4``, ``0.5``. """ gpu_count: Optional[int] = None """ Minimum number of GPUs. It must be non-negative. """ gpu_type: Optional[str] = None """ The type of GPU that the step requires. The exact string you should use to define a GPU type depends on the executor. With the Beaker executor, for example, you should use the same strings you see in the Beaker UI, such as 'NVIDIA A100-SXM-80GB'. """ memory: Optional[str] = None """ Minimum available system memory as a number with unit suffix. Examples: ``2.5GiB``, ``1024m``. """ shared_memory: Optional[str] = None """ Size of ``/dev/shm`` as a number with unit suffix. Examples: ``2.5GiB``, ``1024m``. """
[docs]class Step(Registrable, Generic[T]): """ This class defines one step in your experiment. To write your own step, derive from this class and overwrite the :meth:`run()` method. The :meth:`run()` method must have parameters with type hints. ``Step.__init__()`` takes all the arguments we want to run the step with. They get passed to :meth:`run()` (almost) as they are. If the arguments are other instances of ``Step``, those will be replaced with the step's results before calling :meth:`run()`. Further, there are four special parameters: :param step_name: contains an optional human-readable name for the step. This name is used for error messages and the like, and has no consequence on the actual computation. :param cache_results: specifies whether the results of this step should be cached. If this is ``False``, the step is recomputed every time it is needed. If this is not set at all, and :attr:`CACHEABLE` is ``True``, we cache if the step is marked as :attr:`DETERMINISTIC`, and we don't cache otherwise. :param step_format: gives you a way to override the step's default format (which is given in :attr:`FORMAT`). :param step_config: is the original raw part of the experiment config corresponding to this step. This can be accessed via the :attr:`config` property within each step's :meth:`run()` method. :param step_unique_id_override: overrides the construction of the step's unique id using the hash of inputs. :param step_resources: gives you a way to set the minimum compute resources required to run this step. Certain executors require this information. :param step_metadata: use this to specify additional metadata for your step. This is added to the :attr:`METADATA` class variable to form the ``self.metadata`` attribute. Values in ``step_metadata`` take precedence over ``METADATA``. :param step_extra_dependencies: use this to force a dependency on other steps. Normally dependencies between steps are determined by the inputs and outputs of the steps, but you can use this parameter to force that other steps run before this step even if this step doesn't explicitly depend on the outputs of those steps. .. important:: Overriding the unique id means that the step will always map to this value, regardless of the inputs, and therefore, the step cache will only hold a single copy of the step's output (from the last execution). Thus, in most cases, this should not be used when constructing steps. We include this option for the case when the executor creates subprocesses, which also need to access the *same* ``Step`` object. """ DETERMINISTIC: bool = True """This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is ``False``, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.""" CACHEABLE: Optional[bool] = None """This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn't need to be cached, because HuggingFace datasets already have their own caching mechanism. But it's still a deterministic step, and all following steps are allowed to cache. If it is ``None``, the step figures out by itself whether it should be cacheable or not.""" VERSION: Optional[str] = None """This is optional, but recommended. Specifying a version gives you a way to tell Tango that a step has changed during development, and should now be recomputed. This doesn't invalidate the old results, so when you revert your code, the old cache entries will stick around and be picked up.""" FORMAT: Format = DillFormat("gz") """This specifies the format the results of this step will be serialized in. See the documentation for :class:`~tango.format.Format` for details.""" SKIP_ID_ARGUMENTS: Set[str] = set() """If your :meth:`run()` method takes some arguments that don't affect the results, list them here. Arguments listed here will not be used to calculate this step's unique ID, and thus changing those arguments does not invalidate the cache. For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time. """ SKIP_DEFAULT_ARGUMENTS: Dict[str, Any] = {} """Sometimes, you want to add another argument to your :meth:`run()` method, but you don't want to invalidate the cache when this new argument is set to its default value. If that is the case, add the argument to this dictionary with the default value that should be ignored.""" METADATA: Dict[str, Any] = {} """ Arbitrary metadata about the step. """ _UNIQUE_ID_SUFFIX: Optional[str] = None """ Used internally for testing. """
[docs] def __init__( self, step_name: Optional[str] = None, cache_results: Optional[bool] = None, step_format: Optional[Format] = None, step_config: Optional[Union[Dict[str, Any], Params]] = None, step_unique_id_override: Optional[str] = None, step_resources: Optional[StepResources] = None, step_metadata: Optional[Dict[str, Any]] = None, step_extra_dependencies: Optional[Iterable["Step"]] = None, **kwargs, ): if self.VERSION is not None: assert _version_re.match( self.VERSION ), f"Invalid characters in version '{self.VERSION}'" run_defaults = { k: v.default for k, v in inspect.signature(self.run).parameters.items() if v.default is not inspect.Parameter.empty } self.kwargs = self.massage_kwargs({**run_defaults, **kwargs}) if step_format is None: self.format = self.FORMAT if isinstance(self.format, type): self.format = self.format() else: self.format = step_format self.unique_id_cache = step_unique_id_override if step_name is None: self.name = self.unique_id else: self.name = step_name # TODO: It is bad design to have the step_name in the Step class. The same step can be part of multiple # runs at the same time, and they can have different names in different runs. Step names are # a property of the run, not of the step. if cache_results is True: if not self.CACHEABLE: raise ConfigurationError( f"Step {self.name} is configured to use the cache, but it's not a cacheable step." ) if not self.DETERMINISTIC: warnings.warn( f"Step {self.name} is going to be cached despite not being deterministic.", UserWarning, ) self.cache_results = True elif cache_results is False: self.cache_results = False elif cache_results is None: c = (self.DETERMINISTIC, self.CACHEABLE) if c == (False, None): self.cache_results = False elif c == (True, None): self.cache_results = True elif c == (False, False): self.cache_results = False elif c == (True, False): self.cache_results = False elif c == (False, True): warnings.warn( f"Step {self.name} is set to be cacheable despite not being deterministic.", UserWarning, ) self.cache_results = True elif c == (True, True): self.cache_results = True else: assert False, "Step.DETERMINISTIC or step.CACHEABLE are set to an invalid value." else: raise ConfigurationError( f"Step {self.name}'s cache_results parameter is set to an invalid value." ) self._workspace: Optional["Workspace"] = None self.work_dir_for_run: Optional[ Path ] = None # This is set only while the run() method runs. if isinstance(step_config, Params): self._config = step_config.as_dict(quiet=True) else: self._config = step_config assert step_resources is None or isinstance(step_resources, StepResources) self.step_resources = step_resources self.metadata = deepcopy(self.METADATA) if step_metadata: self.metadata.update(step_metadata) self.extra_dependencies = set(step_extra_dependencies) if step_extra_dependencies else set()
@property def class_name(self) -> str: return self.__class__.__name__
[docs] @classmethod def massage_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: """ Override this method in your step if you want to change the step's arguments before they are passed to the :meth:`run()` method. This can be useful if you want to normalize arguments that are passed to your step. For example, you might not care about the case of a string that's passed in. You can lowercase the string in this method, and the step will function as if it had been created with a lowercase string from the start. This way you can make sure that the step's unique ID does not change when the case of the input changes. .. note:: When the input to a step is another step, this method will see the step in the input, not the other step's result. .. warning:: This is an advanced feature of Tango that you won't need most of the time. By default, this method does nothing and just returns its input unchanged. :param kwargs: The original kwargs that were passed to the step during construction. :return: New kwargs that will be passed to the step's :meth:`run()` method. """ return kwargs
@property def logger(self) -> logging.Logger: """ A :class:`logging.Logger` that can be used within the :meth:`run()` method. """ return logging.getLogger(self.__class__.__name__) @classmethod def from_params( # type: ignore[override] cls: Type["Step"], params: Union[Params, dict, str], constructor_to_call: Optional[Callable[..., "Step"]] = None, constructor_to_inspect: Optional[ Union[Callable[..., "Step"], Callable[["Step"], None]] ] = None, step_name: Optional[str] = None, **extras, ) -> "Step": # Why do we need a custom from_params? Step classes have a run() method that takes all the # parameters necessary to perform the step. The __init__() method of the step takes those # same parameters, but each of them could be wrapped in another Step instead of being # supplied directly. from_params() doesn't know anything about these shenanigans, so # we have to supply the necessary logic here. if constructor_to_call is not None: raise ConfigurationError( f"{cls.__name__}.from_params cannot be called with a constructor_to_call." ) if constructor_to_inspect is not None: raise ConfigurationError( f"{cls.__name__}.from_params cannot be called with a constructor_to_inspect." ) if isinstance(params, str): params = Params({"type": params}) if not isinstance(params, Params): if isinstance(params, dict): params = Params(params) else: raise ConfigurationError( "from_params was passed a ``params`` object that was not a ``Params``. This probably " "indicates malformed parameters in a configuration file, where something that " "should have been a dictionary was actually a list, or something else. " f"This happened when constructing an object of type {cls}." ) # Build up a raw step config def replace_steps_with_refs(o: Any) -> Any: if isinstance(o, (list, tuple, set)): return o.__class__(replace_steps_with_refs(i) for i in o) elif isinstance(o, (dict, Params)): result = {key: replace_steps_with_refs(value) for key, value in o.items()} if isinstance(o, dict): return result elif isinstance(o, Params): return Params(result, history=o.history) elif isinstance(o, Step): return {"type": "ref", "ref": o.name} else: return deepcopy(o) raw_step_config = replace_steps_with_refs(params.as_dict(quiet=True)) as_registrable = cast(Type[Registrable], cls) if "type" in params and params["type"] not in as_registrable.list_available(): as_registrable.search_modules(params["type"]) choice = params.pop_choice( "type", choices=as_registrable.list_available(), default_to_first_choice=False ) subclass, constructor_name = as_registrable.resolve_class_name(choice) if not issubclass(subclass, Step): # This can happen if `choice` is a fully qualified name. raise ConfigurationError( f"Tried to make a Step of type {choice}, but ended up with a {subclass}." ) if issubclass(subclass, FunctionalStep): parameters = infer_method_params(subclass, subclass.WRAPPED_FUNC, infer_kwargs=False) if subclass.BIND: if "self" not in parameters: raise ConfigurationError( f"Functional step for {subclass.WRAPPED_FUNC} is bound but is missing argument 'self'" ) else: del parameters["self"] else: parameters = infer_method_params(subclass, subclass.run, infer_kwargs=False) del parameters["self"] init_parameters = infer_constructor_params(subclass) del init_parameters["self"] del init_parameters["kwargs"] parameter_overlap = parameters.keys() & init_parameters.keys() assert len(parameter_overlap) <= 0, ( f"If this assert fails it means that you wrote a Step with a run() method that takes one of the " f"reserved parameters ({', '.join(init_parameters.keys())})" ) parameters.update(init_parameters) kwargs: Dict[str, Any] = {} accepts_kwargs = False for param_name, param in parameters.items(): if param.kind == param.VAR_KEYWORD: # When a class takes **kwargs we store the fact that the method allows extra keys; if # we get extra parameters, instead of crashing, we'll just pass them as-is to the # constructor, and hope that you know what you're doing. accepts_kwargs = True continue explicitly_set = param_name in params constructed_arg = pop_and_construct_arg( subclass.__name__, param_name, param.annotation, param.default, params, extras ) # If the param wasn't explicitly set in `params` and we just ended up constructing # the default value for the parameter, we can just omit it. # Leaving it in can cause issues with **kwargs in some corner cases, where you might end up # with multiple values for a single parameter (e.g., the default value gives you lazy=False # for a dataset reader inside **kwargs, but a particular dataset reader actually hard-codes # lazy=True - the superclass sees both lazy=True and lazy=False in its constructor). if explicitly_set or constructed_arg is not param.default: kwargs[param_name] = constructed_arg if accepts_kwargs: kwargs.update(params) else: params.assert_empty(subclass.__name__) return subclass(step_name=step_name, step_config=raw_step_config, **kwargs)
[docs] @abstractmethod def run(self, **kwargs) -> T: """ Execute the step's action. This method needs to be implemented when creating a ``Step`` subclass, but it shouldn't be called directly. Instead, call :meth:`result()`. """ raise NotImplementedError()
def _run_with_work_dir(self, workspace: "Workspace", needed_by: Optional["Step"] = None) -> T: if self.work_dir_for_run is not None: raise RuntimeError("You can only run a Step's run() method once at a time.") if self.DETERMINISTIC: random.seed(784507111) self._workspace = workspace if self.cache_results: self.work_dir_for_run = workspace.work_dir(self) dir_for_cleanup = None else: dir_for_cleanup = TemporaryDirectory(prefix=f"{self.unique_id}-", suffix=".step_dir") self.work_dir_for_run = Path(dir_for_cleanup.name) try: self._replace_steps_with_results(self.extra_dependencies, workspace) kwargs = self._replace_steps_with_results(self.kwargs, workspace) self.log_starting(needed_by=needed_by) workspace.step_starting(self) try: result = self.run(**kwargs) result = workspace.step_finished(self, result) except BaseException as e: self.log_failure(e) workspace.step_failed(self, e) raise self.log_finished() return result finally: self._workspace = None self.work_dir_for_run = None if dir_for_cleanup is not None: dir_for_cleanup.cleanup() @property def work_dir(self) -> Path: """ The working directory that a step can use while its ``:meth:run()`` method runs. This is a convenience property for you to call inside your :meth:`run()` method. This directory stays around across restarts. You cannot assume that it is empty when your step runs, but you can use it to store information that helps you restart a step if it got killed half-way through the last time it ran.""" if self.work_dir_for_run is None: raise RuntimeError( "You can only call this method while the step is running with a working directory. " "Did you call '.run()' directly? You should only run a step with '.result()'." ) return self.work_dir_for_run @property def workspace(self) -> "Workspace": """ The :class:`~tango.workspace.Workspace` being used. This is a convenience property for you to call inside your :meth:`run()` method. """ if self._workspace is None: raise RuntimeError( "You can only call this method while the step is running with a workspace. " "Did you call '.run()' directly? You should only run a step with '.result()'." ) return self._workspace @property def config(self) -> Dict[str, Any]: """ The configuration parameters that were used to construct the step. This can be empty if the step was not constructed from a configuration file. """ if self._config is None: raise ValueError(f"No config has been assigned to this step! ('{self.name}')") else: return self._config
[docs] def det_hash_object(self) -> Any: return self.unique_id
@property def resources(self) -> StepResources: """ Defines the minimum compute resources required to run this step. Certain executors require this information in order to allocate resources for each step. You can set this with the ``step_resources`` argument to :class:`Step` or you can override this method to automatically define the required resources. """ return self.step_resources or StepResources() @property def unique_id(self) -> str: """Returns the unique ID for this step. Unique IDs are of the shape ``$class_name-$version-$hash``, where the hash is the hash of the inputs for deterministic steps, and a random string of characters for non-deterministic ones. """ if self.unique_id_cache is None: self.unique_id_cache = self.class_name if self.VERSION is not None: self.unique_id_cache += "-" self.unique_id_cache += self.VERSION self.unique_id_cache += "-" if self.DETERMINISTIC: hash_kwargs = { key: value for key, value in self.kwargs.items() if (key not in self.SKIP_ID_ARGUMENTS) and ( ( key not in self.SKIP_DEFAULT_ARGUMENTS or self.SKIP_DEFAULT_ARGUMENTS[key] != value ) ) } self.unique_id_cache += det_hash( ( (self.format.__class__.__module__, self.format.__class__.__qualname__), self.format.VERSION, hash_kwargs, ) )[:32] else: self.unique_id_cache += det_hash( _random_for_step_names.getrandbits((58**32).bit_length()) )[:32] if self._UNIQUE_ID_SUFFIX is not None: self.unique_id_cache += f"-{self._UNIQUE_ID_SUFFIX}" return self.unique_id_cache
[docs] def __str__(self): return self.unique_id
[docs] def __hash__(self): """ A step's hash is just its unique ID. """ return hash(self.unique_id)
[docs] def __eq__(self, other): """ Determines whether this step is equal to another step. Two steps with the same unique ID are considered identical. """ if isinstance(other, Step): return self.unique_id == other.unique_id else: return False
def _replace_steps_with_results(self, o: Any, workspace: "Workspace"): if isinstance(o, (Step, StepIndexer)): return o.result(workspace=workspace, needed_by=self) elif isinstance(o, Lazy): return Lazy( o._constructor, params=Params( self._replace_steps_with_results(o._params.as_dict(quiet=True), workspace) ), constructor_extras=self._replace_steps_with_results( o._constructor_extras, workspace ), ) elif isinstance(o, WithUnresolvedSteps): return o.construct(workspace) elif isinstance(o, (list, tuple, set)): return o.__class__(self._replace_steps_with_results(i, workspace) for i in o) elif isinstance(o, dict): return { key: self._replace_steps_with_results(value, workspace) for key, value in o.items() } else: return o
[docs] def result( self, workspace: Optional["Workspace"] = None, needed_by: Optional["Step"] = None ) -> T: """Returns the result of this step. If the results are cached, it returns those. Otherwise it runs the step and returns the result from there. If necessary, this method will first produce the results of all steps it depends on.""" if workspace is None: from tango.workspaces import default_workspace workspace = default_workspace from tango.step_info import StepState if not self.cache_results or self not in workspace.step_cache: # Try running the step. It might get completed by a different tango process # if there is a race, so we catch "StepStateError" and check if it's "COMPLETED" # at that point. try: return self._run_with_work_dir(workspace, needed_by=needed_by) except StepStateError as exc: if exc.step_state != StepState.COMPLETED or not self.cache_results: raise elif self not in workspace.step_cache: raise StepStateError( self, exc.step_state, "because it's not found in the cache" ) else: # Step has been completed (and cached) by a different process, so we're done. pass self.log_cache_hit(needed_by=needed_by) return workspace.step_cache[self]
[docs] def ensure_result( self, workspace: Optional["Workspace"] = None, ) -> None: """This makes sure that the result of this step is in the cache. It does not return the result.""" if not self.cache_results: raise RuntimeError( "It does not make sense to call ensure_result() on a step that's not cacheable." ) if workspace is None: from tango.workspaces import default_workspace workspace = default_workspace if self in workspace.step_cache: self.log_cache_hit() else: self.result(workspace)
def _ordered_dependencies(self) -> Iterable["Step"]: def dependencies_internal(o: Any) -> Iterable[Step]: if isinstance(o, Step): yield o elif isinstance(o, Lazy): yield from dependencies_internal(o._params.as_dict(quiet=True)) elif isinstance(o, WithUnresolvedSteps): yield from dependencies_internal(o.args) yield from dependencies_internal(o.kwargs) elif isinstance(o, StepIndexer): yield o.step elif isinstance(o, str): return # Confusingly, str is an Iterable of itself, resulting in infinite recursion. elif isinstance(o, (dict, Params)): yield from dependencies_internal(o.values()) elif isinstance(o, Iterable): yield from itertools.chain(*(dependencies_internal(i) for i in o)) else: return yield from self.extra_dependencies yield from dependencies_internal(self.kwargs.values()) @property def dependencies(self) -> Set["Step"]: """ Returns a set of steps that this step depends on. This does not return recursive dependencies. """ return set(self._ordered_dependencies()) @property def recursive_dependencies(self) -> Set["Step"]: """ Returns a set of steps that this step depends on. This returns recursive dependencies. """ seen = set() steps = list(self.dependencies) while len(steps) > 0: step = steps.pop() if step in seen: continue seen.add(step) steps.extend(step.dependencies) return seen def log_cache_hit(self, needed_by: Optional["Step"] = None) -> None: if needed_by is not None: cli_logger.info( '[green]\N{check mark} Found output for step [bold]"%s"[/bold] in cache ' '(needed by "%s")...[/green]', self.name, needed_by.name, ) else: cli_logger.info( '[green]\N{check mark} Found output for step [bold]"%s"[/] in cache...[/]', self.name, ) def log_starting(self, needed_by: Optional["Step"] = None) -> None: if needed_by is not None: cli_logger.info( '[blue]\N{black circle} Starting step [bold]"%s"[/] (needed by "%s")...[/]', self.name, needed_by.name, ) else: cli_logger.info( '[blue]\N{black circle} Starting step [bold]"%s"[/]...[/]', self.name, ) def log_finished(self, run_name: Optional[str] = None) -> None: if run_name is not None: cli_logger.info( '[green]\N{check mark} Finished run for step [bold]"%s"[/] (%s)[/]', self.name, run_name, ) else: cli_logger.info( '[green]\N{check mark} Finished step [bold]"%s"[/][/]', self.name, ) def log_failure(self, exception: Optional[BaseException] = None) -> None: if exception is not None: log_exception(exception, logger=self.logger) cli_logger.error('[red]\N{ballot x} Step [bold]"%s"[/] failed[/]', self.name)
class FunctionalStep(Step): WRAPPED_FUNC: ClassVar[Callable] BIND: ClassVar[bool] = False @property def class_name(self) -> str: return self.WRAPPED_FUNC.__name__ def run(self, *args, **kwargs): if self.BIND: return self.WRAPPED_FUNC(*args, **kwargs) else: return self.__class__.WRAPPED_FUNC(*args, **kwargs)
[docs]def step( name: Optional[str] = None, *, exist_ok: bool = False, bind: bool = False, deterministic: bool = True, cacheable: Optional[bool] = None, version: Optional[str] = None, format: Format = DillFormat("gz"), skip_id_arguments: Optional[Set[str]] = None, metadata: Optional[Dict[str, Any]] = None, ): """ A decorator to create a :class:`Step` from a function. :param name: A name to register the step under. By default the name of the function is used. :param exist_ok: If True, overwrites any existing step registered under the same ``name``. Else, throws an error if a step is already registered under ``name``. :param bind: If ``True``, the first argument passed to the step function will be the underlying :class:`Step` instance, i.e. the function will be called as an instance method. In this case you must name the first argument 'self' or you will get a :class:`~tango.common.exceptions.ConfigurationError` when instantiating the class. See the :class:`Step` class for an explanation of the other parameters. Example ------- .. testcode:: from tango import step @step(version="001") def add(a: int, b: int) -> int: return a + b @step(bind=True) def bound_step(self) -> None: assert self.work_dir.is_dir() """ def step_wrapper(step_func): @Step.register(name or step_func.__name__, exist_ok=exist_ok) class WrapperStep(FunctionalStep): DETERMINISTIC = deterministic CACHEABLE = cacheable VERSION = version FORMAT = format SKIP_ID_ARGUMENTS = skip_id_arguments or set() METADATA = metadata or {} WRAPPED_FUNC = step_func BIND = bind return WrapperStep return step_wrapper
class StepIndexer(CustomDetHash): def __init__(self, step: Step, key: Union[str, int]): self.step = step self.key = key def result( self, workspace: Optional["Workspace"] = None, needed_by: Optional["Step"] = None ) -> Any: return self.step.result(workspace=workspace, needed_by=needed_by)[self.key] def det_hash_object(self) -> Any: return self.step.unique_id, self.key
[docs]class WithUnresolvedSteps(CustomDetHash): """ This is a helper class for some scenarios where steps depend on other steps. Let's say we have two steps, :class:`ConsumeDataStep` and :class:`ProduceDataStep`. The easiest way to make :class:`ConsumeDataStep` depend on :class:`ProduceDataStep` is to specify ``Produce`` as one of the arguments to the step. This works when ``Consume`` takes the output of ``Produce`` directly, or if it takes it inside standard Python container, like a list, set, or dictionary. But what if the output of :class:`ConsumeDataStep` needs to be added to a complex, custom data structure? :class:`WithUnresolvedSteps` takes care of this scenario. For example, this works without any help: .. code-block:: Python class ProduceDataStep(Step[MyDataClass]): def run(self, ...) -> MyDataClass ... return MyDataClass(...) class ConsumeDataStep(Step): def run(self, input_data: MyDataClass): ... produce = ProduceDataStep() consume = ConsumeDataStep(input_data = produce) This scenario needs help: .. code-block:: Python @dataclass class DataWithTimestamp: data: MyDataClass timestamp: float class ProduceDataStep(Step[MyDataClass]): def run(self, ...) -> MyDataClass ... return MyDataClass(...) class ConsumeDataStep(Step): def run(self, input_data: DataWithTimestamp): ... produce = ProduceDataStep() consume = ConsumeDataStep( input_data = DataWithTimestamp(produce, time.now()) ) That does not work, because :class:`DataWithTimestamp` needs an object of type :class:`MyDataClass`, but we're giving it an object of type :class:`Step[MyDataClass]`. Instead, we change the last line to this: .. code-block:: Python consume = ConsumeDataStep( input_data = WithUnresolvedSteps( DataWithTimestamp, produce, time.now() ) ) :class:`WithUnresolvedSteps` will delay calling the constructor of ``DataWithTimestamp`` until the :meth:`run()` method runs. Tango will make sure that the results from the ``produce`` step are available at that time, and replaces the step in the arguments with the step's results. :param function: The function to call after resolving steps to their results. :param args: The args to pass to the function. These may contain steps, which will be resolved before the function is called. :param kwargs: The kwargs to pass to the function. These may contain steps, which will be resolved before the function is called. """ def __init__(self, function, *args, **kwargs): self.function = function self.args = args self.kwargs = kwargs
[docs] @classmethod def with_resolved_steps( cls, o: Any, workspace: "Workspace", ): """ Recursively goes through a Python object and replaces all instances of :class:`.Step` with the results of that step. :param o: The Python object to go through :param workspace: The workspace in which to resolve all steps :return: A new object that's a copy of the original object, with all instances of :class:`.Step` replaced with the results of the step. """ if isinstance(o, (Step, StepIndexer)): return o.result(workspace=workspace) elif isinstance(o, Lazy): return Lazy( o._constructor, params=Params(cls.with_resolved_steps(o._params.as_dict(quiet=True), workspace)), constructor_extras=cls.with_resolved_steps(o._constructor_extras, workspace), ) elif isinstance(o, cls): return o.construct(workspace) elif isinstance(o, (dict, Params)): return o.__class__( {key: cls.with_resolved_steps(value, workspace) for key, value in o.items()} ) elif isinstance(o, (list, tuple, set)): return o.__class__(cls.with_resolved_steps(item, workspace) for item in o) else: return o
[docs] def construct(self, workspace: "Workspace"): """ Replaces all steps in the args that are stored in this object, and calls the function with those args. :param workspace: The :class:`.Workspace` in which to resolve all the steps. :return: The result of calling the function. """ resolved_args = self.with_resolved_steps(self.args, workspace) resolved_kwargs = self.with_resolved_steps(self.kwargs, workspace) return self.function(*resolved_args, **resolved_kwargs)
[docs] def det_hash_object(self) -> Any: return self.function.__qualname__, self.args, self.kwargs