evoxels.inversion
Classes
|
Inverse modeling using JAX and diffrax. |
- class evoxels.inversion.InversionModel(vf: Any, problem_cls: Type, pos_params: list[str] | None = None, problem_kwargs: dict[str, Any] | None = None, backend: str = 'jax')
Inverse modeling using JAX and diffrax.
This small helper class wraps the differentiable solver implementation and provides utilities to fit material parameters via gradient based optimization. It is intentionally lightweight so that new users can easily follow the individual steps: solving the PDE, computing residuals and running a least squares optimiser.
- __init__(vf: Any, problem_cls: Type, pos_params: list[str] | None = None, problem_kwargs: dict[str, Any] | None = None, backend: str = 'jax') None
- backend: str = 'jax'
- forward_solve(parameters, fieldname, saveat, dt0=0.1, verbose=True)
- pos_params: list[str] | None = None
- problem_cls: Type
- problem_kwargs: dict[str, Any] | None = None
- residuals(parameters, y0s__values__saveat, adjoint=diffrax.ForwardMode)
Calculate residuals between measured and simulated states.
- Parameters:
parameters (dict) – Current estimate of the model parameters.
y0s__values__saveat (tuple) – Tuple
(y0s, values, saveat)wherey0scontains the initial states for each sequence,valuescontains the observed states andsaveatspecifies the time points of these observations.adjoint – Differentiation mode for
solve().
- Returns:
Array of residuals with shape matching
values.- Return type:
jax.Array
- solve(parameters, y0, saveat, adjoint=diffrax.ForwardMode, dt0=0.1)
Integrate the Cahn–Hilliard equation for a given parameter set.
- Parameters:
parameters (dict) – Dictionary containing the material parameters to solve with.
y0 (array-like) – Initial concentration field.
saveat (
diffrax.SaveAt) – Time points at which the solution should be stored.adjoint – Differentiation mode used by
diffrax.diffeqsolve().dt0 (float) – Initial step size for the time integrator.
- Returns:
Array of saved concentration fields with shape
(len(saveat.ts), Nx, Ny, Nz).- Return type:
jax.Array
- train(initial_parameters, data, inds, adjoint=diffrax.ForwardMode, rtol=1e-06, atol=1e-06, verbose=True, max_steps=1000)
Fit
parametersso that the model matches observed data.This method assembles the observed sequences into a format suitable for
optimistix.least_squares()and then runs a Levenberg–Marquardt optimisation to minimise the residuals returned byresiduals().- Parameters:
initial_parameters (dict) – Initial guess for the parameters to be optimised.
data (dict) – Dictionary containing
"ts"(time stamps) and"ys"(concentration fields) as produced bysolve().inds (list[list[int]]) – For each sequence, the indices in
datathat should be used for training. All sequences must have the same spacing.adjoint – Differentiation mode used when evaluating the residuals.
rtol (float) – Tolerances for the optimiser.
atol (float) – Tolerances for the optimiser.
verbose (bool) – If
True, prints optimisation progress.max_steps (int) – Maximum number of optimisation steps.
- Returns:
The optimiser state after termination.
- Return type:
optimistix.State
- vf: Any