evoxels.inversion
Classes
|
Inverse modeling using JAX and diffrax. |
- class evoxels.inversion.InversionModel(vf: Any, problem_cls: Type, pos_params: list[str] | None = None, problem_kwargs: dict[str, ~typing.Any] | None=None, timestepper_cls: Type[TimeStepper] = <class 'evoxels.timesteppers.PseudoSpectralIMEX'>, backend: str = 'jax')
Inverse modeling using JAX and diffrax.
This lightweight helper wraps differentiable forward solves for
evoxelsproblem classes and provides utilities to fit model parameters via gradient-based optimization. It is intentionally minimal so that the individual steps of solving a PDE, computing residuals, and running a least-squares optimizer remain easy to follow.- __init__(vf: Any, problem_cls: Type, pos_params: list[str] | None = None, problem_kwargs: dict[str, ~typing.Any] | None=None, timestepper_cls: Type[TimeStepper] = <class 'evoxels.timesteppers.PseudoSpectralIMEX'>, backend: str = 'jax') None
- backend: str = 'jax'
- forward_solve(parameters, fieldname, saveat, dt0=0.1, verbose=True)
- pos_params: list[str] | None = None
- problem_cls: Type
- problem_kwargs: dict[str, Any] | None = None
- residuals(parameters, y0s__values__saveat, adjoint=diffrax.ForwardMode)
Calculate residuals between measured and simulated states.
- Parameters:
parameters (dict) – Current estimate of the model parameters.
y0s__values__saveat (tuple) – Tuple
(y0s, values, saveat)wherey0scontains the initial states for each sequence,valuescontains the observed states, andsaveatspecifies the time points of these observations.adjoint – Differentiation mode for
solve().
- Returns:
Array of residuals with shape matching
values.- Return type:
jax.Array
- solve(parameters, y0, saveat, adjoint=diffrax.ForwardMode, dt0=0.1)
Integrate the configured problem for a given parameter set.
- Parameters:
parameters (dict) – Dictionary containing the model parameters to solve with.
y0 (array-like) – Initial state field.
saveat (
diffrax.SaveAt) – Time points at which the solution should be stored.adjoint – Differentiation mode used by
diffrax.diffeqsolve().dt0 (float) – Initial step size for the time integrator.
- Returns:
Array of saved state fields with shape
(len(saveat.ts), Nx, Ny, Nz).- Return type:
jax.Array
- timestepper_cls
alias of
PseudoSpectralIMEX
- train(initial_parameters, data, inds, adjoint=diffrax.ForwardMode, rtol=1e-06, atol=1e-06, verbose=True, max_steps=1000)
Fit
parametersso that the model matches observed data.This method assembles the observed sequences into a format suitable for
optimistix.least_squares()and then runs a Levenberg–Marquardt optimisation to minimise the residuals returned byresiduals().- Parameters:
initial_parameters (dict) – Initial guess for the parameters to be optimised.
data (dict) – Dictionary containing
"ts"(time stamps) and"ys"(state fields) as produced bysolve().inds (list[list[int]]) – For each sequence, the indices in
datathat should be used for training. All sequences must have the same spacing.adjoint – Differentiation mode used when evaluating the residuals.
rtol (float) – Tolerances for the optimiser.
atol (float) – Tolerances for the optimiser.
verbose (bool) – If
True, prints optimisation progress.max_steps (int) – Maximum number of optimisation steps.
- Returns:
The optimiser state after termination.
- Return type:
optimistix.State
- vf: Any