Source code for tango.common.params

import copy
import json
import logging
import os
import zlib
from collections import OrderedDict
from import MutableMapping
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Set, TypeVar, Union

import yaml
from rjsonnet import evaluate_file, evaluate_snippet

from .aliases import PathOrStr
from .exceptions import ConfigurationError
from .util import could_be_class_name

logger = logging.getLogger(__name__)

def infer_and_cast(value: Any):
    In some cases we'll be feeding params dicts to functions we don't own;
    for example, PyTorch optimizers. In that case we can't use ``pop_int``
    or similar to force casts (which means you can't specify ``int`` parameters
    using environment variables). This function takes something that looks JSON-like
    and recursively casts things that look like (bool, int, float) to (bool, int, float).

    if isinstance(value, (int, float, bool)):
        # Already one of our desired types, so leave as is.
        return value
    elif isinstance(value, list):
        # Recursively call on each list element.
        return [infer_and_cast(item) for item in value]
    elif isinstance(value, dict):
        # Recursively call on each dict value.
        return {key: infer_and_cast(item) for key, item in value.items()}
    elif isinstance(value, str):
        # If it looks like a bool, make it a bool.
        if value.lower() == "true":
            return True
        elif value.lower() == "false":
            return False
            # See if it could be an int.
                return int(value)
            except ValueError:
            # See if it could be a float.
                return float(value)
            except ValueError:
                # Just return it as a string.
                return value
        raise ValueError(f"cannot infer type of {value}")

def _is_encodable(value: str) -> bool:
    We need to filter out environment variables that can't
    be unicode-encoded to avoid a "surrogates not allowed"
    error in jsonnet.
    # Idiomatically you'd like to not check the != b""
    # but mypy doesn't like that.
    return (value == "") or (value.encode("utf-8", "ignore") != b"")

def _environment_variables() -> Dict[str, str]:
    Wraps ``os.environ`` to filter out non-encodable values.
    return {key: value for key, value in os.environ.items() if _is_encodable(value)}

T = TypeVar("T", dict, list)

def with_overrides(original: T, overrides_dict: Dict[str, Any], prefix: str = "") -> T:
    merged: T
    keys: Union[Iterable[str], Iterable[int]]
    if isinstance(original, list):
        merged = [None] * len(original)
        keys = range(len(original))
    elif isinstance(original, dict):
        merged = {}
        keys = chain(
            original.keys(), (k for k in overrides_dict if "." not in k and k not in original)
        if prefix:
            raise ValueError(
                f"overrides for '{prefix[:-1]}.*' expected list or dict in original, "
                f"found {type(original)} instead"
            raise ValueError(f"expected list or dict, found {type(original)} instead")

    used_override_keys: Set[str] = set()
    for key in keys:
        if str(key) in overrides_dict:
            merged[key] = copy.deepcopy(overrides_dict[str(key)])
            overrides_subdict = {}
            for o_key in overrides_dict:
                if o_key.startswith(f"{key}."):
                    overrides_subdict[o_key[len(f"{key}.") :]] = overrides_dict[o_key]
            if overrides_subdict:
                merged[key] = with_overrides(
                    original[key], overrides_subdict, prefix=prefix + f"{key}."
                merged[key] = copy.deepcopy(original[key])

    unused_override_keys = [prefix + key for key in set(overrides_dict.keys()) - used_override_keys]
    if unused_override_keys:
        raise ValueError(f"overrides dict contains unused keys: {unused_override_keys}")

    return merged

def parse_overrides(
    serialized_overrides: str, ext_vars: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
    if serialized_overrides:
        ext_vars = {**_environment_variables(), **(ext_vars or {})}

        return json.loads(evaluate_snippet("", serialized_overrides, ext_vars=ext_vars))
        return {}

def _is_dict_free(obj: Any) -> bool:
    Returns False if obj is a dict, or if it's a list with an element that _has_dict.
    if isinstance(obj, dict):
        return False
    elif isinstance(obj, list):
        return all(_is_dict_free(item) for item in obj)
        return True

def pop_choice(
    params: Dict[str, Any],
    key: str,
    choices: List[Any],
    default_to_first_choice: bool = False,
    history: str = "?.",
    allow_class_names: bool = True,
) -> Any:
    Performs the same function as ``Params.pop_choice``, but is required in order to deal with
    places that the Params object is not welcome, such as inside Keras layers.  See the docstring
    of that method for more detail on how this function works.

    This method adds a ``history`` parameter, in the off-chance that you know it, so that we can
    reproduce ``Params.pop_choice`` exactly.  We default to using "?." if you don't know the
    history, so you'll have to fix that in the log if you want to actually recover the logged
    value = Params(params, history).pop_choice(
        key, choices, default_to_first_choice, allow_class_names=allow_class_names
    return value

def _replace_none(params: Any) -> Any:
    if isinstance(params, str) and params == "None":
        return None
    elif isinstance(params, (dict, Params)):
        if isinstance(params, Params):
            params = params.as_dict(quiet=True)
        for key, value in params.items():
            params[key] = _replace_none(value)
        return params
    elif isinstance(params, list):
        return [_replace_none(value) for value in params]
    return params

def remove_keys_from_params(params: "Params", keys: List[str] = ["pretrained_file", "initializer"]):
    if isinstance(params, Params):  # The model could possibly be a string, for example.
        param_keys = params.keys()
        for key in keys:
            if key in param_keys:
                del params[key]
        for value in params.values():
            if isinstance(value, Params):
                remove_keys_from_params(value, keys)
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, Params):
                        remove_keys_from_params(item, keys)

[docs]class Params(MutableMapping): """ A :class:`` that represents a parameter dictionary with a history, and contains other functionality around parameter passing and validation for AI2 Tango. There are currently two benefits of a ``Params`` object over a plain dictionary for parameter passing: 1. We handle a few kinds of parameter validation, including making sure that parameters representing discrete choices actually have acceptable values, and making sure no extra parameters are passed. 2. We log all parameter reads, including default values. This gives a more complete specification of the actual parameters used than is given in a JSON file, because those may not specify what default values were used, whereas this will log them. .. important:: The convention for using a ``Params`` object in Tango is that you will consume the parameters as you read them, so that there are none left when you've read everything you expect. This lets us easily validate that you didn't pass in any ``extra`` parameters, just by making sure that the parameter dictionary is empty. You should do this when you're done handling parameters, by calling :meth:`Params.assert_empty()`. """ # This allows us to check for the presence of "None" as a default argument, # which we require because we make a distinction between passing a value of "None" # and passing no value to the default parameter of "pop". DEFAULT = object() def __init__(self, params: "MutableMapping[str, Any]", history: str = "") -> None: if isinstance(params, Params): self.params: MutableMapping = params.params else: self.params = _replace_none(params) self.history = history
[docs] def pop(self, key: str, default: Any = DEFAULT, keep_as_dict: bool = False) -> Any: """ Performs the functionality associated with ``dict.pop(key)``, along with checking for returned dictionaries, replacing them with Param objects with an updated history (unless keep_as_dict is True, in which case we leave them as dictionaries). If ``key`` is not present in the dictionary, and no default was specified, we raise a :class:`~tango.common.exceptions.ConfigurationError`, instead of the typical ``KeyError``. """ if default is self.DEFAULT: try: value = self.params.pop(key) except KeyError: msg = f'key "{key}" is required' if self.history: msg += f' at location "{self.history}"' raise ConfigurationError(msg) else: value = self.params.pop(key, default) logger.debug(f"{self.history}{key} = {value}") if keep_as_dict or _is_dict_free(value): return value else: return self._check_is_dict(key, value)
[docs] def pop_int(self, key: str, default: Any = DEFAULT) -> Optional[int]: """ Performs a pop and coerces to an int. """ value = self.pop(key, default) if value is None: return None else: return int(value)
[docs] def pop_float(self, key: str, default: Any = DEFAULT) -> Optional[float]: """ Performs a pop and coerces to a float. """ value = self.pop(key, default) if value is None: return None else: return float(value)
[docs] def pop_bool(self, key: str, default: Any = DEFAULT) -> Optional[bool]: """ Performs a pop and coerces to a bool. """ value = self.pop(key, default) if value is None: return None elif isinstance(value, bool): return value elif value == "true": return True elif value == "false": return False else: raise ValueError("Cannot convert variable to bool: " + value)
[docs] def get(self, key: str, default: Any = DEFAULT): """ Performs the functionality associated with ``dict.get(key)`` but also checks for returned dicts and returns a ``Params`` object in their place with an updated history. """ default = None if default is self.DEFAULT else default value = self.params.get(key, default) return self._check_is_dict(key, value)
[docs] def pop_choice( self, key: str, choices: List[Any], default_to_first_choice: bool = False, allow_class_names: bool = True, ) -> Any: """ Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of the given choices. Note that this ``pops`` the key from params, modifying the dictionary, consistent with how parameters are processed in this codebase. :param key: Key to get the value from in the param dictionary :param choices: A list of valid options for values corresponding to ``key``. For example, if you're specifying the type of encoder to use for some part of your model, the choices might be the list of encoder classes we know about and can instantiate. If the value we find in the param dictionary is not in ``choices``, we raise a :class:`~tango.common.exceptions.ConfigurationError`, because the user specified an invalid value in their parameter file. :param default_to_first_choice: If this is ``True``, we allow the ``key`` to not be present in the parameter dictionary. If the key is not present, we will use the return as the value the first choice in the ``choices`` list. If this is ``False``, we raise a :class:`~tango.common.exceptions.ConfigurationError`, because specifying the ``key`` is required (e.g., you ``have`` to specify your model class when running an experiment, but you can feel free to use default settings for encoders if you want). :param allow_class_names: If this is ``True``, then we allow unknown choices that look like fully-qualified class names. This is to allow e.g. specifying a model type as ``my_library.my_model.MyModel`` and importing it on the fly. Our check for "looks like" is extremely lenient and consists of checking that the value contains a '.'. """ default = choices[0] if default_to_first_choice else self.DEFAULT value = self.pop(key, default) ok_because_class_name = allow_class_names and could_be_class_name(value) if value not in choices and not ok_because_class_name: key_str = self.history + key message = ( f"'{value}' not in acceptable choices for {key_str}: {choices}. " "You should either use the --include-package flag to make sure the correct module " "is loaded, or use a fully qualified class name in your config file like " """{"model": "my_module.models.MyModel"} to have it imported automatically.""" ) raise ConfigurationError(message) return value
[docs] def as_dict(self, quiet: bool = False, infer_type_and_cast: bool = False): """ Sometimes we need to just represent the parameters as a dict, for instance when we pass them to PyTorch code. :param quiet: Whether to log the parameters before returning them as a dict. :param infer_type_and_cast: If ``True``, we infer types and cast (e.g. things that look like floats to floats). """ if infer_type_and_cast: params_as_dict = infer_and_cast(self.params) else: params_as_dict = self.params if quiet: return params_as_dict def log_recursively(parameters, history): for key, value in parameters.items(): if isinstance(value, dict): new_local_history = history + key + "." log_recursively(value, new_local_history) else: logger.debug(f"{history}{key} = {value}") log_recursively(self.params, self.history) return params_as_dict
[docs] def as_flat_dict(self) -> Dict[str, Any]: """ Returns the parameters of a flat dictionary from keys to values. Nested structure is collapsed with periods. """ flat_params = {} def recurse(parameters, path): for key, value in parameters.items(): newpath = path + [key] if isinstance(value, dict): recurse(value, newpath) else: flat_params[".".join(newpath)] = value recurse(self.params, []) return flat_params
[docs] def duplicate(self) -> "Params": """ Uses ``copy.deepcopy()`` to create a duplicate (but fully distinct) copy of these Params. """ return copy.deepcopy(self)
[docs] def assert_empty(self, name: str): """ Raises a :class:`~tango.common.exceptions.ConfigurationError` if ``self.params`` is not empty. We take ``name`` as an argument so that the error message gives some idea of where an error happened, if there was one. For example, ``name`` could be the name of the ``calling`` class that got extra parameters (if there are any). """ if self.params: raise ConfigurationError("Extra parameters passed to {}: {}".format(name, self.params))
def __getitem__(self, key): if key in self.params: return self._check_is_dict(key, self.params[key]) else: raise KeyError(str(key)) def __setitem__(self, key, value): self.params[key] = value def __delitem__(self, key): del self.params[key] def __iter__(self): return iter(self.params) def __len__(self): return len(self.params) def _check_is_dict(self, new_history, value): if isinstance(value, dict): new_history = self.history + new_history + "." return Params(value, history=new_history) if isinstance(value, list): value = [self._check_is_dict(f"{new_history}.{i}", v) for i, v in enumerate(value)] return value
[docs] @classmethod def from_file( cls, params_file: PathOrStr, params_overrides: Union[str, Dict[str, Any]] = "", ext_vars: Optional[dict] = None, ) -> "Params": """ Load a ``Params`` object from a configuration file. :param params_file: The path to the configuration file to load. Can be JSON, Jsonnet, or YAML. :param params_overrides: A dict of overrides that can be applied to final object. e.g. ``{"model.embedding_dim": 10}`` will change the value of "embedding_dim" within the "model" object of the config to 10. If you wanted to override the entire "model" object of the config, you could do ``{"model": {"type": "other_type", ...}}``. :param ext_vars: Our config files are Jsonnet, which allows specifying external variables for later substitution. Typically we substitute these using environment variables; however, you can also specify them here, in which case they take priority over environment variables. e.g. ``{"HOME_DIR": "/Users/allennlp/home"}`` """ if ext_vars is None: ext_vars = {} # redirect to cache, if necessary from cached_path import cached_path params_file: Path = Path(cached_path(params_file)) if not params_file.is_file(): raise FileNotFoundError(params_file) file_dict: Dict[str, Any] if params_file.suffix in {".yml", ".yaml"}: with open(params_file) as f: file_dict = yaml.safe_load(f) else: # Fall back to JSON/Jsonnet. ext_vars = {**_environment_variables(), **ext_vars} json_str = evaluate_file(, str(params_file.parent), ext_vars=ext_vars) file_dict = json.loads(json_str) if isinstance(params_overrides, dict): params_overrides = json.dumps(params_overrides) overrides_dict = parse_overrides(params_overrides, ext_vars=ext_vars) if overrides_dict: param_dict = with_overrides(file_dict, overrides_dict) else: param_dict = file_dict return cls(param_dict)
[docs] def to_file( self, params_file: PathOrStr, preference_orders: Optional[List[List[str]]] = None ) -> None: """ Write the params to file. """ with open(params_file, "w") as handle: json.dump(self.as_ordered_dict(preference_orders), handle, indent=4)
[docs] def as_ordered_dict(self, preference_orders: Optional[List[List[str]]] = None) -> OrderedDict: """ Returns an ``OrderedDict`` of ``Params`` from list of partial order preferences. :param preference_orders: ``preference_orders`` is list of partial preference orders. ["A", "B", "C"] means "A" > "B" > "C". For multiple preference_orders first will be considered first. Keys not found, will have last but alphabetical preference. Default Preferences: ``[["dataset_reader", "iterator", "model", "train_data_path", "validation_data_path", "test_data_path", "trainer", "vocabulary"], ["type"]]`` """ params_dict = self.as_dict(quiet=True) if not preference_orders: preference_orders = [] preference_orders.append(["type"]) def order_func(key): # Makes a tuple to use for ordering. The tuple is an index into each of the `preference_orders`, # followed by the key itself. This gives us integer sorting if you have a key in one of the # `preference_orders`, followed by alphabetical ordering if not. order_tuple = [ order.index(key) if key in order else len(order) for order in preference_orders # type: ignore ] return order_tuple + [key] def order_dict(dictionary, order_func): # Recursively orders dictionary according to scoring order_func result = OrderedDict() for key, val in sorted(dictionary.items(), key=lambda item: order_func(item[0])): result[key] = order_dict(val, order_func) if isinstance(val, dict) else val return result return order_dict(params_dict, order_func)
[docs] def get_hash(self) -> str: """ Returns a hash code representing the current state of this ``Params`` object. We don't want to implement ``__hash__`` because that has deeper python implications (and this is a mutable object), but this will give you a representation of the current state. We use ``zlib.adler32`` instead of Python's builtin ``hash`` because the random seed for the latter is reset on each new program invocation, as discussed here: """ dumped = json.dumps(self.params, sort_keys=True) hashed = zlib.adler32(dumped.encode()) return str(hashed)
def __str__(self) -> str: return f"{self.history}Params({self.params})"