import json
import random
from pathlib import Path
from typing import (
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from urllib.parse import ParseResult
import petname
from google.auth.credentials import Credentials
from google.cloud import datastore
from tango.common.util import utc_now_datetime
from tango.integrations.gs.common import (
Constants,
GCSStepLock,
get_bucket_and_prefix,
get_client,
get_credentials,
)
from tango.integrations.gs.step_cache import GSStepCache
from tango.step import Step
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace
from tango.workspaces.remote_workspace import RemoteWorkspace
T = TypeVar("T")
[docs]@Workspace.register("gs")
class GSWorkspace(RemoteWorkspace):
"""
This is a :class:`~tango.workspace.Workspace` that stores step artifacts on Google Cloud Storage.
.. tip::
Registered as a :class:`~tango.workspace.Workspace` under the name "gs".
:param workspace: The name or ID of the Google Cloud bucket folder to use.
:param project: The Google project ID. This is required for the datastore. If not provided,
it will be inferred from the Google cloud credentials.
.. important::
Credentials can be provided in the following ways:
- Using the `credentials` keyword argument:
- You can specify the path to the credentials json file.
- You can specify the `google.oauth2.credentials.Credentials()` object.
- You can specify the json string of credentials dict.
- Using the default credentials: You can use your default google cloud credentials by running
`gcloud auth application-default login`. If you are using `GSWorkspace` with
:class:`~tango.integrations.beaker.BeakerExecutor`, you will need to set the environment variable
`GOOGLE_TOKEN` to the credentials json file. The default location is usually
`~/.config/gcloud/application_default_credentials.json`.
"""
Constants = Constants
NUM_CONCURRENT_WORKERS = 32
def __init__(
self,
workspace: str,
project: Optional[str] = None,
credentials: Optional[Union[str, Credentials]] = None,
):
credentials = get_credentials(credentials)
self.client = get_client(folder_name=workspace, credentials=credentials, project=project)
self.client.NUM_CONCURRENT_WORKERS = self.NUM_CONCURRENT_WORKERS
self._cache = GSStepCache(workspace, client=self.client)
self._locks: Dict[Step, GCSStepLock] = {}
super().__init__()
project = project or self.client.storage.project or credentials.quota_project_id
self.bucket_name, self.prefix = get_bucket_and_prefix(workspace)
self._ds = datastore.Client(
namespace=self.bucket_name, project=project, credentials=credentials
)
@property
def cache(self):
return self._cache
@property
def locks(self):
return self._locks
@property
def steps_dir_name(self):
return "gs_workspace"
@classmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:
workspace: str
if parsed_url.netloc and parsed_url.path:
# e.g. "gs://ai2/my-workspace"
workspace = parsed_url.netloc + parsed_url.path
elif parsed_url.netloc:
# e.g. "gs://my-workspace"
workspace = parsed_url.netloc
else:
raise ValueError(f"Bad URL for GS workspace '{parsed_url}'")
return cls(workspace)
@property
def url(self) -> str:
return self.client.url()
def _remote_lock(self, step: Step) -> GCSStepLock:
return GCSStepLock(self.client, step)
def _step_location(self, step: Step) -> str:
return self.client.url(self.Constants.step_artifact_name(step))
@property
def _run_key(self):
return self.client._gs_path("run")
@property
def _stepinfo_key(self):
return self.client._gs_path("stepinfo")
def _save_run(
self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None
) -> Run:
if name is None:
while True:
name = petname.generate() + str(random.randint(0, 100))
if not self._ds.get(self._ds.key(self._run_key, name)):
break
else:
if self._ds.get(self._ds.key(self._run_key, name)):
raise ValueError(f"Run name '{name}' is already in use")
run_entity = self._ds.entity(
key=self._ds.key(self._run_key, name), exclude_from_indexes=("steps",)
)
# Even though the run's name is part of the key, we add this as a
# field so we can index on it and order asc/desc (indices on the key field don't allow ordering).
run_entity["name"] = name
run_entity["start_date"] = utc_now_datetime()
run_entity["steps"] = json.dumps(run_data).encode()
self._ds.put(run_entity)
return Run(name=cast(str, name), steps=steps, start_date=run_entity["start_date"])
def _get_run_from_entity(self, run_entity: datastore.Entity) -> Optional[Run]:
try:
steps_info_bytes = run_entity["steps"]
steps_info = json.loads(steps_info_bytes)
except KeyError:
return None
import concurrent.futures
steps: Dict[str, StepInfo] = {}
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS,
thread_name_prefix="GSWorkspace._get_run_from_dataset()-",
) as executor:
step_info_futures = []
for unique_id in steps_info.values():
step_info_futures.append(executor.submit(self.step_info, unique_id))
for future in concurrent.futures.as_completed(step_info_futures):
step_info = future.result()
assert step_info.step_name is not None
steps[step_info.step_name] = step_info
return Run(name=run_entity.key.name, start_date=run_entity["start_date"], steps=steps)
def registered_runs(self) -> Dict[str, Run]:
import concurrent.futures
runs: Dict[str, Run] = {}
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS,
thread_name_prefix="GSWorkspace.registered_runs()-",
) as executor:
run_futures = []
for run_entity in self._ds.query(kind=self._run_key).fetch():
run_futures.append(executor.submit(self._get_run_from_entity, run_entity))
for future in concurrent.futures.as_completed(run_futures):
run = future.result()
if run is not None:
runs[run.name] = run
return runs
def search_registered_runs(
self,
*,
sort_by: Optional[RunSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
start: int = 0,
stop: Optional[int] = None,
) -> List[RunInfo]:
run_entities = self._fetch_run_entities(
sort_by=sort_by, sort_descending=sort_descending, match=match, start=start, stop=stop
)
return [
RunInfo(name=e.key.name, start_date=e["start_date"], steps=json.loads(e["steps"]))
for e in run_entities
]
def num_registered_runs(self, *, match: Optional[str] = None) -> int:
count = 0
for _ in self._fetch_run_entities(match=match):
count += 1
return count
def _fetch_run_entities(
self,
*,
sort_by: Optional[RunSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
start: int = 0,
stop: Optional[int] = None,
) -> Generator[datastore.Entity, None, None]:
from itertools import islice
# Note: we can't query or order by multiple fields without a suitable
# composite index. So in that case we have to apply remaining filters
# or slice and order locally. We'll default to using 'match' in the query.
# But if 'match' is null we can sort with the query.
sort_locally = bool(match)
sort_field: Optional[str] = None
if sort_by == RunSort.START_DATE:
sort_field = "start_date"
elif sort_by == RunSort.NAME:
sort_field = "name"
elif sort_by is not None:
raise NotImplementedError(sort_by)
order: List[str] = []
if sort_field is not None and not sort_locally:
order = [sort_field if not sort_descending else f"-{sort_field}"]
query = self._ds.query(kind=self._run_key, order=order)
if match:
# HACK: Datastore has no direct string matching functionality,
# but this comparison is equivalent to checking if 'name' starts with 'match'.
query.add_filter("name", ">=", match)
query.add_filter("name", "<=", match[:-1] + chr(ord(match[-1]) + 1))
entity_iter: Iterable[datastore.Entity] = query.fetch(
offset=0 if sort_locally else start,
limit=None if (stop is None or sort_locally) else stop - start,
)
if sort_field is not None and sort_locally:
entity_iter = sorted(
entity_iter, key=lambda entity: entity[sort_field], reverse=sort_descending
)
if sort_locally:
entity_iter = islice(entity_iter, start, stop)
for entity in entity_iter:
yield entity
def search_step_info(
self,
*,
sort_by: Optional[StepInfoSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
state: Optional[StepState] = None,
start: int = 0,
stop: Optional[int] = None,
) -> List[StepInfo]:
step_info_entities = self._fetch_step_info_entities(
sort_by=sort_by,
sort_descending=sort_descending,
match=match,
state=state,
start=start,
stop=stop,
)
return [
StepInfo.from_json_dict(json.loads(e["step_info_dict"])) for e in step_info_entities
]
def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int:
count = 0
for _ in self._fetch_step_info_entities(match=match, state=state):
count += 1
return count
def _fetch_step_info_entities(
self,
*,
sort_by: Optional[StepInfoSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
state: Optional[StepState] = None,
start: int = 0,
stop: Optional[int] = None,
) -> Generator[datastore.Entity, None, None]:
from itertools import islice
# Note: we can't query or order by multiple fields without a suitable
# composite index. So in that case we have to apply remaining filters
# or slice and order locally. We'll default to using 'match' in the query.
# But if 'match' is null, we'll use 'state' to filter in the query.
# If 'state' is also null, we can sort with the query.
sort_locally = sort_by is not None and (match is not None or state is not None)
filter_locally = state is not None and match is not None
slice_locally = sort_locally or filter_locally
sort_field: Optional[str] = None
if sort_by == StepInfoSort.START_TIME:
sort_field = "start_time"
elif sort_by == StepInfoSort.UNIQUE_ID:
sort_field = "step_id"
elif sort_by is not None:
raise NotImplementedError(sort_by)
order: List[str] = []
if sort_field is not None and not sort_locally:
order = [sort_field if not sort_descending else f"-{sort_field}"]
query = self._ds.query(kind=self._stepinfo_key, order=order)
if match is not None:
# HACK: Datastore has no direct string matching functionality,
# but this comparison is equivalent to checking if 'step_id' starts with 'match'.
query.add_filter("step_id", ">=", match)
query.add_filter("step_id", "<=", match[:-1] + chr(ord(match[-1]) + 1))
elif state is not None and not filter_locally:
query.add_filter("state", "=", str(state.value))
entity_iter: Iterable[datastore.Entity] = query.fetch(
offset=0 if slice_locally else start,
limit=None if (stop is None or slice_locally) else stop - start,
)
if state is not None and filter_locally:
entity_iter = filter(lambda entity: entity["state"] == state, entity_iter)
if sort_field is not None and sort_locally:
entity_iter = sorted(
entity_iter, key=lambda entity: entity[sort_field], reverse=sort_descending
)
if slice_locally:
entity_iter = islice(entity_iter, start, stop)
for entity in entity_iter:
yield entity
def registered_run(self, name: str) -> Run:
err_msg = f"Run '{name}' not found in workspace"
run_entity = self._ds.get(key=self._ds.key(self._run_key, name))
if not run_entity:
raise KeyError(err_msg)
run = self._get_run_from_entity(run_entity)
if run is None:
raise KeyError(err_msg)
else:
return run
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
unique_id = (
step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id
)
step_info_entity = self._ds.get(key=self._ds.key(self._stepinfo_key, unique_id))
if step_info_entity is not None:
step_info_bytes = step_info_entity["step_info_dict"]
step_info = StepInfo.from_json_dict(json.loads(step_info_bytes))
return step_info
else:
if not isinstance(step_or_unique_id, Step):
raise KeyError(step_or_unique_id)
step_info = StepInfo.new_from_step(step_or_unique_id)
self._update_step_info(step_info)
return step_info
def _step_info_multiple(
self, step_or_unique_ids: Union[List[Step], List[str]]
) -> List[StepInfo]:
"""
This method is to combine all calls to the datastore api in a single transaction.
"""
all_unique_id_keys = []
for step_or_unique_id in step_or_unique_ids:
unique_id = (
step_or_unique_id
if isinstance(step_or_unique_id, str)
else step_or_unique_id.unique_id
)
key = self._ds.key(self._stepinfo_key, unique_id)
all_unique_id_keys.append(key)
missing: List = []
step_info_entities = self._ds.get_multi(keys=all_unique_id_keys, missing=missing)
missing_steps = [entity.key.name for entity in missing]
step_infos = []
for step_info_entity in step_info_entities:
step_info_bytes = step_info_entity["step_info_dict"]
step_info = StepInfo.from_json_dict(json.loads(step_info_bytes))
step_infos.append(step_info)
for step_or_unique_id in step_or_unique_ids:
step_id = (
step_or_unique_id
if isinstance(step_or_unique_id, str)
else step_or_unique_id.unique_id
)
if step_id in missing_steps:
if not isinstance(step_or_unique_id, Step):
raise KeyError(step_or_unique_id)
step_info = StepInfo.new_from_step(step_or_unique_id)
self._update_step_info(step_info)
step_infos.append(step_info)
return step_infos
def _get_run_step_info(self, targets: Iterable[Step]) -> Tuple[Dict, Dict]:
all_steps = set(targets)
for step in targets:
all_steps |= step.recursive_dependencies
steps: Dict[str, StepInfo] = {}
run_data: Dict[str, str] = {}
all_valid_steps = [step for step in all_steps if step.name is not None]
step_infos = self._step_info_multiple(all_valid_steps)
for step_info in step_infos:
assert step_info.step_name is not None
steps[step_info.step_name] = step_info
run_data[step_info.step_name] = step_info.unique_id
return steps, run_data
def _update_step_info(self, step_info: StepInfo):
step_info_entity = self._ds.entity(
key=self._ds.key(self._stepinfo_key, step_info.unique_id),
exclude_from_indexes=("step_info_dict",),
)
# Even though the step's unique ID is part of the key, we add this as a
# field so we can index on it and order asc/desc (indices on the key field don't allow ordering).
step_info_entity["step_id"] = step_info.unique_id
step_info_entity["step_name"] = step_info.step_name
step_info_entity["start_time"] = step_info.start_time
step_info_entity["end_time"] = step_info.end_time
step_info_entity["state"] = str(step_info.state.value)
step_info_entity["updated"] = utc_now_datetime()
step_info_entity["step_info_dict"] = json.dumps(step_info.to_json_dict()).encode()
self._ds.put(step_info_entity)
def _remove_step_info(self, step_info: StepInfo) -> None:
# remove dir from bucket
step_artifact = self.client.get(self.Constants.step_artifact_name(step_info))
if step_artifact is not None:
self.client.delete(step_artifact)
# remove datastore entities
self._ds.delete(key=self._ds.key("stepinfo", step_info.unique_id))
def _save_run_log(self, name: str, log_file: Path):
"""
The logs are stored in the bucket. The Run object details are stored in
the remote database.
"""
run_dataset = self.Constants.run_artifact_name(name)
self.client.upload(run_dataset, log_file)