import logging
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, Optional, TypeVar, Union
from urllib.parse import ParseResult
import pytz
import wandb
from tango.common.exceptions import StepStateError
from tango.common.file_lock import FileLock
from tango.common.util import exception_to_string, tango_cache_dir, utc_now_datetime
from tango.step import Step
from tango.step_cache import StepCache
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, Workspace
from .step_cache import WandbStepCache
from .util import RunKind, check_environment
T = TypeVar("T")
logger = logging.getLogger(__name__)
[docs]@Workspace.register("wandb")
class WandbWorkspace(Workspace):
"""
This is a :class:`~tango.workspace.Workspace` that tracks Tango runs in a W&B project.
It also stores step results as W&B Artifacts via :class:`WandbStepCache`.
Each Tango run with this workspace will generate multiple runs in your W&B project.
There will always be a W&B run corresponding to each Tango run with the same name,
which will contain some metadata about the Tango run. Then there will be one W&B run
for each cacheable step that runs with a name corresponding to the name of the step.
So if your Tango run includes 3 cacheable steps, that will result in a total of 4 new runs in W&B.
:param project: The W&B project to use for the workspace.
:param entity: The W&B entity (user or organization account) to use for the workspace.
.. tip::
Registered as a :class:`~tango.workspace.Workspace` under the name "wandb".
.. tip::
If you want to change the artifact kind for step result artifacts uploaded
to W&B, add a field called ``artifact_kind`` to the ``metadata`` of
the :class:`~tango.step.Step` class.
This can be useful if you want model objects to be added to the model zoo.
In that you would set ``artifact_kind = "model"``.
For example, your config for the step would look like this:
.. code-block::
{ type: "trainer", step_metadata: { artifact_kind: "model" }, ... }
Or just add this to the ``METADATA`` class attribute:
.. code-block::
@Step.register("trainer")
class TrainerStep(Step):
METADATA = {"artifact_kind": "model"}
"""
def __init__(self, project: str, entity: Optional[str] = None):
check_environment()
super().__init__()
self.project = project
self._entity = entity
self.cache = WandbStepCache(project=self.project, entity=self.entity)
self.steps_dir = tango_cache_dir() / "wandb_workspace"
self.locks: Dict[Step, FileLock] = {}
self._running_step_info: Dict[str, StepInfo] = {}
def __getstate__(self):
"""
We override `__getstate__()` to customize how instances of this class are pickled
since we don't want to persist certain attributes.
"""
out = super().__getstate__()
out["locks"] = {}
return out
@property
def wandb_client(self) -> wandb.Api:
overrides = {"project": self.project}
if self._entity is not None:
overrides["entity"] = self._entity
return wandb.Api(overrides=overrides)
@property
def entity(self) -> str:
return self._entity or self.wandb_client.default_entity
@property
def url(self) -> str:
return f"wandb://{self.entity}/{self.project}"
@classmethod
def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:
entity = parsed_url.netloc
project = parsed_url.path
if project:
project = project.strip("/")
return cls(project=project, entity=entity)
@property
def step_cache(self) -> StepCache:
return self.cache
@property
def wandb_project_url(self) -> str:
"""
The URL of the W&B project this workspace uses.
"""
app_url = self.wandb_client.client.app_url
app_url = app_url.rstrip("/")
return f"{app_url}/{self.entity}/{self.project}"
def _get_unique_id(self, step_or_unique_id: Union[Step, str]) -> str:
if isinstance(step_or_unique_id, Step):
unique_id = step_or_unique_id.unique_id
else:
unique_id = step_or_unique_id
return unique_id
def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path:
unique_id = self._get_unique_id(step_or_unique_id)
path = self.steps_dir / unique_id
path.mkdir(parents=True, exist_ok=True)
return path
def work_dir(self, step: Step) -> Path:
path = self.step_dir(step) / "work"
path.mkdir(parents=True, exist_ok=True)
return path
def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:
unique_id = self._get_unique_id(step_or_unique_id)
if unique_id in self._running_step_info:
return self._running_step_info[unique_id]
step_info = self._get_updated_step_info(
unique_id,
step_name=step_or_unique_id.name if isinstance(step_or_unique_id, Step) else None,
)
if step_info is None:
raise KeyError(step_or_unique_id)
else:
return step_info
def step_starting(self, step: Step) -> None:
if wandb.run is not None:
raise RuntimeError(
"There is already a W&B run initialized, cannot initialize another one."
)
work_dir = self.work_dir(step)
lock_path = self.step_dir(step) / "lock"
lock = FileLock(lock_path, read_only_ok=True)
lock.acquire_with_updates(desc=f"acquiring lock for '{step.name}'")
self.locks[step] = lock
step_info = self._get_updated_step_info(step.unique_id) or StepInfo.new_from_step(step)
if step_info.state not in {StepState.INCOMPLETE, StepState.FAILED, StepState.UNCACHEABLE}:
raise StepStateError(
step,
step_info.state,
context="If you are certain the step is not running somewhere else, delete the lock "
f"file at {lock_path}.",
)
try:
# Initialize W&B run for the step.
wandb.init(
name=step_info.step_name,
job_type=RunKind.STEP.value,
group=step.unique_id,
dir=str(work_dir),
entity=self.entity,
project=self.project,
# For cacheable steps we can just use the step's unique ID as the W&B run ID,
# but not for uncacheable steps since those might be ran more than once, and
# and will need a unique W&B run ID each time.
id=step.unique_id if step.cache_results else None,
resume="allow" if step.cache_results else None,
notes="\n".join(
[
f'Tango step "{step.name}"',
f"\N{bullet} type: {step_info.step_class_name}",
f"\N{bullet} ID: {step.unique_id}",
]
),
config={
"job_type": RunKind.STEP.value,
"_run_suite_id": self._generate_run_suite_id(), # used for testing only
},
)
assert wandb.run is not None
logger.info(
"Tracking '%s' step on Weights and Biases: %s/runs/%s/overview",
step.name,
self.wandb_project_url,
wandb.run.id,
)
# "Use" all of the result artifacts for this step's dependencies in order to declare
# those dependencies to W&B.
for dependency in step.dependencies:
self.cache.use_step_result_artifact(dependency)
# Update StepInfo to mark as running.
step_info.start_time = utc_now_datetime()
step_info.end_time = None
step_info.error = None
step_info.result_location = None
wandb.run.config.update({"step_info": step_info.to_json_dict()}, allow_val_change=True)
self._running_step_info[step.unique_id] = step_info
except: # noqa: E722
lock.release()
del self.locks[step]
raise
def step_finished(self, step: Step, result: T) -> T:
if wandb.run is None:
raise RuntimeError(
f"{self.__class__.__name__}.step_finished() called outside of a W&B run. "
f"Did you forget to call {self.__class__.__name__}.step_starting() first?"
)
step_info = self._running_step_info.get(step.unique_id) or self._get_updated_step_info(
step.unique_id
)
if step_info is None:
raise KeyError(step.unique_id)
try:
if step.cache_results:
self.step_cache[step] = result
if hasattr(result, "__next__"):
assert isinstance(result, Iterator)
# Caching the iterator will consume it, so we write it to the
# cache and then read from the cache for the return value.
result = self.step_cache[step]
step_info.result_location = self.cache.get_step_result_artifact_url(step)
else:
# Create an empty artifact in order to build the DAG in W&B.
self.cache.create_step_result_artifact(step)
step_info.end_time = utc_now_datetime()
wandb.run.config.update({"step_info": step_info.to_json_dict()}, allow_val_change=True)
# Finalize the step's W&B run.
wandb.finish()
finally:
self.locks[step].release()
del self.locks[step]
if step.unique_id in self._running_step_info:
del self._running_step_info[step.unique_id]
return result
def step_failed(self, step: Step, e: BaseException) -> None:
if wandb.run is None:
raise RuntimeError(
f"{self.__class__.__name__}.step_failed() called outside of a W&B run. "
f"Did you forget to call {self.__class__.__name__}.step_starting() first?"
)
step_info = self._running_step_info.get(step.unique_id) or self._get_updated_step_info(
step.unique_id
)
if step_info is None:
raise KeyError(step.unique_id)
try:
# Update StepInfo, marking the step as failed.
if step_info.state != StepState.RUNNING:
raise StepStateError(step, step_info.state)
step_info.end_time = utc_now_datetime()
step_info.error = exception_to_string(e)
wandb.run.config.update({"step_info": step_info.to_json_dict()}, allow_val_change=True)
# Finalize the step's W&B run.
wandb.finish(exit_code=1)
finally:
self.locks[step].release()
del self.locks[step]
if step.unique_id in self._running_step_info:
del self._running_step_info[step.unique_id]
def remove_step(self, step_unique_id: str):
"""
Removes cached step using the given unique step id
:raises KeyError: If there is no step with the given name.
"""
raise NotImplementedError()
def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:
all_steps = set(targets)
for step in targets:
all_steps |= step.recursive_dependencies
wandb_run_id: str
wandb_run_name: str
with tempfile.TemporaryDirectory() as temp_dir_name:
with wandb.init( # type: ignore[union-attr]
job_type=RunKind.TANGO_RUN.value,
entity=self.entity,
project=self.project,
name=name,
dir=temp_dir_name,
config={
"job_type": RunKind.TANGO_RUN.value, # need this in the config so we can filter runs by this
"_run_suite_id": self._generate_run_suite_id(), # used for testing only
},
) as wandb_run:
wandb_run_id = wandb_run.id
wandb_run_name = wandb_run.name # type: ignore[assignment]
logger.info("Registering run %s with Weights and Biases", wandb_run.name)
logger.info(
"View run at: %s/runs/%s/overview", self.wandb_project_url, wandb_run_id
)
# Collect step info for all steps.
step_ids: Dict[str, bool] = {}
step_name_to_info: Dict[str, Dict[str, Any]] = {}
for step in all_steps:
step_info = StepInfo.new_from_step(step)
step_name_to_info[step.name] = {
k: v for k, v in step_info.to_json_dict().items() if v is not None
}
step_ids[step.unique_id] = True
# Update config with step info.
wandb_run.config.update({"steps": step_name_to_info, "_step_ids": step_ids})
# Update notes.
notes = "Tango run\n--------------"
cacheable_steps = {step for step in all_steps if step.cache_results}
if cacheable_steps:
notes += "\nCacheable steps:\n"
for step in sorted(cacheable_steps, key=lambda step: step.name):
notes += f"\N{bullet} {step.name}"
dependencies = step.dependencies
if dependencies:
notes += ", depends on: " + ", ".join(
sorted(
[f"'{dep.name}'" for dep in dependencies],
)
)
notes += "\n \N{rightwards arrow with hook} "
notes += f"{self.wandb_project_url}/runs/{step.unique_id}/overview\n"
wandb_run.notes = notes
return self.registered_run(wandb_run_name)
def _generate_run_suite_id(self) -> str:
return wandb.util.generate_id()
def registered_runs(self) -> Dict[str, Run]:
runs: Dict[str, Run] = {}
matching_runs = list(
self.wandb_client.runs(
f"{self.entity}/{self.project}",
filters={"config.job_type": RunKind.TANGO_RUN.value}, # type: ignore
)
)
for wandb_run in matching_runs:
runs[wandb_run.name] = self._get_run_from_wandb_run(wandb_run)
return runs
def registered_run(self, name: str) -> Run:
matching_runs = list(
self.wandb_client.runs(
f"{self.entity}/{self.project}",
filters={"display_name": name, "config.job_type": RunKind.TANGO_RUN.value}, # type: ignore
)
)
if not matching_runs:
raise KeyError(f"Run '{name}' not found in workspace")
elif len(matching_runs) > 1:
raise ValueError(f"Found more than one run named '{name}' in W&B project")
return self._get_run_from_wandb_run(matching_runs[0])
def _get_run_from_wandb_run(
self,
wandb_run: wandb.apis.public.Run,
) -> Run:
step_name_to_info = {}
for step_name, step_info_dict in wandb_run.config["steps"].items():
step_info = StepInfo.from_json_dict(step_info_dict)
if step_info.cacheable:
updated_step_info = self._get_updated_step_info(
step_info.unique_id, step_name=step_name
)
if updated_step_info is not None:
step_info = updated_step_info
step_name_to_info[step_name] = step_info
return Run(
name=wandb_run.name,
steps=step_name_to_info,
start_date=datetime.strptime(wandb_run.created_at, "%Y-%m-%dT%H:%M:%S").replace(
tzinfo=pytz.utc
),
)
def _get_updated_step_info(
self, step_id: str, step_name: Optional[str] = None
) -> Optional[StepInfo]:
# First try to find the W&B run corresponding to the step. This will only
# work if the step execution was started already.
filters = {
"config.job_type": RunKind.STEP.value,
"config.step_info.unique_id": step_id,
}
if step_name is not None:
filters["display_name"] = step_name
for wandb_run in self.wandb_client.runs(
f"{self.entity}/{self.project}",
filters=filters, # type: ignore
):
step_info = StepInfo.from_json_dict(wandb_run.config["step_info"])
# Might need to fix the step info the step failed and we failed to update the config.
if step_info.start_time is None:
step_info.start_time = datetime.strptime(
wandb_run.created_at, "%Y-%m-%dT%H:%M:%S"
).replace(tzinfo=pytz.utc)
if wandb_run.state in {"failed", "finished"}:
if step_info.end_time is None:
step_info.end_time = datetime.strptime(
wandb_run.heartbeatAt, "%Y-%m-%dT%H:%M:%S"
).replace(tzinfo=pytz.utc)
if wandb_run.state == "failed" and step_info.error is None:
step_info.error = "Exception"
return step_info
# If the step hasn't been started yet, we'll have to pull the step info from the
# registered run.
filters = {
"config.job_type": RunKind.TANGO_RUN.value,
f"config._step_ids.{step_id}": True,
}
if step_name is not None:
filters[f"config.steps.{step_name}.unique_id"] = step_id
for wandb_run in self.wandb_client.runs(
f"{self.entity}/{self.project}",
filters=filters, # type: ignore
):
if step_name is not None:
step_info_data = wandb_run.config["steps"][step_name]
else:
step_info_data = next(
d for d in wandb_run.config["steps"].values() if d["unique_id"] == step_id
)
step_info = StepInfo.from_json_dict(step_info_data)
return step_info
return None