Source code for plasmapy.utils.decorators.validators

"""
Various decorators to validate input/output arguments to functions.
"""

__all__ = ["validate_class_attributes", "validate_quantities", "ValidateQuantities"]

import functools
import inspect
import warnings
from collections.abc import Iterable
from typing import Any

import astropy.units as u

from plasmapy.utils.decorators.checks import CheckUnits, CheckValues
from plasmapy.utils.decorators.helpers import preserve_signature


[docs] class ValidateQuantities(CheckUnits, CheckValues): """ A decorator class to 'validate' -- control and convert -- the units and values of input and return arguments to a function or method. Arguments are expected to be astropy :class:`~astropy.units.quantity.Quantity` objects. Parameters ---------- validations_on_return: dictionary of validation specifications Specifications for unit and value validations on the return of the function being wrapped. (see `quantity validations`_ for valid specifications. **validations: dictionary of validation specifications Specifications for unit and value validations on the input arguments of the function being wrapped. Each keyword argument in ``validations`` is the name of a function argument to be validated and the keyword value contains the unit and value validation specifications. .. _`quantity validations`: Unit and value validations can be defined by passing one of the astropy :mod:`~astropy.units`, a list of astropy units, or a dictionary containing the keys defined below. Units can also be defined with function annotations, but must be consistent with decorator ``**validations`` arguments if used concurrently. If a key is omitted, then the default value will be assumed. ====================== ======= ================================================ Key Type Description ====================== ======= ================================================ units list of desired astropy :mod:`~astropy.units` equivalencies | [DEFAULT `None`] A list of equivalent pairs to try if | the units are not directly convertible. | (see :mod:`~astropy.units.equivalencies`, and/or `astropy equivalencies`_) pass_equivalent_units `bool` | [DEFAULT `False`] allow equivalent units | to pass can_be_negative `bool` [DEFAULT `True`] values can be negative can_be_complex `bool` [DEFAULT `False`] values can be complex numbers can_be_inf `bool` [DEFAULT `True`] values can be :data:`~numpy.inf` can_be_nan `bool` [DEFAULT `True`] values can be :data:`~numpy.nan` none_shall_pass `bool` [DEFAULT `False`] values can be a python `None` can_be_zero `bool` [DEFAULT `True`] values can be zero ====================== ======= ================================================ Notes ----- * Validation of function arguments ``*args`` and ``**kwargs`` is not supported. * `None` values will pass when `None` is included in the list of specified units, is set as a default value for the function argument, or ``none_shall_pass`` is set to `True`. If ``none_shall_pass`` is doubly/triply defined through the mentioned options, then they all must be consistent with each other. * If units are not specified in ``validations``, then the decorator will attempt to identify desired units by examining the function annotations. Examples -------- Define unit and value validations with decorator parameters:: import astropy.units as u from plasmapy.utils.decorators import ValidateQuantities @ValidateQuantities( mass={"units": u.g, "can_be_negative": False}, vel=u.cm / u.s, validations_on_return=[u.g * u.cm / u.s, u.kg * u.m / u.s], ) def foo(mass, vel): return mass * vel # on a method class Foo: @ValidateQuantities( mass={"units": u.g, "can_be_negative": False}, vel=u.cm / u.s, validations_on_return=[u.g * u.cm / u.s, u.kg * u.m / u.s], ) def bar(self, mass, vel): return mass * vel Define units with function annotations:: @ValidateQuantities(mass={"can_be_negative": False}) def foo(mass: u.g, vel: u.cm / u.s) -> u.g * u.cm / u.s: return mass * vel # on a method class Foo: @ValidateQuantities(mass={"can_be_negative": False}) def bar(self, mass: u.g, vel: u.cm / u.s) -> u.g * u.cm / u.s: return mass * vel Allow `None` values to pass:: @ValidateQuantities(checks_on_return=[u.cm, None]) def foo(arg1: u.cm = None): return arg1 Allow return values to have equivalent units:: @ValidateQuantities( arg1={"units": u.cm}, checks_on_return={"units": u.km, "pass_equivalent_units": True}, ) def foo(arg1): return arg1 Allow equivalent units to pass with specified equivalencies:: @ValidateQuantities( arg1={ "units": u.K, "equivalencies": u.temperature(), "pass_equivalent_units": True, } ) def foo(arg1): return arg1 .. _astropy equivalencies: https://docs.astropy.org/en/stable/units/equivalencies.html """ def __init__( self, validations_on_return=None, **validations: dict[str, Any] ) -> None: if "checks_on_return" in validations: raise TypeError( "keyword argument 'checks_on_return' is not allowed, " "use 'validations_on_return' to set validations " "on the return variable" ) self._validations = validations checks = validations.copy() if validations_on_return is not None: self._validations["validations_on_return"] = validations_on_return checks["checks_on_return"] = validations_on_return super().__init__(**checks)
[docs] def __call__(self, f): """ Decorate a function. Parameters ---------- f Function to be wrapped Returns ------- function wrapped function of ``f`` """ self.f = f wrapped_sign = inspect.signature(f) @preserve_signature @functools.wraps(f) def wrapper(*args, **kwargs): # combine args and kwargs into dictionary bound_args = wrapped_sign.bind(*args, **kwargs) bound_args.apply_defaults() # get conditioned validations validations = self._get_validations(bound_args) # validate (input) argument units and values for arg_name in validations: # skip check of output/return if arg_name == "validations_on_return": continue # validate argument & update for conversion arg = self._validate_quantity( bound_args.arguments[arg_name], arg_name, validations[arg_name] ) bound_args.arguments[arg_name] = arg # call function _return = f(**bound_args.arguments) # validate output if "validations_on_return" in validations: _return = self._validate_quantity( _return, "validations_on_return", validations["validations_on_return"], ) return _return return wrapper
def _get_validations( self, bound_args: inspect.BoundArguments ) -> dict[str, dict[str, Any]]: """ Review :attr:`validations` and function bound arguments to build a complete 'validations' dictionary. If a validation key is omitted from the argument validations, then a default value is assumed (see `quantity validations`_). Parameters ---------- bound_args: :class:`inspect.BoundArguments` arguments passed into the function being wrapped .. code-block:: python bound_args = inspect.signature(f).bind(*args, **kwargs) Returns ------- Dict[str, Dict[str, Any]] A complete 'validations' dictionary for validating function input arguments and return. """ unit_checks = self._get_unit_checks(bound_args) value_checks = self._get_value_checks(bound_args) # combine all validations # * `unit_checks` will encompass all argument "checks" defined either by # function annotations or **validations. # * `value_checks` may miss some arguments if **validations only defines # unit validations or some validations come from function annotations validations = unit_checks.copy() for arg_name in validations: # augment 'none_shall_pass' (if needed) try: # if 'none_shall_pass' was in the original passed-in validations, # then override the value determined by CheckUnits _none_shall_pass = self.validations[arg_name]["none_shall_pass"] # if validations[arg_name]['none_shall_pass'] != _none_shall_pass: if ( _none_shall_pass is False and validations[arg_name]["none_shall_pass"] is True ): raise ValueError( f"Validation 'none_shall_pass' for argument '{arg_name}' is " f"inconsistent between function annotations " f"({validations[arg_name]['none_shall_pass']}) and decorator " f"argument ({_none_shall_pass})." ) validations[arg_name]["none_shall_pass"] = _none_shall_pass except (KeyError, TypeError): # 'none_shall_pass' was not in the original passed-in validations, so # rely on the value determined by CheckUnits pass finally: try: del value_checks[arg_name]["none_shall_pass"] except KeyError: dvc = self._CheckValues__check_defaults.copy() del dvc["none_shall_pass"] value_checks[arg_name] = dvc # update the validations dictionary validations[arg_name].update(value_checks[arg_name]) if "checks_on_return" in validations: validations["validations_on_return"] = validations.pop("checks_on_return") return validations def _validate_quantity( # noqa: C901, PLR0912 self, arg, arg_name: str, arg_validations: dict[str, Any], ): """ Perform validations `arg_validations` on function argument `arg` named `arg_name`. Parameters ---------- arg The argument to be validated. arg_name: str The name of the argument to be validated arg_validations: Dict[str, Any] The requested validations for the argument Raises ------ TypeError if argument is not an Astropy :class:`~astropy.units.Quantity` or not convertible to a :class:`~astropy.units.Quantity` ValueError if validations fail """ # rename to work with "check" methods if arg_name == "validations_on_return": arg_name = "checks_on_return" # initialize str for error message if arg_name == "checks_on_return": err_msg = "The return value " else: err_msg = f"The argument '{arg_name}' " err_msg += f"to function {self.f.__name__}()" # initialize TypeError message typeerror_msg = ( f"{err_msg} should be an astropy Quantity with units" f" equivalent to one of [" ) for ii, unit in enumerate(arg_validations["units"]): typeerror_msg += f"{unit}" if ii != len(arg_validations["units"]) - 1: typeerror_msg += ", " typeerror_msg += "]" # add units to arg if possible # * a None value will be taken care of by `_check_unit_core` # if arg is None or hasattr(arg, "unit"): pass elif len(arg_validations["units"]) != 1: raise TypeError(typeerror_msg) else: try: arg = arg * arg_validations["units"][0] except (TypeError, ValueError) as ex: raise TypeError(typeerror_msg) from ex else: warnings.warn( u.UnitsWarning( f"{err_msg} has no specified units. Assuming units of " f"{arg_validations['units'][0]}. To silence this warning, " f"explicitly pass in an astropy Quantity " f"(e.g. 5. * astropy.units.cm) " f"(see http://docs.astropy.org/en/stable/units/)" ) ) # check units arg, unit, equiv, err = self._check_unit_core(arg, arg_name, arg_validations) # convert quantity if ( arg is not None and unit is not None and not arg_validations["pass_equivalent_units"] ): arg = arg.to(unit, equivalencies=equiv) elif err is not None: raise err self._check_value(arg, arg_name, arg_validations) return arg @property def validations(self): """ Requested validations on the decorated function's input arguments and return variable. """ return self._validations
[docs] def validate_quantities(func=None, validations_on_return=None, **validations): """ A decorator to 'validate' — control and convert — the units and values of input and return arguments to a function or method. Arguments are expected to be astropy :class:`~astropy.units.quantity.Quantity` objects. Parameters ---------- func: The function to be decorated validations_on_return: dictionary of validation specifications Specifications for unit and value validations on the return of the function being wrapped. (see `quantity validations`_ for valid specifications. **validations: dictionary of validation specifications Specifications for unit and value validations on the input arguments of the function being wrapped. Each keyword argument in ``validations`` is the name of a function argument to be validated and the keyword value contains the unit and value validation specifications. .. _`quantity validations`: Unit and value validations can be defined by passing one of the astropy :mod:`~astropy.units`, a list of astropy units, or a dictionary containing the keys defined below. Units can also be defined with function annotations, but must be consistent with decorator ``**validations`` arguments if used concurrently. If a key is omitted, then the default value will be assumed. ====================== ======= ================================================ Key Type Description ====================== ======= ================================================ units list of desired astropy :mod:`~astropy.units` equivalencies | [DEFAULT `None`] A list of equivalent pairs to try if | the units are not directly convertible. | (see :mod:`~astropy.units.equivalencies`, and/or `astropy equivalencies`_) pass_equivalent_units `bool` | [DEFAULT `False`] allow equivalent units | to pass can_be_negative `bool` [DEFAULT `True`] values can be negative can_be_complex `bool` [DEFAULT `False`] values can be complex numbers can_be_inf `bool` [DEFAULT `True`] values can be :data:`~numpy.inf` can_be_nan `bool` [DEFAULT `True`] values can be :data:`~numpy.nan` none_shall_pass `bool` [DEFAULT `False`] values can be a python `None` can_be_zero `bool` [DEFAULT `True`] values can be zero ====================== ======= ================================================ Notes ----- * Validation of function arguments ``*args`` and ``**kwargs`` is not supported. * `None` values will pass when `None` is included in the list of specified units, is set as a default value for the function argument, or ``none_shall_pass`` is set to `True`. If ``none_shall_pass`` is doubly/triply defined through the mentioned options, then they all must be consistent with each other. * If units are not specified in ``validations``, then the decorator will attempt to identify desired units by examining the function annotations. * Full functionality is defined by the class :class:`ValidateQuantities`. Examples -------- Define unit and value validations with decorator parameters:: import astropy.units as u from plasmapy.utils.decorators import validate_quantities @validate_quantities( mass={"units": u.g, "can_be_negative": False}, vel=u.cm / u.s, validations_on_return=[u.g * u.cm / u.s, u.kg * u.m / u.s], ) def foo(mass, vel): return mass * vel # on a method class Foo: @validate_quantities( mass={"units": u.g, "can_be_negative": False}, vel=u.cm / u.s, validations_on_return=[u.g * u.cm / u.s, u.kg * u.m / u.s], ) def bar(self, mass, vel): return mass * vel Define units with function annotations:: @validate_quantities(mass={"can_be_negative": False}) def foo(mass: u.g, vel: u.cm / u.s) -> u.g * u.cm / u.s: return mass * vel # rely only on annotations @validate_quantities def foo(x: u.cm, time: u.s) -> u.cm / u.s: return x / time # on a method class Foo: @validate_quantities(mass={"can_be_negative": False}) def bar(self, mass: u.g, vel: u.cm / u.s) -> u.g * u.cm / u.s: return mass * vel Define units using type hint annotations:: @validate_quantities def foo(x: u.Quantity[u.m], time: u.Quantity[u.s]) -> u.Quantity[u.m / u.s]: return x / time Allow `None` values to pass:: @validate_quantities(arg2={"none_shall_pass": True}, checks_on_return=[u.cm, None]) def foo(arg1: u.cm, arg2: u.cm = None): return None Allow return values to have equivalent units:: @validate_quantities( arg1={"units": u.cm}, checks_on_return={"units": u.km, "pass_equivalent_units": True}, ) def foo(arg1): return arg1 Allow equivalent units to pass with specified equivalencies:: @validate_quantities( arg1={ "units": u.K, "equivalencies": u.temperature(), "pass_equivalent_units": True, } ) def foo(arg1): return arg1 .. _astropy equivalencies: https://docs.astropy.org/en/stable/units/equivalencies.html """ if validations_on_return is not None: validations["validations_on_return"] = validations_on_return if func is not None: # `validate_quantities` called as a function return ValidateQuantities(**validations)(func) # `validate_quantities` called as a decorator "sugar-syntax" return ValidateQuantities(**validations)
def get_attributes_not_provided( self, expected_attributes: list[str] | None = None, both_or_either_attributes: list[Iterable[str]] | None = None, mutually_exclusive_attributes: list[Iterable[str]] | None = None, ): """ Collect attributes that weren't provided during instantiation needed to access a method. """ attributes_not_provided = [] if expected_attributes is not None: attributes_not_provided.extend( attribute for attribute in expected_attributes if getattr(self, attribute) is None ) if both_or_either_attributes is not None: for attribute_tuple in both_or_either_attributes: number_of_attributes_provided = sum( getattr(self, attribute) is not None for attribute in attribute_tuple ) if number_of_attributes_provided == 0: attributes_not_provided.append( f"at least one of {' or '.join(attribute_tuple)}" ) if mutually_exclusive_attributes is not None: for attribute_tuple in mutually_exclusive_attributes: number_of_attributes_provided = sum( getattr(self, attribute) is not None for attribute in attribute_tuple ) if number_of_attributes_provided != 1: attributes_not_provided.append( f"exactly one of {' or '.join(attribute_tuple)}" ) return attributes_not_provided
[docs] def validate_class_attributes( expected_attributes: list[str] | None = None, both_or_either_attributes: list[Iterable[str]] | None = None, mutually_exclusive_attributes: list[Iterable[str]] | None = None, ): """ A decorator responsible for raising errors if the expected arguments weren't provided during class instantiation. """ def decorator(attribute): def wrapper(self, *args, **kwargs): arguments_not_provided = get_attributes_not_provided( self, expected_attributes, both_or_either_attributes, mutually_exclusive_attributes, ) if len(arguments_not_provided) > 0: raise ValueError( f"{attribute.__name__} expected the following " f"additional arguments: {', '.join(arguments_not_provided)}" ) return attribute(self, *args, **kwargs) return wrapper return decorator