Source code for tango.integrations.gs.step_cache

import logging
from pathlib import Path
from typing import Optional, Union

from tango.common import PathOrStr
from tango.common.util import make_safe_filename, tango_cache_dir
from tango.integrations.gs.common import (
    Constants,
    GSArtifact,
    GSArtifactConflict,
    GSArtifactNotFound,
    GSArtifactWriteError,
    GSClient,
    get_bucket_and_prefix,
)
from tango.step import Step
from tango.step_cache import StepCache
from tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache
from tango.step_info import StepInfo

logger = logging.getLogger(__name__)


[docs]@StepCache.register("gs") class GSStepCache(RemoteStepCache): """ This is a :class:`~tango.step_cache.StepCache` that's used by :class:`GSWorkspace`. It stores the results of steps on Google cloud buckets as blobs. It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a step's resulting subsequent times should be fast. .. tip:: Registered as a :class:`~tango.step_cache.StepCache` under the name "gs". :param folder_name: The name of the google cloud bucket folder to use. :param client: The google cloud storage client to use. """ Constants = Constants def __init__(self, folder_name: str, client: Optional[GSClient] = None): if client is not None: bucket_name, _ = get_bucket_and_prefix(folder_name) assert ( bucket_name == client.bucket_name ), "Assert that bucket name is same as client bucket until we do better" self.folder_name = folder_name self._client = client else: self._client = GSClient(folder_name) super().__init__(tango_cache_dir() / "gs_cache" / make_safe_filename(folder_name)) @property def client(self): return self._client def _step_result_remote(self, step: Union[Step, StepInfo]) -> Optional[GSArtifact]: """ Returns a `GSArtifact` object containing the details of the step. This only returns if the step has been finalized (committed). """ try: artifact = self.client.get(self.Constants.step_artifact_name(step)) return artifact if artifact.committed else None except GSArtifactNotFound: return None def _upload_step_remote(self, step: Step, objects_dir: Path) -> GSArtifact: """ Uploads the step's output to remote location. """ artifact_name = self.Constants.step_artifact_name(step) try: self.client.create(artifact_name) except GSArtifactConflict: pass try: self.client.upload(artifact_name, objects_dir) self.client.commit(artifact_name) except GSArtifactWriteError: pass return self.client.get(artifact_name) def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None: """ Downloads the step's output from remote location. """ try: self.client.download(step_result, target_dir) except GSArtifactNotFound: raise RemoteNotFoundError() def __len__(self): """ Returns the number of committed step outputs present in the remote location. """ # NOTE: lock files should not count here. return sum( 1 for ds in self.client.artifacts( prefix=self.Constants.STEP_ARTIFACT_PREFIX, uncommitted=False ) if ds.name is not None and ds.name.startswith(self.Constants.STEP_ARTIFACT_PREFIX) and not ds.name.endswith(self.Constants.LOCK_ARTIFACT_SUFFIX) )