Source code for tango.steps.shell_step

import os
import subprocess
from typing import List, Optional, Union

from tango.common import PathOrStr, RegistrableFunction, make_registrable
from tango.step import Step


@make_registrable(exist_ok=True)
def check_path_existence(path: PathOrStr):
    assert os.path.exists(path), f"Output not found at {path}!"


[docs]@Step.register("shell_step") class ShellStep(Step): """ This step runs a shell command, and returns the standard output as a string. .. tip:: Registered as a :class:`~tango.step.Step` under the name "shell_step". :param shell_command: The shell command to run. :param output_path: The step makes no assumptions about the command being run. If your command produces some output, you can optionally specify the output path for recording the output location, and optionally validating it. See `validate_output` argument for this. :param validate_output: If an expected `output_path` has been specified, you can choose to validate that the step produced the correct output. By default, it will just check if the `output_path` exists, but you can pass any other validating function. For example, if your command is a script generating a model output, you can check if the model weights can be loaded. :param kwargs: Other kwargs to be passed to `subprocess.run()`. If you need to take advantage of environment variables, set `shell = True`. """
[docs] def run( # type: ignore[override] self, shell_command: Union[str, List[str]], output_path: Optional[PathOrStr] = None, validate_output: RegistrableFunction = check_path_existence, **kwargs, ): output = self.run_command(shell_command, **kwargs) self.logger.info(output) if output_path is not None: validate_output(output_path) self.logger.info(f"Output found at: {output_path}") return str(output.decode("utf-8"))
def run_command(self, command: Union[str, List[str]], **kwargs): import shlex if kwargs.get("shell", None): if isinstance(command, list): command = shlex.join(command) else: if isinstance(command, str): command = shlex.split(command) self.logger.info(f"Command: {command}") process = subprocess.run(command, capture_output=True, **kwargs) if process.returncode != 0: raise RuntimeError(f"The command failed with the following errors: {process.stderr}") return process.stdout