evoxels.inversion

Classes

InversionModel(vf, problem_cls, pos_params, ...)

Inverse modeling using JAX and diffrax.

class evoxels.inversion.InversionModel(vf: Any, problem_cls: Type, pos_params: list[str] | None = None, problem_kwargs: dict[str, ~typing.Any] | None=None, timestepper_cls: Type[TimeStepper] = <class 'evoxels.timesteppers.PseudoSpectralIMEX'>, backend: str = 'jax')

Inverse modeling using JAX and diffrax.

This lightweight helper wraps differentiable forward solves for evoxels problem classes and provides utilities to fit model parameters via gradient-based optimization. It is intentionally minimal so that the individual steps of solving a PDE, computing residuals, and running a least-squares optimizer remain easy to follow.

__init__(vf: Any, problem_cls: Type, pos_params: list[str] | None = None, problem_kwargs: dict[str, ~typing.Any] | None=None, timestepper_cls: Type[TimeStepper] = <class 'evoxels.timesteppers.PseudoSpectralIMEX'>, backend: str = 'jax') None
backend: str = 'jax'
forward_solve(parameters, fieldname, saveat, dt0=0.1, verbose=True)
pos_params: list[str] | None = None
problem_cls: Type
problem_kwargs: dict[str, Any] | None = None
residuals(parameters, y0s__values__saveat, adjoint=diffrax.ForwardMode)

Calculate residuals between measured and simulated states.

Parameters:
  • parameters (dict) – Current estimate of the model parameters.

  • y0s__values__saveat (tuple) – Tuple (y0s, values, saveat) where y0s contains the initial states for each sequence, values contains the observed states, and saveat specifies the time points of these observations.

  • adjoint – Differentiation mode for solve().

Returns:

Array of residuals with shape matching values.

Return type:

jax.Array

solve(parameters, y0, saveat, adjoint=diffrax.ForwardMode, dt0=0.1)

Integrate the configured problem for a given parameter set.

Parameters:
  • parameters (dict) – Dictionary containing the model parameters to solve with.

  • y0 (array-like) – Initial state field.

  • saveat (diffrax.SaveAt) – Time points at which the solution should be stored.

  • adjoint – Differentiation mode used by diffrax.diffeqsolve().

  • dt0 (float) – Initial step size for the time integrator.

Returns:

Array of saved state fields with shape (len(saveat.ts), Nx, Ny, Nz).

Return type:

jax.Array

timestepper_cls

alias of PseudoSpectralIMEX

train(initial_parameters, data, inds, adjoint=diffrax.ForwardMode, rtol=1e-06, atol=1e-06, verbose=True, max_steps=1000)

Fit parameters so that the model matches observed data.

This method assembles the observed sequences into a format suitable for optimistix.least_squares() and then runs a Levenberg–Marquardt optimisation to minimise the residuals returned by residuals().

Parameters:
  • initial_parameters (dict) – Initial guess for the parameters to be optimised.

  • data (dict) – Dictionary containing "ts" (time stamps) and "ys" (state fields) as produced by solve().

  • inds (list[list[int]]) – For each sequence, the indices in data that should be used for training. All sequences must have the same spacing.

  • adjoint – Differentiation mode used when evaluating the residuals.

  • rtol (float) – Tolerances for the optimiser.

  • atol (float) – Tolerances for the optimiser.

  • verbose (bool) – If True, prints optimisation progress.

  • max_steps (int) – Maximum number of optimisation steps.

Returns:

The optimiser state after termination.

Return type:

optimistix.State

vf: Any