import json
import logging
import os
import random
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Type, TypeVar, Union, cast
from urllib.parse import ParseResult
import petname
from beaker import Dataset
from beaker import Dataset as BeakerDataset
from beaker import (
DatasetConflict,
DatasetNotFound,
DatasetSort,
Digest,
Experiment,
ExperimentNotFound,
)
from tango.common.util import make_safe_filename, tango_cache_dir
from tango.step import Step
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace
from tango.workspaces.remote_workspace import RemoteWorkspace
from .common import BeakerStepLock, Constants, dataset_url, get_client
from .step_cache import BeakerStepCache
T = TypeVar("T")
U = TypeVar("U", Run, StepInfo)
logger = logging.getLogger(__name__)
[docs]@Workspace.register("beaker")
class BeakerWorkspace(RemoteWorkspace):
"""
This is a :class:`~tango.workspace.Workspace` that stores step artifacts on `Beaker`_.
.. tip::
Registered as a :class:`~tango.workspace.Workspace` under the name "beaker".
:param workspace: The name or ID of the Beaker workspace to use.
:param kwargs: Additional keyword arguments passed to :meth:`Beaker.from_env() <beaker.Beaker.from_env()>`.
"""
STEP_INFO_CACHE_SIZE = 512
Constants = Constants
NUM_CONCURRENT_WORKERS = 9
def __init__(self, workspace: str, max_workers: Optional[int] = None, **kwargs):
self.beaker = get_client(beaker_workspace=workspace, **kwargs)
self._cache = BeakerStepCache(beaker=self.beaker)
self._locks: Dict[Step, BeakerStepLock] = {}
super().__init__()
self.max_workers = max_workers
self._disk_cache_dir = tango_cache_dir() / "beaker_cache" / "_objects"
self._mem_cache: "OrderedDict[Digest, Union[StepInfo, Run]]" = OrderedDict()
@property
def cache(self):
return self._cache
@property
def locks(self):
return self._locks
@property
def steps_dir_name(self):
return "beaker_workspace"
@property
def url(self) -> str:
return f"beaker://{self.beaker.workspace.get().full_name}"
def _step_location(self, step: Step) -> str:
return dataset_url(self.beaker, self.Constants.step_artifact_name(step))
@classmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:
workspace: str
if parsed_url.netloc and parsed_url.path:
# e.g. "beaker://ai2/my-workspace"
workspace = parsed_url.netloc + parsed_url.path
elif parsed_url.netloc:
# e.g. "beaker://my-workspace"
workspace = parsed_url.netloc
else:
raise ValueError(f"Bad URL for Beaker workspace '{parsed_url}'")
return cls(workspace)
@property
def current_beaker_experiment(self) -> Optional[Experiment]:
"""
When the workspace is being used within a Beaker experiment that was submitted
by the Beaker executor, this will return the `Experiment` object.
"""
experiment_name = os.environ.get("BEAKER_EXPERIMENT_NAME")
if experiment_name is not None:
try:
return self.beaker.experiment.get(experiment_name)
except ExperimentNotFound:
return None
else:
return None
def _remote_lock(self, step: Step) -> BeakerStepLock:
return BeakerStepLock(
self.beaker, step, current_beaker_experiment=self.current_beaker_experiment
)
def _get_object_from_cache(self, digest: Digest, o_type: Type[U]) -> Optional[U]:
cache_path = self._disk_cache_dir / make_safe_filename(str(digest))
if digest in self._mem_cache:
cached = self._mem_cache.pop(digest)
# Move to end.
self._mem_cache[digest] = cached
return cached if isinstance(cached, o_type) else None
elif cache_path.is_file():
try:
with cache_path.open("r+t") as f:
json_dict = json.load(f)
cached = o_type.from_json_dict(json_dict)
except Exception as exc:
logger.warning("Error while loading object from workspace cache: %s", str(exc))
try:
os.remove(cache_path)
except FileNotFoundError:
pass
return None
# Add to in-memory cache.
self._mem_cache[digest] = cached
while len(self._mem_cache) > self.STEP_INFO_CACHE_SIZE:
self._mem_cache.popitem(last=False)
return cached # type: ignore
else:
return None
def _add_object_to_cache(self, digest: Digest, o: U):
self._disk_cache_dir.mkdir(parents=True, exist_ok=True)
cache_path = self._disk_cache_dir / make_safe_filename(str(digest))
self._mem_cache[digest] = o
with cache_path.open("w+t") as f:
json.dump(o.to_json_dict(), f)
while len(self._mem_cache) > self.STEP_INFO_CACHE_SIZE:
self._mem_cache.popitem(last=False)
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
try:
dataset = self.beaker.dataset.get(self.Constants.step_artifact_name(step_or_unique_id))
return self._get_step_info_from_dataset(dataset)
except (DatasetNotFound, FileNotFoundError):
if not isinstance(step_or_unique_id, Step):
raise KeyError(step_or_unique_id)
step_info = StepInfo.new_from_step(step_or_unique_id)
self._update_step_info(step_info)
return step_info
def _get_step_info_from_dataset(self, dataset: Dataset) -> StepInfo:
file_info = self.beaker.dataset.file_info(dataset, Constants.STEP_INFO_FNAME)
step_info: StepInfo
cached = (
None
if file_info.digest is None
else self._get_object_from_cache(file_info.digest, StepInfo)
)
if cached is not None:
step_info = cached
else:
step_info_bytes = self.beaker.dataset.get_file(dataset, file_info, quiet=True)
step_info = StepInfo.from_json_dict(json.loads(step_info_bytes))
if file_info.digest is not None:
self._add_object_to_cache(file_info.digest, step_info)
return step_info
def _save_run(
self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None
) -> Run:
# Create a remote dataset that represents this run. The dataset which just contain
# a JSON file that maps step names to step unique IDs.
run_dataset: BeakerDataset
if name is None:
# Find a unique name to use.
while True:
name = petname.generate() + str(random.randint(0, 100))
try:
run_dataset = self.beaker.dataset.create(
self.Constants.run_artifact_name(cast(str, name)), commit=False
)
except DatasetConflict:
continue
else:
break
else:
try:
run_dataset = self.beaker.dataset.create(
self.Constants.run_artifact_name(name), commit=False
)
except DatasetConflict:
raise ValueError(f"Run name '{name}' is already in use")
# Upload run data to dataset.
# NOTE: We don't commit the dataset here since we'll need to upload the logs file
# after the run.
self.beaker.dataset.upload(
run_dataset, json.dumps(run_data).encode(), self.Constants.RUN_DATA_FNAME, quiet=True
)
return Run(name=cast(str, name), steps=steps, start_date=run_dataset.created)
def registered_runs(self) -> Dict[str, Run]:
import concurrent.futures
runs: Dict[str, Run] = {}
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS,
thread_name_prefix="BeakerWorkspace.registered_runs()-",
) as executor:
run_futures = []
for dataset in self.beaker.workspace.iter_datasets(
match=self.Constants.RUN_ARTIFACT_PREFIX, uncommitted=True, results=False
):
run_futures.append(executor.submit(self._get_run_from_dataset, dataset))
for future in concurrent.futures.as_completed(run_futures):
run = future.result()
if run is not None:
runs[run.name] = run
return runs
def search_registered_runs(
self,
*,
sort_by: Optional[RunSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
) -> List[RunInfo]:
if match is None:
match = Constants.RUN_ARTIFACT_PREFIX
else:
match = Constants.RUN_ARTIFACT_PREFIX + match
if sort_by is None or sort_by == RunSort.START_DATE:
sort = DatasetSort.created
elif sort_by == RunSort.NAME:
sort = DatasetSort.dataset_name
else:
raise NotImplementedError
runs = []
for dataset in self.beaker.workspace.iter_datasets(
match=match,
results=False,
cursor=start or 0,
limit=None if stop is None else stop - (start or 0),
sort_by=sort,
descending=sort_descending,
):
if dataset.name is not None and dataset.name.startswith(
self.Constants.RUN_ARTIFACT_PREFIX
):
run_name = dataset.name[len(self.Constants.RUN_ARTIFACT_PREFIX) :]
runs.append(RunInfo(name=run_name, start_date=dataset.created))
return runs
def num_registered_runs(self, *, match: Optional[str] = None) -> int:
if match is None:
match = Constants.RUN_ARTIFACT_PREFIX
else:
match = Constants.RUN_ARTIFACT_PREFIX + match
count = 0
for dataset in self.beaker.workspace.iter_datasets(
match=match,
results=False,
):
if dataset.name is not None and dataset.name.startswith(Constants.RUN_ARTIFACT_PREFIX):
count += 1
return count
def search_step_info(
self,
*,
sort_by: Optional[StepInfoSort] = None,
sort_descending: bool = True,
match: Optional[str] = None,
state: Optional[StepState] = None,
start: int = 0,
stop: Optional[int] = None,
) -> List[StepInfo]:
if state is not None:
raise NotImplementedError(
f"{self.__class__.__name__} cannot filter steps efficiently by state"
)
if match is None:
match = Constants.STEP_ARTIFACT_PREFIX
else:
match = Constants.STEP_ARTIFACT_PREFIX + match
sort: Optional[DatasetSort] = None
if sort_by is None or sort_by == StepInfoSort.START_TIME:
sort = DatasetSort.created
elif sort_by == StepInfoSort.UNIQUE_ID:
sort = DatasetSort.dataset_name
elif sort_by is not None:
raise NotImplementedError
steps = []
for dataset in self.beaker.workspace.iter_datasets(
match=match,
results=False,
cursor=start or 0,
limit=None if stop is None else stop - (start or 0),
sort_by=sort or DatasetSort.created,
descending=sort_descending,
):
try:
steps.append(self._get_step_info_from_dataset(dataset))
except (DatasetNotFound, FileNotFoundError):
continue
return steps
def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int:
if state is not None:
raise NotImplementedError(
f"{self.__class__.__name__} cannot filter steps efficiently by state"
)
if match is None:
match = Constants.STEP_ARTIFACT_PREFIX
else:
match = Constants.STEP_ARTIFACT_PREFIX + match
count = 0
for dataset in self.beaker.workspace.iter_datasets(
match=match,
results=False,
):
if dataset.name is not None and dataset.name.startswith(Constants.STEP_ARTIFACT_PREFIX):
count += 1
return count
def registered_run(self, name: str) -> Run:
err_msg = f"Run '{name}' not found in workspace"
try:
dataset_for_run = self.beaker.dataset.get(self.Constants.run_artifact_name(name))
# Make sure the run is in our workspace.
if dataset_for_run.workspace_ref.id != self.beaker.workspace.get().id: # type: ignore # TODO
raise DatasetNotFound
except DatasetNotFound:
raise KeyError(err_msg)
run = self._get_run_from_dataset(dataset_for_run)
if run is None:
raise KeyError(err_msg)
else:
return run
def _save_run_log(self, name: str, log_file: Path):
run_dataset = self.Constants.run_artifact_name(name)
self.beaker.dataset.sync(run_dataset, log_file, quiet=True)
self.beaker.dataset.commit(run_dataset)
def _get_run_from_dataset(self, dataset: BeakerDataset) -> Optional[Run]:
if dataset.name is None:
return None
if not dataset.name.startswith(self.Constants.RUN_ARTIFACT_PREFIX):
return None
try:
run_name = dataset.name[len(self.Constants.RUN_ARTIFACT_PREFIX) :]
steps_info_bytes = self.beaker.dataset.get_file(
dataset, self.Constants.RUN_DATA_FNAME, quiet=True
)
steps_info = json.loads(steps_info_bytes)
except (DatasetNotFound, FileNotFoundError):
return None
steps: Dict[str, StepInfo] = {}
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS,
thread_name_prefix="BeakerWorkspace._get_run_from_dataset()-",
) as executor:
step_info_futures = []
for unique_id in steps_info.values():
step_info_futures.append(executor.submit(self.step_info, unique_id))
for future in concurrent.futures.as_completed(step_info_futures):
step_info = future.result()
assert step_info.step_name is not None
steps[step_info.step_name] = step_info
return Run(name=run_name, start_date=dataset.created, steps=steps)
def _update_step_info(self, step_info: StepInfo):
dataset_name = self.Constants.step_artifact_name(step_info)
step_info_dataset: BeakerDataset
try:
self.beaker.dataset.create(dataset_name, commit=False)
except DatasetConflict:
pass
step_info_dataset = self.beaker.dataset.get(dataset_name)
self.beaker.dataset.upload(
step_info_dataset, # folder name
json.dumps(step_info.to_json_dict()).encode(), # step info dict.
self.Constants.STEP_INFO_FNAME, # step info filename
quiet=True,
)
def _remove_step_info(self, step_info: StepInfo) -> None:
# remove dir from beaker workspace
dataset_name = self.Constants.step_artifact_name(step_info)
step_dataset = self.beaker.dataset.get(dataset_name)
if step_dataset is not None:
self.beaker.dataset.delete(step_dataset)