Source code for idesolver.idesolver

import logging
import warnings
from typing import Callable, Optional, Union

import numpy as np
import scipy.integrate as integ
import scipy.interpolate as inter

from . import exceptions

logger = logging.getLogger("idesolver")

[docs]def complex_quad( integrand: Callable, lower_bound: float, upper_bound: float, **kwargs ) -> (complex, float, float, tuple, tuple): """ A thin wrapper over :func:`scipy.integrate.quad` that handles splitting the real and complex parts of the integral and recombining them. Keyword arguments are passed to both of the internal ``quad`` calls. """ real_result, real_error, *real_extra = integ.quad( lambda x: np.real(integrand(x)), lower_bound, upper_bound, **kwargs ) imag_result, imag_error, *imag_extra = integ.quad( lambda x: np.imag(integrand(x)), lower_bound, upper_bound, **kwargs ) return ( real_result + (1j * imag_result), real_error, imag_error, real_extra, imag_extra, )
[docs]def global_error(y1: np.ndarray, y2: np.ndarray) -> float: """ The default global error function. The estimate is the square root of the sum of squared differences between `y1` and `y2`. Parameters ---------- y1 : :class:`numpy.ndarray` A guess of the solution. y2 : :class:`numpy.ndarray` Another guess of the solution. Returns ------- error : :class:`float` The global error estimate between `y1` and `y2`. """ diff = y1 - y2 return np.sqrt(np.real(np.vdot(diff, diff)))
def coerce_to_array( to_coerce: Union[float, np.float64, complex, np.complex128, np.ndarray, list] ) -> np.ndarray: """Coerce `to_coerce` into a numpy array""" return np.array(to_coerce, ndmin=1, copy=False) def dtype(n): return n.dtype if isinstance(n, np.ndarray) else type(n) # data types to recognize as complex in y_0 _COMPLEX_NUMERIC_TYPES = [complex, np.complex128]
[docs]class IDESolver: """ A class that handles solving an integro-differential equation of the form .. math:: \\frac{dy}{dx} & = c(y, x) + d(x) \\int_{\\alpha(x)}^{\\beta(x)} k(x, s) \\, F( y(s) ) \\, ds, \\\\ & x \\in [a, b], \\quad y(a) = y_0. Attributes ---------- x : :class:`numpy.ndarray` The positions where the solution is calculated (i.e., where :math:`y` is evaluated). y : :class:`numpy.ndarray` The solution :math:`y(x)`. ``None`` until :meth:`IDESolver.solve` is finished. global_error : :class:`float` The final global error estimate. ``None`` until :meth:`IDESolver.solve` is finished. iteration : :class:`int` The current iteration. ``None`` until :meth:`IDESolver.solve` starts. y_intermediate : The intermediate solutions. Only exists if ``store_intermediate_y`` is ``True``. """ def __init__( self, x: np.ndarray, y_0: Union[float, np.float64, complex, np.complex128, np.ndarray, list], c: Optional[Callable] = None, d: Optional[Callable] = None, k: Optional[Callable] = None, f: Optional[Callable] = None, lower_bound: Optional[Callable] = None, upper_bound: Optional[Callable] = None, global_error_tolerance: float = 1e-6, max_iterations: Optional[int] = None, ode_method: str = "RK45", ode_atol: float = 1e-8, ode_rtol: float = 1e-8, int_atol: float = 1e-8, int_rtol: float = 1e-8, interpolation_kind: str = "cubic", smoothing_factor: float = 0.5, store_intermediate_y: bool = False, global_error_function: Callable = global_error, ): """ Parameters ---------- x : :class:`numpy.ndarray` The array of :math:`x` values to find the solution :math:`y(x)` at. Generally something like ``numpy.linspace(a, b, num_pts)``. y_0 : :class:`float` or :class:`complex` or :class:`numpy.ndarray` The initial condition, :math:`y_0 = y(a)` (can be multidimensional). c : The function :math:`c(y, x)`. Defaults to :math:`c(y, x) = 0`. d : The function :math:`d(x)`. Defaults to :math:`d(x) = 1`. k : The kernel function :math:`k(x, s)`. Defaults to :math:`k(x, s) = 1`. f : The function :math:`F(y)`. Defaults to :math:`f(y) = 0`. lower_bound : The lower bound function :math:`\\alpha(x)`. Defaults to the first element of ``x``. upper_bound : The upper bound function :math:`\\beta(x)`. Defaults to the last element of ``x``. global_error_tolerance : :class:`float` The algorithm will continue until the global errors goes below this or uses more than `max_iterations` iterations. If ``None``, the algorithm continues until hitting `max_iterations`. max_iterations : :class:`int` The maximum number of iterations to use. If ``None``, iteration will not stop unless the `global_error_tolerance` is satisfied. Defaults to ``None``. ode_method : :class:`str` The ODE solution method to use. As the `method` option of :func:`scipy.integrate.solve_ivp`. Defaults to ``'RK45'``, which is good for non-stiff systems. ode_atol : :class:`float` The absolute tolerance for the ODE solver. As the `atol` argument of :func:`scipy.integrate.solve_ivp`. ode_rtol : :class:`float` The relative tolerance for the ODE solver. As the `rtol` argument of :func:`scipy.integrate.solve_ivp`. int_atol : :class:`float` The absolute tolerance for the integration routine. As the `epsabs` argument of :func:`scipy.integrate.quad`. int_rtol : :class:`float` The relative tolerance for the integration routine. As the `epsrel` argument of :func:`scipy.integrate.quad`. interpolation_kind : :class:`str` The type of interpolation to use. As the `kind` argument of :class:`scipy.interpolate.interp1d`. Defaults to ``'cubic'``. smoothing_factor : :class:`float` The smoothing factor used to combine the current guess with the new guess at each iteration. Defaults to ``0.5``. store_intermediate_y : :class:`bool` If ``True``, the intermediate guesses for :math:`y(x)` at each iteration will be stored in the attribute `y_intermediate`. global_error_function : The function to use to calculate the global error. Defaults to :func:`global_error`. """ self.y_0 = coerce_to_array(y_0) if dtype(self.y_0) in _COMPLEX_NUMERIC_TYPES: self.integrator = complex_quad else: self.integrator = integ.quad self.x = np.array(x) if c is None: c = lambda x, y: self._zeros() if d is None: d = lambda x: 1 if k is None: k = lambda x, s: 1 if f is None: f = lambda y: self._zeros() self.c = lambda x, y: coerce_to_array(c(x, y)) self.d = lambda x: coerce_to_array(d(x)) self.k = lambda x, s: coerce_to_array(k(x, s)) self.f = lambda y: coerce_to_array(f(y)) if lower_bound is None: lower_bound = lambda x: self.x[0] if upper_bound is None: upper_bound = lambda x: self.x[-1] self.lower_bound = lower_bound self.upper_bound = upper_bound if global_error_tolerance == 0 and max_iterations is None: raise exceptions.InvalidParameter( "global_error_tolerance cannot be 0 if max_iterations is None" ) if global_error_tolerance < 0: raise exceptions.InvalidParameter("global_error_tolerance cannot be negative") self.global_error_tolerance = global_error_tolerance self.global_error_function = global_error_function self.interpolation_kind = interpolation_kind if not 0 < smoothing_factor < 1: raise exceptions.InvalidParameter("Smoothing factor must be between 0 and 1") self.smoothing_factor = smoothing_factor if max_iterations is not None and max_iterations <= 0: raise exceptions.InvalidParameter("If given, max iterations must be greater than 0") self.max_iterations = max_iterations self.ode_method = ode_method self.ode_atol = ode_atol self.ode_rtol = ode_rtol self.int_atol = int_atol self.int_rtol = int_rtol self.store_intermediate = store_intermediate_y if self.store_intermediate: self.y_intermediate = [] self.iteration = None self.y = None self.global_error = None def _zeros(self) -> np.ndarray: return np.zeros_like(self.y_0)
[docs] def solve(self, callback: Optional[Callable] = None) -> np.ndarray: """ Compute the solution to the IDE. Will emit a warning message if the global error increases on an iteration. This does not necessarily mean that the algorithm is not converging, but may indicate that it's having problems. Will emit a warning message if the maximum number of iterations is used without reaching the global error tolerance. Parameters ---------- callback : A function to call after each iteration. The function is passed the :class:`IDESolver` instance, the current :math:`y` guess, and the current global error. Returns ------- :class:`numpy.ndarray` The solution to the IDE (i.e., :math:`y(x)`). """ # check if the user messed up by not passing y_0 as a complex number when they should have with warnings.catch_warnings(): warnings.filterwarnings( action="error", message="Casting complex values", category=np.ComplexWarning, ) try: y_current = self._initial_y() y_guess = self._solve_rhs_with_known_y(y_current) error_current = self._global_error(y_current, y_guess) if self.store_intermediate: self.y_intermediate.append(y_current) self.iteration = 0 logger.debug( f"Advanced to iteration {self.iteration}. Current error: {error_current}." ) if callback is not None: logger.debug(f"Calling {callback} after iteration {self.iteration}") callback(self, y_guess, error_current) while error_current > self.global_error_tolerance: new_current = self._next_y(y_current, y_guess) new_guess = self._solve_rhs_with_known_y(new_current) new_error = self._global_error(new_current, new_guess) if new_error > error_current: warnings.warn( f"Error increased on iteration {self.iteration}", exceptions.IDEConvergenceWarning, ) y_current, y_guess, error_current = ( new_current, new_guess, new_error, ) if self.store_intermediate: self.y_intermediate.append(y_current) self.iteration += 1 logger.debug( f"Advanced to iteration {self.iteration}. Current error: {error_current}." ) if callback is not None: logger.debug(f"Calling {callback} after iteration {self.iteration}") callback(self, y_guess, error_current) if self.max_iterations is not None and self.iteration >= self.max_iterations: warnings.warn( exceptions.IDEConvergenceWarning( f"Used maximum number of iterations ({self.max_iterations}), but only got to global error {error_current} (target {self.global_error_tolerance})" ) ) break except (np.ComplexWarning, TypeError) as e: raise exceptions.UnexpectedlyComplexValuedIDE( "Detected complex-valued IDE. Make sure to pass y_0 as a complex number." ) from e self.y = y_guess self.global_error = error_current # get rid of the array wrapper if the dimension is 1 if self.y_0.size == 1: self.y = self.y[0] if self.store_intermediate: self.y_intermediate = [y[0] for y in self.y_intermediate] return self.y
def _initial_y(self) -> np.ndarray: """Calculate the initial guess for `y`, by considering only `c` on the right-hand side of the IDE.""" return self._solve_ode(self.c) def _next_y(self, curr: np.ndarray, guess: np.ndarray) -> np.ndarray: """Calculate the next guess at the solution by merging two guesses.""" return (self.smoothing_factor * curr) + ((1 - self.smoothing_factor) * guess) def _global_error(self, y1: np.ndarray, y2: np.ndarray) -> float: """ Return the global error estimate between `y1` and `y2`. Parameters ---------- y1 A guess of the solution. y2 Another guess of the solution. Returns ------- error : :class:`float` The global error estimate between `y1` and `y2`. """ return self.global_error_function(y1, y2) def _solve_rhs_with_known_y(self, y: np.ndarray) -> np.ndarray: """Solves the right-hand-side of the IDE as if :math:`y` was `y`.""" interpolated_y = self._interpolate_y(y) def integral(x): def integrand(s): return self.k(x, s) * self.f(interpolated_y(s)) result = [] for i in range(self.y_0.size): r, *_ = self.integrator( lambda s: integrand(s)[i], self.lower_bound(x), self.upper_bound(x), epsabs=self.int_atol, epsrel=self.int_rtol, ) result.append(r) return coerce_to_array(result) def rhs(x, y): return self.c(x, interpolated_y(x)) + (self.d(x) * integral(x)) return self._solve_ode(rhs) def _interpolate_y(self, y: np.ndarray) -> inter.interp1d: """ Interpolate `y` along `x`, using `interpolation_kind`. Parameters ---------- y : :class:`numpy.ndarray` The y values to interpolate (probably a guess at the solution). Returns ------- interpolator : :class:`scipy.interpolate.interp1d` The interpolator function. """ return inter.interp1d( x=self.x, y=y, kind=self.interpolation_kind, fill_value="extrapolate", assume_sorted=True, ) def _solve_ode(self, rhs: Callable) -> np.ndarray: """Solves an ODE with the given right-hand side.""" sol = integ.solve_ivp( fun=rhs, y0=self.y_0, t_span=(self.x[0], self.x[-1]), t_eval=self.x, method=self.ode_method, atol=self.ode_atol, rtol=self.ode_rtol, ) if not sol.success: raise exceptions.ODESolutionFailed(f"Error while trying to solve ODE: {sol.status}") return sol.y