"""
Decorator for checking input/output arguments of functions.
"""
__all__ = [
"check_values",
"check_units",
"check_relativistic",
"CheckBase",
"CheckUnits",
"CheckValues",
]
import collections
import functools
import inspect
import warnings
from functools import reduce
from operator import add
from typing import Any, Optional, Union
import astropy.units as u
import numpy as np
from astropy.constants import c
from astropy.units.equivalencies import Equivalency
from plasmapy.utils.decorators.helpers import preserve_signature
from plasmapy.utils.exceptions import (
PlasmaPyWarning,
RelativityError,
RelativityWarning,
)
[docs]
class CheckBase:
"""
Base class for 'Check' decorator classes.
Parameters
----------
checks_on_return
specified checks on the return of the wrapped function
**checks
specified checks on the input arguments of the wrapped function
"""
def __init__(self, checks_on_return=None, **checks) -> None:
self._checks = checks
if checks_on_return is not None:
self._checks["checks_on_return"] = checks_on_return
@property
def checks(self):
"""
Requested checks on the decorated function's input arguments
and/or return.
"""
return self._checks
[docs]
class CheckValues(CheckBase):
"""
A decorator class to 'check' — limit/control — the values of input and return
arguments to a function or method.
Parameters
----------
checks_on_return: Dict[str, bool]
Specifications for value checks on the return of the function being wrapped.
(see `check values`_ for valid specifications)
**checks: Dict[str, Dict[str, bool]]
Specifications for value checks on the input arguments of the function
being wrapped. Each keyword argument in ``checks`` is the name of a function
argument to be checked and the keyword value contains the value check
specifications.
.. _`check values`:
The value check specifications are defined within a dictionary containing
the keys defined below. If the dictionary is empty or omitting keys,
then the default value will be assumed for the missing keys.
================ ======= ================================================
Key Type Description
================ ======= ================================================
can_be_negative `bool` [DEFAULT `True`] values can be negative
can_be_complex `bool` [DEFAULT `False`] values can be complex numbers
can_be_inf `bool` [DEFAULT `True`] values can be :data:`~numpy.inf`
can_be_nan `bool` [DEFAULT `True`] values can be :data:`~numpy.nan`
none_shall_pass `bool` [DEFAULT `False`] values can be a python `None`
can_be_zero `bool` [DEFAULT `True`] values can be zero
================ ======= ================================================
Notes
-----
* Checking of function arguments ``*args`` and ``**kwargs`` is not supported.
Examples
--------
.. code-block:: python
from plasmapy.utils.decorators.checks import CheckValues
@CheckValues(
arg1={"can_be_negative": False, "can_be_nan": False},
arg2={"can_be_inf": False},
checks_on_return={"none_shall_pass": True},
)
def foo(arg1, arg2):
return None
# on a method
class Foo:
@CheckValues(
arg1={"can_be_negative": False, "can_be_nan": False},
arg2={"can_be_inf": False},
checks_on_return={"none_shall_pass": True},
)
def bar(self, arg1, arg2):
return None
"""
#: Default values for the possible 'check' keys.
# To add a new check to the class, the following needs to be done:
# 1. Add a key & default value to the `__check_defaults` dictionary
# 2. Add a corresponding if-statement to method `_check_value`
#
__check_defaults = {
"can_be_negative": True,
"can_be_complex": False,
"can_be_inf": True,
"can_be_nan": True,
"none_shall_pass": False,
"can_be_zero": True,
}
def __init__(
self,
checks_on_return: Optional[dict[str, bool]] = None,
**checks: dict[str, bool],
) -> None:
super().__init__(checks_on_return=checks_on_return, **checks)
[docs]
def __call__(self, f):
"""
Decorate a function.
Parameters
----------
f
Function to be wrapped
Returns
-------
function
wrapped function of ``f``
"""
self.f = f
wrapped_sign = inspect.signature(f)
@preserve_signature
@functools.wraps(f)
def wrapper(*args, **kwargs):
# map args and kwargs to function parameters
bound_args = wrapped_sign.bind(*args, **kwargs)
bound_args.apply_defaults()
# get checks
checks = self._get_value_checks(bound_args)
# check input arguments
for arg_name in checks:
# skip check of output/return
if arg_name == "checks_on_return":
continue
# check argument
self._check_value(
bound_args.arguments[arg_name], arg_name, checks[arg_name]
)
# call function
_return = f(**bound_args.arguments)
# check function return
if "checks_on_return" in checks:
self._check_value(
_return, "checks_on_return", checks["checks_on_return"]
)
return _return
return wrapper
def _get_value_checks(
self, bound_args: inspect.BoundArguments
) -> dict[str, dict[str, bool]]:
"""
Review :attr:`checks` and function bound arguments to build a
complete 'checks' dictionary.
If a check key is omitted from the argument checks, then a
default value is assumed (see `check values`_).
Parameters
----------
bound_args: :class:`inspect.BoundArguments`
arguments passed into the function being wrapped
.. code-block:: python
bound_args = inspect.signature(f).bind(*args, **kwargs)
Returns
-------
Dict[str, Dict[str, bool]]
A complete 'checks' dictionary for checking function input
arguments and return.
"""
# initialize validation dictionary
out_checks = {}
# Iterate through function bound arguments + return and build `out_checks:
#
# artificially add "return" to parameters
things_to_check = bound_args.signature.parameters.copy()
things_to_check["checks_on_return"] = inspect.Parameter(
"checks_on_return",
inspect.Parameter.POSITIONAL_ONLY,
annotation=bound_args.signature.return_annotation,
)
for param in things_to_check.values():
# variable arguments are NOT checked
# e.g. in foo(x, y, *args, d=None, **kwargs) variable arguments
# *args and **kwargs will NOT be checked
#
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
):
continue
# grab the checks dictionary for the desired parameter
try:
param_in_checks = self.checks[param.name]
except KeyError:
# checks for parameter not specified
continue
# build `out_checks`
# read checks and/or apply defaults values
out_checks[param.name] = {}
for v_name, v_default in self.__check_defaults.items():
try:
out_checks[param.name][v_name] = param_in_checks.get(
v_name, v_default
)
except AttributeError:
# for the case that checks are defined for an argument,
# but is NOT a dictionary
# (e.g. CheckValues(x=u.cm) ... this scenario could happen
# during subclassing)
out_checks[param.name][v_name] = v_default
# Does `self.checks` indicate arguments not used by f?
if missing_params := list(set(self.checks) - set(out_checks)):
params_str = ", ".join(missing_params)
warnings.warn(
PlasmaPyWarning(
f"Expected to value check parameters {params_str} but they "
f"are missing from the call to {self.f.__name__}"
)
)
return out_checks
def _check_value( # noqa: C901, PLR0912
self,
arg,
arg_name: str,
arg_checks: dict[str, bool],
):
"""
Perform checks ``arg_checks`` on function argument ``arg``.
Parameters
----------
arg
The argument to be checked.
arg_name: str
The name of the argument to be checked.
arg_checks: Dict[str, bool]
The requested checks for the argument.
Raises
------
ValueError
If a check fails.
"""
if arg_name == "checks_on_return":
valueerror_msg = "The return value "
else:
valueerror_msg = f"The argument '{arg_name}' "
valueerror_msg += f"to function {self.f.__name__}() can not contain"
# check values
# * 'none_shall_pass' always needs to be checked first
ckeys = list(self.__check_defaults.keys())
ckeys.remove("none_shall_pass")
ckeys = ("none_shall_pass", *tuple(ckeys))
for ckey in ckeys:
if ckey == "can_be_complex":
if not arg_checks[ckey] and np.any(np.iscomplexobj(arg)):
raise ValueError(f"{valueerror_msg} complex numbers.")
elif ckey == "can_be_inf":
if not arg_checks[ckey] and np.any(np.isinf(arg)):
raise ValueError(f"{valueerror_msg} infs.")
elif ckey == "can_be_nan":
if not arg_checks["can_be_nan"] and np.any(np.isnan(arg)):
raise ValueError(f"{valueerror_msg} NaNs.")
elif ckey == "can_be_negative":
if not arg_checks[ckey] and np.any(arg < 0):
raise ValueError(f"{valueerror_msg} negative numbers.")
elif ckey == "can_be_zero":
if not arg_checks[ckey] and np.any(arg == 0):
raise ValueError(f"{valueerror_msg} zeros.")
elif ckey == "none_shall_pass":
if arg is None and arg_checks[ckey]:
break
elif arg is None: # noqa: RET508
raise ValueError(f"{valueerror_msg} Nones.")
[docs]
class CheckUnits(CheckBase):
"""
A decorator class to 'check' — limit/control — the units of input
and return arguments to a function or method.
Parameters
----------
checks_on_return: list of :mod:`~astropy.units` or dict of unit specifications
Specifications for unit checks on the return of the function being wrapped.
(see `check units`_ for valid specifications)
**checks: list of astropy :mod:`~astropy.units` or dict of unit specifications
Specifications for unit checks on the input arguments of the function
being wrapped. Each keyword argument in ``checks`` is the name of a function
argument to be checked and the keyword value contains the unit check
specifications.
.. _`check units`:
Unit checks can be defined by passing one of the astropy
:mod:`~astropy.units`, a list of astropy units, or a dictionary containing
the keys defined below. Units can also be defined with function
annotations, but must be consistent with decorator ``**checks`` arguments if
used concurrently. If a key is omitted, then the default value will be assumed.
====================== ======= ================================================
Key Type Description
====================== ======= ================================================
units list of desired astropy :mod:`~astropy.units`
equivalencies | [DEFAULT `None`] A list of equivalent pairs to
try if
| the units are not directly convertible.
| (see :mod:`~astropy.units.equivalencies`,
and/or `astropy equivalencies`_)
pass_equivalent_units `bool` | [DEFAULT `False`] allow equivalent units
| to pass
====================== ======= ================================================
Notes
-----
* Checking of function arguments ``*args`` and ``**kwargs`` is not supported.
* Decorator does NOT perform any unit conversions.
* If it is desired that `None` values do not raise errors or warnings, then
include `None` in the list of units or as a default value for the function
argument.
* If units are not specified in ``checks``, then the decorator will attempt
to identify desired units by examining the function annotations.
Examples
--------
Define units with decorator parameters::
import astropy.units as u
from plasmapy.utils.decorators import CheckUnits
@CheckUnits(arg1={"units": u.cm}, arg2=u.cm, checks_on_return=[u.cm, u.km])
def foo(arg1, arg2):
return arg1 + arg2
# or on a method
class Foo:
@CheckUnits(arg1={"units": u.cm}, arg2=u.cm, checks_on_return=[u.cm, u.km])
def bar(self, arg1, arg2):
return arg1 + arg2
Define units with function annotations::
@CheckUnits()
def foo(arg1: u.cm, arg2: u.cm) -> u.cm:
return arg1 + arg2
# or on a method
class Foo:
@CheckUnits()
def bar(self, arg1: u.cm, arg2: u.cm) -> u.cm:
return arg1 + arg2
Allow `None` values to pass, on input and output::
@CheckUnits(checks_on_return=[u.cm, None])
def foo(arg1: u.cm = None):
return arg1
Allow return values to have equivalent units::
@CheckUnits(
arg1={"units": u.cm},
checks_on_return={"units": u.km, "pass_equivalent_units": True},
)
def foo(arg1):
return arg1
Allow equivalent units to pass with specified equivalencies::
@CheckUnits(
arg1={
"units": u.K,
"equivalencies": u.temperature_energy(),
"pass_equivalent_units": True,
}
)
def foo(arg1):
return arg1
.. _astropy equivalencies:
https://docs.astropy.org/en/stable/units/equivalencies.html
"""
#: Default values for the possible 'check' keys.
# To add a new check to the class, the following needs to be done:
# 1. Add a key & default value to the `__check_defaults` dictionary
# 2. Add a corresponding conditioning statement to `_get_unit_checks`
# 3. Add a corresponding behavior to `_check_unit`
#
__check_defaults = {
"units": None,
"equivalencies": None,
"pass_equivalent_units": False,
"none_shall_pass": False,
}
def __init__(
self,
checks_on_return: Union[u.Unit, list[u.Unit], dict[str, Any]] = None,
**checks: Union[u.Unit, list[u.Unit], dict[str, Any]],
) -> None:
super().__init__(checks_on_return=checks_on_return, **checks)
[docs]
def __call__(self, f):
"""
Decorate a function.
Parameters
----------
f
Function to be wrapped
Returns
-------
function
wrapped function of ``f``
"""
self.f = f
wrapped_sign = inspect.signature(f)
@preserve_signature
@functools.wraps(f)
def wrapper(*args, **kwargs):
# combine args and kwargs into dictionary
bound_args = wrapped_sign.bind(*args, **kwargs)
bound_args.apply_defaults()
# get checks
checks = self._get_unit_checks(bound_args)
# check (input) argument units
for arg_name in checks:
# skip check of output/return
if arg_name == "checks_on_return":
continue
# check argument
self._check_unit(
bound_args.arguments[arg_name], arg_name, checks[arg_name]
)
# call function
_return = f(**bound_args.arguments)
# check output
if "checks_on_return" in checks:
self._check_unit(
_return, "checks_on_return", checks["checks_on_return"]
)
return _return
return wrapper
def _get_unit_checks( # noqa: C901, PLR0912, PLR0915
self, bound_args: inspect.BoundArguments
) -> dict[str, dict[str, Any]]:
"""
Review :attr:`checks` and function bound arguments to build a complete 'checks'
dictionary. If a check key is omitted from the argument checks, then a default
value is assumed (see `check units`_).
Parameters
----------
bound_args: :class:`inspect.BoundArguments`
arguments passed into the function being wrapped
.. code-block:: python
bound_args = inspect.signature(f).bind(*args, **kwargs)
Returns
-------
Dict[str, Dict[str, Any]]
A complete 'checks' dictionary for checking function input arguments
and return.
"""
# initialize validation dictionary
out_checks = {}
# Iterate through function bound arguments + return and build `out_checks`:
#
# artificially add "return" to parameters
things_to_check = bound_args.signature.parameters.copy()
things_to_check["checks_on_return"] = inspect.Parameter(
"checks_on_return",
inspect.Parameter.POSITIONAL_ONLY,
annotation=bound_args.signature.return_annotation,
)
for param in things_to_check.values():
# variable arguments are NOT checked
# e.g. in foo(x, y, *args, d=None, **kwargs) variable arguments
# *args and **kwargs will NOT be checked
#
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
):
continue
# grab the checks dictionary for the desired parameter
try:
param_checks = self.checks[param.name]
except KeyError:
param_checks = None
# -- Determine target units `_units` --
# target units can be defined in one of three ways (in
# preferential order):
# 1. direct keyword pass-through
# i.e. CheckUnits(x=u.cm)
# CheckUnits(x=[u.cm, u.s])
# 2. keyword pass-through via dictionary definition
# i.e. CheckUnits(x={'units': u.cm})
# CheckUnits(x={'units': [u.cm, u.s]})
# 3. function annotations
#
# * if option (3) is used simultaneously with option (1) or (2), then
# checks defined by (3) must be consistent with checks from (1) or (2)
# to avoid raising an error.
# * if None is included in the units list, then None values are allowed
#
_none_shall_pass = False
_units = None
_units_are_from_anno = False
if param_checks is not None:
# checks for argument were defined with decorator
try:
_units = param_checks["units"]
except TypeError:
# if checks is NOT None and is NOT a dictionary, then assume
# only units were specified
# e.g. CheckUnits(x=u.cm)
#
_units = param_checks
except KeyError:
# if checks does NOT have 'units' but is still a dictionary,
# then other check conditions may have been specified and the
# user is relying on function annotations to define desired
# units
_units = None
# If no units have been specified by decorator checks, then look for
# function annotations.
#
# Reconcile units specified by decorator checks and function annotations
_units_anno = None
if param.annotation is not inspect.Parameter.empty:
# unit annotations defined
_units_anno = param.annotation
if _units is _units_anno is param_checks is None:
# no checks specified and no unit annotations defined
continue
elif _units is _units_anno is None: # noqa: RET507
# checks specified, but NO unit checks
msg = "No astropy.units specified for "
if param.name == "checks_on_return":
msg += "return value "
else:
msg += f"argument {param.name} "
msg += f"of function {self.f.__name__}()."
raise ValueError(msg)
elif _units is None:
_units = _units_anno
_units_are_from_anno = True
_units_anno = None
# Ensure `_units` is an iterable
if not isinstance(_units, collections.abc.Iterable):
_units = [_units]
if not isinstance(_units_anno, collections.abc.Iterable):
_units_anno = [_units_anno]
# Is None allowed?
if None in _units or param.default is None:
_none_shall_pass = True
# Remove Nones
if None in _units:
_units = [t for t in _units if t is not None]
if None in _units_anno:
_units_anno = [t for t in _units_anno if t is not None]
# ensure all _units are astropy.units.Unit or physical types &
# define 'units' for unit checks &
# define 'none_shall_pass' check
_units = self._condition_target_units(
_units, from_annotations=_units_are_from_anno
)
_units_anno = self._condition_target_units(
_units_anno, from_annotations=True
)
if any(_u not in _units for _u in _units_anno):
raise ValueError(
f"For argument '{param.name}', "
f"annotation units ({_units_anno}) are not included in the units "
f"specified by decorator arguments ({_units}). Use either "
f"decorator arguments or function annotations to defined unit "
f"types, or make sure annotation specifications match decorator "
f"argument specifications."
)
if not _units and not _units_anno and param_checks is None:
# annotations did not specify units
continue
elif not _units and not _units_anno: # noqa: RET507
# checks specified, but NO unit checks
msg = "No astropy.units specified for "
if param.name == "checks_on_return":
msg += "return value "
else:
msg += f"argument {param.name} "
msg += f"of function {self.f.__name__}()."
raise ValueError(msg)
out_checks[param.name] = {
"units": _units,
"none_shall_pass": _none_shall_pass,
}
# -- Determine target equivalencies --
# Unit equivalences can be defined by:
# 1. keyword pass-through via dictionary definition
# e.g. CheckUnits(x={'units': u.C,
# 'equivalencies': u.temperature})
#
# initialize equivalencies
try:
_equivs = param_checks["equivalencies"]
except (KeyError, TypeError):
_equivs = self.__check_defaults["equivalencies"]
# ensure equivalences are properly formatted
if _equivs is None or _equivs == [None]:
_equivs = None
elif isinstance(_equivs, Equivalency):
pass
elif isinstance(_equivs, (list, tuple)):
# flatten list to non-list elements
if isinstance(_equivs, tuple):
_equivs = [_equivs]
else:
_equivs = self._flatten_equivalencies_list(_equivs)
# ensure passed equivalencies list is structured properly
# [(), ...]
# or [Equivalency(), ...]
#
# * All equivalencies must be a list of 2, 3, or 4 element tuples
# structured like...
# (from_unit, to_unit, forward_func, backward_func)
#
if all(isinstance(el, Equivalency) for el in _equivs):
_equivs = reduce(add, _equivs)
else:
_equivs = self._normalize_equivalencies(_equivs)
out_checks[param.name]["equivalencies"] = _equivs
# -- Determine if equivalent units pass --
try:
peu = param_checks.get(
"pass_equivalent_units",
self.__check_defaults["pass_equivalent_units"],
)
except (AttributeError, TypeError):
peu = self.__check_defaults["pass_equivalent_units"]
out_checks[param.name]["pass_equivalent_units"] = peu
# Does `self.checks` indicate arguments not used by f?
if missing_params := list(set(self.checks.keys()) - set(out_checks.keys())):
params_str = ", ".join(missing_params)
warnings.warn(
PlasmaPyWarning(
f"Expected to unit check parameters {params_str} but they "
f"are missing from the call to {self.f.__name__}"
)
)
return out_checks
def _check_unit(self, arg, arg_name: str, arg_checks: dict[str, Any]):
"""
Perform unit checks ``arg_checks`` on function argument ``arg``.
Parameters
----------
arg
The argument to be checked
arg_name: str
The name of the argument to be checked
arg_checks: Dict[str, Any]
The requested checks for the argument
Raises
------
ValueError
If ``arg`` is `None` when `arg_checks['none_shall_pass']=False`
TypeError
If ``arg`` does not have units
:class:`astropy.units.UnitTypeError`
If the units of ``arg`` do not satisfy conditions of ``arg_checks``
"""
arg, unit, equiv, err = self._check_unit_core(arg, arg_name, arg_checks)
if err is not None:
raise err
def _check_unit_core( # noqa: C901, PLR0912, PLR0915
self, arg, arg_name: str, arg_checks: dict[str, Any]
) -> tuple[
Optional[u.Quantity],
Optional[u.Unit],
Optional[list[Any]],
Optional[Exception],
]:
"""
Determines if `arg` passes unit checks `arg_checks` and if the units of
`arg` is equivalent to any units specified in `arg_checks`.
Parameters
----------
arg
The argument to be checked
arg_name: str
The name of the argument to be checked
arg_checks: Dict[str, Any]
The requested checks for the argument
Returns
-------
(`arg`, `unit`, `equivalencies`, `error`)
* `arg` is the original input argument `arg` or `None` if unit
checks fail
* `unit` is the identified astropy :mod:`~astropy.units` that `arg`
can be converted to or `None` if none exist
* `equivalencies` is the astropy :mod:`~astropy.units.equivalencies`
used for the unit conversion or `None`
* `error` is the `Exception` associated with the failed unit checks
or `None` for successful unit checks
"""
# initialize str for error messages
if arg_name == "checks_on_return":
err_msg = "The return value "
else:
err_msg = f"The argument '{arg_name}' "
err_msg += f"to function {self.f.__name__}()"
# initialize ValueError message
valueerror_msg = f"{err_msg} can not contain"
# initialize TypeError message
typeerror_msg = f"{err_msg} should be an astropy Quantity with "
if len(arg_checks["units"]) == 1:
typeerror_msg += f"the following unit: {arg_checks['units'][0]}"
else:
typeerror_msg += "one of the following units: "
for unit in arg_checks["units"]:
typeerror_msg += str(unit)
if unit != arg_checks["units"][-1]:
typeerror_msg += ", "
if arg_checks["none_shall_pass"]:
typeerror_msg += "or None "
# pass Nones if allowed
if arg is None:
if arg_checks["none_shall_pass"]:
return arg, None, None, None
else:
return None, None, None, ValueError(f"{valueerror_msg} Nones")
# check units
in_acceptable_units = []
equiv = arg_checks["equivalencies"]
for unit in arg_checks["units"]:
try:
in_acceptable_units.append(
arg.unit.is_equivalent(unit, equivalencies=equiv)
)
except AttributeError:
if hasattr(arg, "unit"):
err_specifier = (
"a 'unit' attribute without an 'is_equivalent' method"
)
else:
err_specifier = "no 'unit' attribute"
msg = (
f"{err_msg} has {err_specifier}. "
f"Use an astropy Quantity instead."
)
return None, None, None, TypeError(msg)
# How many acceptable units?
nacceptable = np.count_nonzero(in_acceptable_units)
unit = None
equiv = None
err = None
if nacceptable == 0:
# NO equivalent units
arg = None
err = u.UnitTypeError(typeerror_msg)
else:
# is there an exact match?
units_arr = np.array(arg_checks["units"])
units_equal_mask = np.equal(units_arr, arg.unit)
units_mask = np.logical_and(units_equal_mask, in_acceptable_units)
if np.count_nonzero(units_mask) == 1:
# matched exactly to a desired unit
unit = units_arr[units_mask][0]
equiv = arg_checks["equivalencies"]
elif nacceptable == 1:
# there is a match to 1 equivalent unit
unit = units_arr[in_acceptable_units][0]
equiv = arg_checks["equivalencies"]
if not arg_checks["pass_equivalent_units"]:
err = u.UnitTypeError(typeerror_msg)
elif not arg_checks["pass_equivalent_units"]:
# there is a match to more than 1 equivalent units
arg = None
err = u.UnitTypeError(typeerror_msg)
return arg, unit, equiv, err
@staticmethod
def _condition_target_units(
targets: list[Union[str, u.Unit, u.Quantity]],
from_annotations: bool = False,
) -> list:
"""
From a `list` of target objects that have or represent units,
return a `list` of conditioned :class:`~astropy.units.Unit`
objects.
Parameters
----------
targets: `list` of `str`, `~astropy.units.Unit`, or `~astropy.units.Quantity`
A list containing strings representing units (e.g., ``"kg"``,
`~astropy.units.Unit` objects (e.g., ``u.kg``), or
|Quantity| objects indexed with a `~astropy.units.U nit`
object (e.g., ``u.Quantity[u.kg]``).
from_annotations: bool, default: `False`
Indicates if ``targets`` originated from function/method
annotations versus decorator input arguments.
Returns
-------
list:
`list` of ``targets`` converted into
:class:`~astropy.units.Unit` objects.
Raises
------
TypeError
If `target` is not a valid type for
:class:`~astropy.units.Unit` when ``from_annotations == True``,
ValueError
If a ``target`` is a valid unit type but not a valid value
for :class:`~astropy.units.Unit`.
"""
# Note: this method does not allow for astropy physical types. This is
# done because we expect all use cases of CheckUnits to define the
# exact units desired.
allowed_units = []
for target in targets:
# The following two blocks are to create extract the unit from
# annotations of the form u.Quantity[u.m], which is an annotated
# alias. The unit is stored as the first item in the __metadata__
# attribute and the original class is stored in the __origin__
# attribute.
annotation_metadata = getattr(target, "__metadata__", None)
annotation_original_class = getattr(target, "__origin__", None)
if (
annotation_original_class is u.Quantity
and annotation_metadata is not None
):
target = annotation_metadata[0] # noqa: PLW2901
try:
target_unit = u.Unit(target)
allowed_units.append(target_unit)
except TypeError:
# not a unit type
if not from_annotations:
raise
continue
return allowed_units
@staticmethod
def _normalize_equivalencies(equivalencies):
"""
Normalize equivalencies to ensure each is in a 4-tuple of the
form `(from_unit, to_unit, forward_func, backward_func)`.
`forward_func` maps `from_unit` into `to_unit` and `backward_func` does
the reverse.
Parameters
----------
equivalencies: list of equivalent pairs
list of astropy :mod:`~astropy.units.equivalencies` to be normalized
Raises
------
ValueError
if an equivalency can not be interpreted
Notes
-----
* the code here was copied and modified from
:func:`astropy.units.core._normalize_equivalencies` from AstroPy
version 3.2.3
* this will work on both the old style list equivalencies (pre AstroPy v3.2.1)
and the modern equivalencies defined with the
:class:`~astropy.units.equivalencies.Equivalency` class
"""
if equivalencies is None:
return []
normalized = []
def return_argument(x):
return x
for i, equiv in enumerate(equivalencies):
if len(equiv) == 2:
from_unit, to_unit = equiv
a = b = return_argument
elif len(equiv) == 3:
from_unit, to_unit, a = equiv
b = a
elif len(equiv) == 4:
from_unit, to_unit, a, b = equiv
else:
raise ValueError(f"Invalid equivalence entry {i}: {equiv!r}")
if not (
from_unit is u.Unit(from_unit)
and (to_unit is None or to_unit is u.Unit(to_unit))
and callable(a)
and callable(b)
):
raise ValueError(f"Invalid equivalence entry {i}: {equiv!r}")
normalized.append((from_unit, to_unit, a, b))
return normalized
def _flatten_equivalencies_list(self, elist):
"""
Given a list of equivalencies, flatten out any sub-element lists.
Parameters
----------
elist: list
list of astropy :mod:`~astropy.units.equivalencies` to be flattened
Returns
-------
list
a flattened list of astropy :mod:`~astropy.units.equivalencies`
"""
new_list = []
for el in elist:
if not isinstance(el, list):
new_list.append(el)
else:
new_list.extend(self._flatten_equivalencies_list(el))
return new_list
[docs]
def check_units(
func=None,
checks_on_return: Optional[dict[str, Any]] = None,
**checks: dict[str, Any],
):
"""
A decorator to 'check' — limit/control — the units of input and return
arguments to a function or method.
Parameters
----------
func:
The function to be decorated
checks_on_return: list of :mod:`~astropy.units` or dict of unit specifications
Specifications for unit checks on the return of the function
being wrapped. (see "check units"_ for valid specifications)
**checks: list of :mod:`~astropy.units` or dict of unit specifications
Specifications for unit checks on the input arguments of the
function being wrapped. Each keyword argument in ``checks`` is
the name of a function argument to be checked and the keyword
value contains the unit check specifications.
.. _`check units`:
Unit checks can be defined by passing one of the astropy
:mod:`~astropy.units`, a list of astropy units, or a dictionary
containing the keys defined below. Units can also be defined
with function annotations, but must be consistent with decorator
``**checks`` arguments if used concurrently. If a key is
omitted, then the default value will be assumed.
====================== ======= ================================================
Key Type Description
====================== ======= ================================================
units list of desired astropy :mod:`~astropy.units`
equivalencies | [DEFAULT `None`] A list of equivalent pairs to
try if
| the units are not directly convertible.
| (see :mod:`~astropy.units.equivalencies`,
and/or `astropy equivalencies`_)
pass_equivalent_units `bool` | [DEFAULT `False`] allow equivalent units
| to pass
====================== ======= ================================================
Notes
-----
* Checking of function arguments ``*args`` and ``**kwargs`` is not
supported.
* Decorator does NOT perform any unit conversions, look to
:func:`~plasmapy.utils.decorators.validators.validate_quantities`
if that functionality is desired.
* If it is desired that `None` values do not raise errors or
warnings, then include `None` in the list of units or as a default
value for the function argument.
* If units are not specified in ``checks``, then the decorator will
attempt to identify desired units by examining the function
annotations.
* Full functionality is defined by the class :class:`CheckUnits`.
Examples
--------
Define units with decorator parameters::
import astropy.units as u
from plasmapy.utils.decorators import check_units
@check_units(arg1={"units": u.cm}, arg2=u.cm, checks_on_return=[u.cm, u.km])
def foo(arg1, arg2):
return arg1 + arg2
# or on a method
class Foo:
@check_units(arg1={"units": u.cm}, arg2=u.cm, checks_on_return=[u.cm, u.km])
def bar(self, arg1, arg2):
return arg1 + arg2
Define units with function annotations::
@check_units
def foo(arg1: u.cm, arg2: u.cm) -> u.cm:
return arg1 + arg2
# or on a method
class Foo:
@check_units
def bar(self, arg1: u.cm, arg2: u.cm) -> u.cm:
return arg1 + arg2
Allow `None` values to pass::
@check_units(checks_on_return=[u.cm, None])
def foo(arg1: u.cm = None):
return arg1
Allow return values to have equivalent units::
@check_units(
arg1={"units": u.cm},
checks_on_return={"units": u.km, "pass_equivalent_units": True},
)
def foo(arg1):
return arg1
Allow equivalent units to pass with specified equivalencies::
@check_units(
arg1={
"units": u.K,
"equivalencies": u.temperature(),
"pass_equivalent_units": True,
}
)
def foo(arg1):
return arg1
.. _astropy equivalencies:
https://docs.astropy.org/en/stable/units/equivalencies.html
"""
if checks_on_return is not None:
checks["checks_on_return"] = checks_on_return
return CheckUnits(**checks)(func) if func is not None else CheckUnits(**checks)
[docs]
def check_values(
func=None,
checks_on_return: Optional[dict[str, bool]] = None,
**checks: dict[str, bool],
):
"""
A decorator to 'check' — limit/control — the values of input and
return arguments to a function or method.
Parameters
----------
func:
The function to be decorated
checks_on_return: Dict[str, bool]
Specifications for value checks on the return of the function
being wrapped. (see `check values`_ for valid specifications)
**checks: Dict[str, Dict[str, bool]]
Specifications for value checks on the input arguments of the
function being wrapped. Each keyword argument in ``checks`` is
the name of a function argument to be checked and the keyword
value contains the value check specifications.
.. _`check values`:
The value check specifications are defined within a dictionary
containing the keys defined below. If the dictionary is empty
or omitting keys, then the default value will be assumed for the
missing keys.
================ ======= ================================================
Key Type Description
================ ======= ================================================
can_be_negative `bool` [DEFAULT `True`] values can be negative
can_be_complex `bool` [DEFAULT `False`] values can be complex numbers
can_be_inf `bool` [DEFAULT `True`] values can be :data:`~numpy.inf`
can_be_nan `bool` [DEFAULT `True`] values can be :data:`~numpy.nan`
none_shall_pass `bool` [DEFAULT `False`] values can be a python `None`
can_be_zero `bool` [DEFAULT `True`] values can be zero
================ ======= ================================================
Notes
-----
* Checking of function arguments ``*args`` and ``**kwargs`` is not
supported.
* Full functionality is defined by the class :class:`CheckValues`.
Examples
--------
.. code-block:: python
from plasmapy.utils.decorators import check_values
@check_values(
arg1={"can_be_negative": False, "can_be_nan": False},
arg2={"can_be_inf": False},
checks_on_return={"none_shall_pass": True},
)
def foo(arg1, arg2):
return None
# on a method
class Foo:
@check_values(
arg1={"can_be_negative": False, "can_be_nan": False},
arg2={"can_be_inf": False},
checks_on_return={"none_shall_pass": True},
)
def bar(self, arg1, arg2):
return None
"""
if checks_on_return is not None:
checks["checks_on_return"] = checks_on_return
return CheckValues(**checks) if func is None else CheckValues(**checks)(func)
[docs]
def check_relativistic(func=None, betafrac: float = 0.05):
"""
Warns or raises an exception when the output of the decorated
function is greater than ``betafrac`` times the speed of light.
Parameters
----------
func : function, optional
The function to decorate.
betafrac : float, optional
The minimum fraction of the speed of light that will raise a
`~plasmapy.utils.exceptions.RelativityWarning`. Defaults to 5%.
Returns
-------
function
Decorated function.
Raises
------
TypeError
If ``V`` is not a `~astropy.units.Quantity`.
~astropy.units.UnitConversionError
If ``V`` is not in units of velocity.
ValueError
If ``V`` contains any `~numpy.nan` values.
~plasmapy.utils.exceptions.RelativityError
If ``V`` is greater than or equal to the speed of light.
Warns
-----
: `~plasmapy.utils.exceptions.RelativityWarning`
If ``V`` is greater than or equal to ``betafrac`` times the
speed of light, but less than the speed of light.
Examples
--------
>>> import astropy.units as u
>>> @check_relativistic
... def speed():
... return 1 * u.m / u.s
Passing in a custom ``betafrac``:
>>> @check_relativistic(betafrac=0.01)
... def speed():
... return 1 * u.m / u.s
"""
def decorator(f):
@preserve_signature
@functools.wraps(f)
def wrapper(*args, **kwargs):
return_ = f(*args, **kwargs)
_check_relativistic(return_, f.__name__, betafrac=betafrac)
return return_
return wrapper
return decorator(func) if func else decorator
def _check_relativistic(
V: u.Quantity[u.m / u.s],
funcname: str,
betafrac: float = 0.05,
) -> None:
r"""
Warn or raise error for relativistic or superrelativistic
velocities.
Parameters
----------
V : ~astropy.units.Quantity
A velocity.
funcname : str
The name of the original function to be printed in the error
messages.
betafrac : float, optional
The minimum fraction of the speed of light that will generate
a warning. Defaults to 5%.
Raises
------
TypeError
If ``V`` is not a `~astropy.units.Quantity`.
~astropy.units.UnitConversionError
If ``V`` is not in units of velocity.
ValueError
If ``V`` contains any `~numpy.nan` values.
RelativityError
If ``V`` is greater than or equal to the speed of light.
Warns
-----
~plasmapy.utils.RelativityWarning
If ``V`` is greater than or equal to the specified fraction of
the speed of light.
Examples
--------
>>> import astropy.units as u
>>> _check_relativistic(1 * u.m / u.s, "function_calling_this")
"""
# TODO: Replace `funcname` with func.__name__?
errmsg = "V must be a Quantity with units of velocity in _check_relativistic"
if not isinstance(V, u.Quantity):
raise TypeError(errmsg)
try:
V_over_c = (V / c).to_value(u.dimensionless_unscaled)
except u.UnitConversionError as ex:
raise u.UnitConversionError(errmsg) from ex
beta = np.max(np.abs(V_over_c))
if beta == np.inf:
raise RelativityError(f"{funcname} is yielding an infinite velocity.")
elif beta >= 1:
raise RelativityError(
f"{funcname} is yielding a velocity that is {round(beta, 3)} "
f"times the speed of light."
)
elif beta >= betafrac:
warnings.warn(
f"{funcname} is yielding a velocity that is "
f"{round(beta * 100, 3)}% of the speed of "
f"light. Relativistic effects may be important.",
RelativityWarning,
)