This page was generated by nbsphinx from docs/notebooks/analysis/fit_functions.ipynb.
Interactive online version: Binder badge.

Fit Functions

Fit functions are a set of callable classes designed to aid in fitting analytical functions to data. A fit function class combines the following functionality:

  1. An analytical function that is callable with given parameters or fitted parameters.

  2. Curve fitting functionality (usually SciPy’s curve_fit() or linregress()), which stores the fit statistics and parameters into the class. This makes the function easily callable with the fitted parameters.

  3. Error propagation calculations.

  4. A root solver that returns either the known analytical solutions or uses SciPy’s fsolve() to calculate the roots.

[1]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

from plasmapy.analysis import fit_functions as ffuncs

plt.rcParams["figure.figsize"] = [10.5, 0.56 * 10.5]

Contents:

  1. Fit function basics

  2. Fitting to data

    1. Getting fit results

    2. Fit function is callable

    3. Plotting results

    4. Root solving

Fit function basics

There is an ever expanding collection of fit functions, but this notebook will use ExponentialPlusLinear as an example.

A fit function class has no required arguments at time of instantiation.

[2]:
# basic instantiation
explin = ffuncs.ExponentialPlusLinear()

# fit parameters are not set yet
(explin.params, explin.param_errors)
[2]:
(None, None)

Each fit parameter is given a name.

[3]:
('a', 'alpha', 'm', 'b')

These names are used throughout the fit function’s documentation, as well as in its __repr__, __str__, and latex_str methods.

[4]:
(explin, explin.__str__(), explin.latex_str)
[4]:
(f(x) = a exp(alpha x) + m x + b <class 'plasmapy.analysis.fit_functions.ExponentialPlusLinear'>,
 'f(x) = a exp(alpha x) + m x + b',
 'a \\, \\exp(\\alpha x) + m x + b')

Fitting to data

Fit functions provide the curve_fit() method to fit the analytical function to a set of \((x, y)\) data. This is typically done with SciPy’s curve_fit() function, but fitting is done with SciPy’s linregress() for the Linear fit function.

Let’s generate some noisy data to fit to…

[5]:
params = (5.0, 0.1, -0.5, -8.0)  # (a, alpha, m, b)
xdata = np.linspace(-20, 15, num=100)
ydata = explin.func(xdata, *params) + np.random.normal(0.0, 0.6, xdata.size)

plt.plot(xdata, ydata)
plt.xlabel("X", fontsize=14)
plt.ylabel("Y", fontsize=14)
[5]:
Text(0, 0.5, 'Y')
../../_images/notebooks_analysis_fit_functions_11_1.png

The fit function curve_fit() shares the same signature as SciPy’s curve_fit(), so any **kwargs will be passed on. By default, only the \((x, y)\) values are needed.

[6]:
explin.curve_fit(xdata, ydata)

Getting fit results

After fitting, the fitted parameters, uncertainties, and coefficient of determination, or \(r^2\), values can be retrieved through their respective properties, params, parame_errors, and rsq.

[7]:
(explin.params, explin.params.a, explin.params.alpha)
[7]:
(FitParamTuple(a=5.090414472250938, alpha=0.09859670513726874, m=-0.5048791271797906, b=-8.004611692648654),
 5.090414472250938,
 0.09859670513726874)
[8]:
(explin.param_errors, explin.param_errors.a, explin.param_errors.alpha)
[8]:
(FitParamTuple(a=1.2145019116815585, alpha=0.010426862875874562, m=0.05380805284429719, b=1.2410013705604934),
 1.2145019116815585,
 0.010426862875874562)
[9]:
[9]:
np.float64(0.9326527280325769)

Fit function is callable

Now that parameters are set, the fit function is callable.

[10]:
explin(0)
[10]:
np.float64(-2.914197220397716)

Associated errors can also be generated.

[11]:
y, y_err = explin(np.linspace(-1, 1, num=10), reterr=True)
(y, y_err)
[11]:
(array([-2.88726697, -2.89728619, -2.90504199, -2.91048422, -2.91356164,
        -2.91422186, -2.91241134, -2.90807533, -2.90115791, -2.89160187]),
 array([1.66021892, 1.67588195, 1.69225241, 1.70935001, 1.72719495,
        1.74580794, 1.76521022, 1.78542357, 1.8064703 , 1.82837332]))

Known uncertainties in \(x\) can be specified too.

[12]:
y, y_err = explin(np.linspace(-1, 1, num=10), reterr=True, x_err=0.1)
(y, y_err)
[12]:
(array([-2.88726697, -2.89728619, -2.90504199, -2.91048422, -2.91356164,
        -2.91422186, -2.91241134, -2.90807533, -2.90115791, -2.89160187]),
 array([1.66022648, 1.67588673, 1.69225502, 1.70935109, 1.72719516,
        1.74580796, 1.76521076, 1.78542535, 1.8064741 , 1.82837989]))

Plotting results

[13]:
# plot original data
plt.plot(xdata, ydata, marker="o", linestyle=" ", label="Data")
ax = plt.gca()
ax.set_xlabel("X", fontsize=14)
ax.set_ylabel("Y", fontsize=14)

ax.axhline(0.0, color="r", linestyle="--")

# plot fitted curve + error
yfit, yfit_err = explin(xdata, reterr=True)
ax.plot(xdata, yfit, color="orange", label="Fit")
ax.fill_between(
    xdata,
    yfit + yfit_err,
    yfit - yfit_err,
    color="orange",
    alpha=0.12,
    zorder=0,
    label="Fit Error",
)

# plot annotations
plt.legend(fontsize=14, loc="upper left")

txt = f"$f(x) = {explin.latex_str}$\n$r^2 = {explin.rsq:.3f}$\n"
for name, param, err in zip(
    explin.param_names, explin.params, explin.param_errors, strict=False
):
    txt += f"{name} = {param:.3f} $\\pm$ {err:.3f}\n"
txt_loc = [-13.0, ax.get_ylim()[1]]
txt_loc = ax.transAxes.inverted().transform(ax.transData.transform(txt_loc))
txt_loc[0] -= 0.02
txt_loc[1] -= 0.05
ax.text(
    txt_loc[0],
    txt_loc[1],
    txt,
    fontsize="large",
    transform=ax.transAxes,
    va="top",
    linespacing=1.5,
)
[13]:
Text(0.20727272727272733, 0.95, '$f(x) = a \\, \\exp(\\alpha x) + m x + b$\n$r^2 = 0.933$\na = 5.090 $\\pm$ 1.215\nalpha = 0.099 $\\pm$ 0.010\nm = -0.505 $\\pm$ 0.054\nb = -8.005 $\\pm$ 1.241\n')
../../_images/notebooks_analysis_fit_functions_25_1.png

Root solving

An exponential plus a linear offset has no analytical solutions for its roots, except for a few specific cases. To get around this, ExponentialPlusLinear().root_solve() uses SciPy’s fsolve() to calculate it’s roots. If a fit function has analytical solutions to its roots (e.g. Linear().root_solve()), then the method is overridden with the known solution.

[14]:
root, err = explin.root_solve(-15.0)
(root, err)
[14]:
(np.float64(-13.077416276767797), nan)

Let’s use Linear().root_solve() as an example for a known solution.

[15]:
lin = ffuncs.Linear(params=(1.0, -5.0), param_errors=(0.1, 0.1))
root, err = lin.root_solve()
(root, err)
[15]:
(5.0, np.float64(0.5099019513592785))