import copy
import json
import logging
import os
import zlib
from collections import OrderedDict
from collections.abc import MutableMapping
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Set, TypeVar, Union
import yaml
from rjsonnet import evaluate_file, evaluate_snippet
from .aliases import PathOrStr
from .exceptions import ConfigurationError
from .util import could_be_class_name
logger = logging.getLogger(__name__)
def infer_and_cast(value: Any):
"""
In some cases we'll be feeding params dicts to functions we don't own;
for example, PyTorch optimizers. In that case we can't use ``pop_int``
or similar to force casts (which means you can't specify ``int`` parameters
using environment variables). This function takes something that looks JSON-like
and recursively casts things that look like (bool, int, float) to (bool, int, float).
"""
if isinstance(value, (int, float, bool)):
# Already one of our desired types, so leave as is.
return value
elif isinstance(value, list):
# Recursively call on each list element.
return [infer_and_cast(item) for item in value]
elif isinstance(value, dict):
# Recursively call on each dict value.
return {key: infer_and_cast(item) for key, item in value.items()}
elif isinstance(value, str):
# If it looks like a bool, make it a bool.
if value.lower() == "true":
return True
elif value.lower() == "false":
return False
else:
# See if it could be an int.
try:
return int(value)
except ValueError:
pass
# See if it could be a float.
try:
return float(value)
except ValueError:
# Just return it as a string.
return value
else:
raise ValueError(f"cannot infer type of {value}")
def _is_encodable(value: str) -> bool:
"""
We need to filter out environment variables that can't
be unicode-encoded to avoid a "surrogates not allowed"
error in jsonnet.
"""
# Idiomatically you'd like to not check the != b""
# but mypy doesn't like that.
return (value == "") or (value.encode("utf-8", "ignore") != b"")
def _environment_variables() -> Dict[str, str]:
"""
Wraps ``os.environ`` to filter out non-encodable values.
"""
return {key: value for key, value in os.environ.items() if _is_encodable(value)}
T = TypeVar("T", dict, list)
def with_overrides(original: T, overrides_dict: Dict[str, Any], prefix: str = "") -> T:
merged: T
keys: Union[Iterable[str], Iterable[int]]
if isinstance(original, list):
merged = [None] * len(original)
keys = range(len(original))
elif isinstance(original, dict):
merged = {}
keys = chain(
original.keys(), (k for k in overrides_dict if "." not in k and k not in original)
)
else:
if prefix:
raise ValueError(
f"overrides for '{prefix[:-1]}.*' expected list or dict in original, "
f"found {type(original)} instead"
)
else:
raise ValueError(f"expected list or dict, found {type(original)} instead")
used_override_keys: Set[str] = set()
for key in keys:
if str(key) in overrides_dict:
merged[key] = copy.deepcopy(overrides_dict[str(key)])
used_override_keys.add(str(key))
else:
overrides_subdict = {}
for o_key in overrides_dict:
if o_key.startswith(f"{key}."):
overrides_subdict[o_key[len(f"{key}.") :]] = overrides_dict[o_key]
used_override_keys.add(o_key)
if overrides_subdict:
merged[key] = with_overrides(
original[key], overrides_subdict, prefix=prefix + f"{key}."
)
else:
merged[key] = copy.deepcopy(original[key])
unused_override_keys = [prefix + key for key in set(overrides_dict.keys()) - used_override_keys]
if unused_override_keys:
raise ValueError(f"overrides dict contains unused keys: {unused_override_keys}")
return merged
def parse_overrides(
serialized_overrides: str, ext_vars: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
if serialized_overrides:
ext_vars = {**_environment_variables(), **(ext_vars or {})}
return json.loads(evaluate_snippet("", serialized_overrides, ext_vars=ext_vars))
else:
return {}
def _is_dict_free(obj: Any) -> bool:
"""
Returns False if obj is a dict, or if it's a list with an element that _has_dict.
"""
if isinstance(obj, dict):
return False
elif isinstance(obj, list):
return all(_is_dict_free(item) for item in obj)
else:
return True
def pop_choice(
params: Dict[str, Any],
key: str,
choices: List[Any],
default_to_first_choice: bool = False,
history: str = "?.",
allow_class_names: bool = True,
) -> Any:
"""
Performs the same function as ``Params.pop_choice``, but is required in order to deal with
places that the Params object is not welcome, such as inside Keras layers. See the docstring
of that method for more detail on how this function works.
This method adds a ``history`` parameter, in the off-chance that you know it, so that we can
reproduce ``Params.pop_choice`` exactly. We default to using "?." if you don't know the
history, so you'll have to fix that in the log if you want to actually recover the logged
parameters.
"""
value = Params(params, history).pop_choice(
key, choices, default_to_first_choice, allow_class_names=allow_class_names
)
return value
def _replace_none(params: Any) -> Any:
if isinstance(params, str) and params == "None":
return None
elif isinstance(params, (dict, Params)):
if isinstance(params, Params):
params = params.as_dict(quiet=True)
for key, value in params.items():
params[key] = _replace_none(value)
return params
elif isinstance(params, list):
return [_replace_none(value) for value in params]
return params
def remove_keys_from_params(params: "Params", keys: List[str] = ["pretrained_file", "initializer"]):
if isinstance(params, Params): # The model could possibly be a string, for example.
param_keys = params.keys()
for key in keys:
if key in param_keys:
del params[key]
for value in params.values():
if isinstance(value, Params):
remove_keys_from_params(value, keys)
elif isinstance(value, list):
for item in value:
if isinstance(item, Params):
remove_keys_from_params(item, keys)
[docs]class Params(MutableMapping):
"""
A :class:`~collections.abc.MutableMapping` that represents a parameter dictionary with a history,
and contains other functionality around parameter passing and validation for AI2 Tango.
There are currently two benefits of a ``Params`` object over a plain dictionary for parameter
passing:
1. We handle a few kinds of parameter validation, including making sure that parameters
representing discrete choices actually have acceptable values, and making sure no extra
parameters are passed.
2. We log all parameter reads, including default values. This gives a more complete
specification of the actual parameters used than is given in a JSON file, because
those may not specify what default values were used, whereas this will log them.
.. important::
The convention for using a ``Params`` object in Tango is that you will consume the parameters
as you read them, so that there are none left when you've read everything you expect. This
lets us easily validate that you didn't pass in any ``extra`` parameters, just by making sure
that the parameter dictionary is empty. You should do this when you're done handling
parameters, by calling :meth:`Params.assert_empty()`.
"""
# This allows us to check for the presence of "None" as a default argument,
# which we require because we make a distinction between passing a value of "None"
# and passing no value to the default parameter of "pop".
DEFAULT = object()
def __init__(self, params: "MutableMapping[str, Any]", history: str = "") -> None:
if isinstance(params, Params):
self.params: MutableMapping = params.params
else:
self.params = _replace_none(params)
self.history = history
[docs] def pop(self, key: str, default: Any = DEFAULT, keep_as_dict: bool = False) -> Any:
"""
Performs the functionality associated with ``dict.pop(key)``, along with checking for
returned dictionaries, replacing them with Param objects with an updated history
(unless keep_as_dict is True, in which case we leave them as dictionaries).
If ``key`` is not present in the dictionary, and no default was specified, we raise a
:class:`~tango.common.exceptions.ConfigurationError`, instead of the typical ``KeyError``.
"""
if default is self.DEFAULT:
try:
value = self.params.pop(key)
except KeyError:
msg = f'key "{key}" is required'
if self.history:
msg += f' at location "{self.history}"'
raise ConfigurationError(msg)
else:
value = self.params.pop(key, default)
logger.debug(f"{self.history}{key} = {value}")
if keep_as_dict or _is_dict_free(value):
return value
else:
return self._check_is_dict(key, value)
[docs] def pop_int(self, key: str, default: Any = DEFAULT) -> Optional[int]:
"""
Performs a pop and coerces to an int.
"""
value = self.pop(key, default)
if value is None:
return None
else:
return int(value)
[docs] def pop_float(self, key: str, default: Any = DEFAULT) -> Optional[float]:
"""
Performs a pop and coerces to a float.
"""
value = self.pop(key, default)
if value is None:
return None
else:
return float(value)
[docs] def pop_bool(self, key: str, default: Any = DEFAULT) -> Optional[bool]:
"""
Performs a pop and coerces to a bool.
"""
value = self.pop(key, default)
if value is None:
return None
elif isinstance(value, bool):
return value
elif value == "true":
return True
elif value == "false":
return False
else:
raise ValueError("Cannot convert variable to bool: " + value)
[docs] def get(self, key: str, default: Any = DEFAULT):
"""
Performs the functionality associated with ``dict.get(key)`` but also checks for returned
dicts and returns a ``Params`` object in their place with an updated history.
"""
default = None if default is self.DEFAULT else default
value = self.params.get(key, default)
return self._check_is_dict(key, value)
[docs] def pop_choice(
self,
key: str,
choices: List[Any],
default_to_first_choice: bool = False,
allow_class_names: bool = True,
) -> Any:
"""
Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of
the given choices. Note that this ``pops`` the key from params, modifying the dictionary,
consistent with how parameters are processed in this codebase.
:param key:
Key to get the value from in the param dictionary
:param choices:
A list of valid options for values corresponding to ``key``. For example, if you're
specifying the type of encoder to use for some part of your model, the choices might be
the list of encoder classes we know about and can instantiate. If the value we find in
the param dictionary is not in ``choices``, we raise a
:class:`~tango.common.exceptions.ConfigurationError`, because
the user specified an invalid value in their parameter file.
:param default_to_first_choice:
If this is ``True``, we allow the ``key`` to not be present in the parameter
dictionary. If the key is not present, we will use the return as the value the first
choice in the ``choices`` list. If this is ``False``, we raise a
:class:`~tango.common.exceptions.ConfigurationError`, because
specifying the ``key`` is required (e.g., you ``have`` to
specify your model class when running an experiment, but you can feel free to use
default settings for encoders if you want).
:param allow_class_names:
If this is ``True``, then we allow unknown choices that look like fully-qualified class names.
This is to allow e.g. specifying a model type as ``my_library.my_model.MyModel``
and importing it on the fly. Our check for "looks like" is extremely lenient
and consists of checking that the value contains a '.'.
"""
default = choices[0] if default_to_first_choice else self.DEFAULT
value = self.pop(key, default)
ok_because_class_name = allow_class_names and could_be_class_name(value)
if value not in choices and not ok_because_class_name:
key_str = self.history + key
message = (
f"'{value}' not in acceptable choices for {key_str}: {choices}. "
"You should either use the --include-package flag to make sure the correct module "
"is loaded, or use a fully qualified class name in your config file like "
"""{"model": "my_module.models.MyModel"} to have it imported automatically."""
)
raise ConfigurationError(message)
return value
[docs] def as_dict(self, quiet: bool = False, infer_type_and_cast: bool = False):
"""
Sometimes we need to just represent the parameters as a dict, for instance when we pass
them to PyTorch code.
:param quiet:
Whether to log the parameters before returning them as a dict.
:param infer_type_and_cast:
If ``True``, we infer types and cast (e.g. things that look like floats to floats).
"""
if infer_type_and_cast:
params_as_dict = infer_and_cast(self.params)
else:
params_as_dict = self.params
if quiet:
return params_as_dict
def log_recursively(parameters, history):
for key, value in parameters.items():
if isinstance(value, dict):
new_local_history = history + key + "."
log_recursively(value, new_local_history)
else:
logger.debug(f"{history}{key} = {value}")
log_recursively(self.params, self.history)
return params_as_dict
[docs] def as_flat_dict(self) -> Dict[str, Any]:
"""
Returns the parameters of a flat dictionary from keys to values.
Nested structure is collapsed with periods.
"""
flat_params = {}
def recurse(parameters, path):
for key, value in parameters.items():
newpath = path + [key]
if isinstance(value, dict):
recurse(value, newpath)
else:
flat_params[".".join(newpath)] = value
recurse(self.params, [])
return flat_params
[docs] def duplicate(self) -> "Params":
"""
Uses ``copy.deepcopy()`` to create a duplicate (but fully distinct)
copy of these Params.
"""
return copy.deepcopy(self)
[docs] def assert_empty(self, name: str):
"""
Raises a :class:`~tango.common.exceptions.ConfigurationError` if ``self.params`` is not empty.
We take ``name`` as an argument so that the error message gives some idea of where an error
happened, if there was one. For example, ``name`` could be the name of the ``calling`` class
that got extra parameters (if there are any).
"""
if self.params:
raise ConfigurationError("Extra parameters passed to {}: {}".format(name, self.params))
def __getitem__(self, key):
if key in self.params:
return self._check_is_dict(key, self.params[key])
else:
raise KeyError(str(key))
def __setitem__(self, key, value):
self.params[key] = value
def __delitem__(self, key):
del self.params[key]
def __iter__(self):
return iter(self.params)
def __len__(self):
return len(self.params)
def _check_is_dict(self, new_history, value):
if isinstance(value, dict):
new_history = self.history + new_history + "."
return Params(value, history=new_history)
if isinstance(value, list):
value = [self._check_is_dict(f"{new_history}.{i}", v) for i, v in enumerate(value)]
return value
[docs] @classmethod
def from_file(
cls,
params_file: PathOrStr,
params_overrides: Union[str, Dict[str, Any]] = "",
ext_vars: Optional[dict] = None,
) -> "Params":
"""
Load a ``Params`` object from a configuration file.
:param params_file:
The path to the configuration file to load. Can be JSON, Jsonnet, or YAML.
:param params_overrides:
A dict of overrides that can be applied to final object.
e.g. ``{"model.embedding_dim": 10}`` will change the value of "embedding_dim"
within the "model" object of the config to 10. If you wanted to override the entire
"model" object of the config, you could do ``{"model": {"type": "other_type", ...}}``.
:param ext_vars:
Our config files are Jsonnet, which allows specifying external variables
for later substitution. Typically we substitute these using environment
variables; however, you can also specify them here, in which case they
take priority over environment variables.
e.g. ``{"HOME_DIR": "/Users/allennlp/home"}``
"""
if ext_vars is None:
ext_vars = {}
# redirect to cache, if necessary
from cached_path import cached_path
params_file: Path = Path(cached_path(params_file))
if not params_file.is_file():
raise FileNotFoundError(params_file)
file_dict: Dict[str, Any]
if params_file.suffix in {".yml", ".yaml"}:
with open(params_file) as f:
file_dict = yaml.safe_load(f)
else:
# Fall back to JSON/Jsonnet.
ext_vars = {**_environment_variables(), **ext_vars}
json_str = evaluate_file(params_file.name, str(params_file.parent), ext_vars=ext_vars)
file_dict = json.loads(json_str)
if isinstance(params_overrides, dict):
params_overrides = json.dumps(params_overrides)
overrides_dict = parse_overrides(params_overrides, ext_vars=ext_vars)
if overrides_dict:
param_dict = with_overrides(file_dict, overrides_dict)
else:
param_dict = file_dict
return cls(param_dict)
[docs] def to_file(
self, params_file: PathOrStr, preference_orders: Optional[List[List[str]]] = None
) -> None:
"""
Write the params to file.
"""
with open(params_file, "w") as handle:
json.dump(self.as_ordered_dict(preference_orders), handle, indent=4)
[docs] def as_ordered_dict(self, preference_orders: Optional[List[List[str]]] = None) -> OrderedDict:
"""
Returns an ``OrderedDict`` of ``Params`` from list of partial order preferences.
:param preference_orders:
``preference_orders`` is list of partial preference orders. ["A", "B", "C"] means
"A" > "B" > "C". For multiple preference_orders first will be considered first.
Keys not found, will have last but alphabetical preference. Default Preferences:
``[["dataset_reader", "iterator", "model", "train_data_path", "validation_data_path",
"test_data_path", "trainer", "vocabulary"], ["type"]]``
"""
params_dict = self.as_dict(quiet=True)
if not preference_orders:
preference_orders = []
preference_orders.append(["type"])
def order_func(key):
# Makes a tuple to use for ordering. The tuple is an index into each of the `preference_orders`,
# followed by the key itself. This gives us integer sorting if you have a key in one of the
# `preference_orders`, followed by alphabetical ordering if not.
order_tuple = [
order.index(key) if key in order else len(order) for order in preference_orders # type: ignore
]
return order_tuple + [key]
def order_dict(dictionary, order_func):
# Recursively orders dictionary according to scoring order_func
result = OrderedDict()
for key, val in sorted(dictionary.items(), key=lambda item: order_func(item[0])):
result[key] = order_dict(val, order_func) if isinstance(val, dict) else val
return result
return order_dict(params_dict, order_func)
[docs] def get_hash(self) -> str:
"""
Returns a hash code representing the current state of this ``Params`` object. We don't
want to implement ``__hash__`` because that has deeper python implications (and this is a
mutable object), but this will give you a representation of the current state.
We use ``zlib.adler32`` instead of Python's builtin ``hash`` because the random seed for the
latter is reset on each new program invocation, as discussed here:
https://stackoverflow.com/questions/27954892/deterministic-hashing-in-python-3.
"""
dumped = json.dumps(self.params, sort_keys=True)
hashed = zlib.adler32(dumped.encode())
return str(hashed)
def __str__(self) -> str:
return f"{self.history}Params({self.params})"