evoxels.inversion

Classes

InversionModel(vf, problem_cls[, ...])

Inverse modeling using JAX and diffrax.

class evoxels.inversion.InversionModel(vf: Any, problem_cls: Type, pos_params: list[str] | None = None, problem_kwargs: dict[str, Any] | None = None, backend: str = 'jax')

Inverse modeling using JAX and diffrax.

This small helper class wraps the differentiable solver implementation and provides utilities to fit material parameters via gradient based optimization. It is intentionally lightweight so that new users can easily follow the individual steps: solving the PDE, computing residuals and running a least squares optimiser.

__init__(vf: Any, problem_cls: Type, pos_params: list[str] | None = None, problem_kwargs: dict[str, Any] | None = None, backend: str = 'jax') None
backend: str = 'jax'
forward_solve(parameters, fieldname, saveat, dt0=0.1, verbose=True)
pos_params: list[str] | None = None
problem_cls: Type
problem_kwargs: dict[str, Any] | None = None
residuals(parameters, y0s__values__saveat, adjoint=diffrax.ForwardMode)

Calculate residuals between measured and simulated states.

Parameters:
  • parameters (dict) – Current estimate of the model parameters.

  • y0s__values__saveat (tuple) – Tuple (y0s, values, saveat) where y0s contains the initial states for each sequence, values contains the observed states and saveat specifies the time points of these observations.

  • adjoint – Differentiation mode for solve().

Returns:

Array of residuals with shape matching values.

Return type:

jax.Array

solve(parameters, y0, saveat, adjoint=diffrax.ForwardMode, dt0=0.1)

Integrate the Cahn–Hilliard equation for a given parameter set.

Parameters:
  • parameters (dict) – Dictionary containing the material parameters to solve with.

  • y0 (array-like) – Initial concentration field.

  • saveat (diffrax.SaveAt) – Time points at which the solution should be stored.

  • adjoint – Differentiation mode used by diffrax.diffeqsolve().

  • dt0 (float) – Initial step size for the time integrator.

Returns:

Array of saved concentration fields with shape (len(saveat.ts), Nx, Ny, Nz).

Return type:

jax.Array

train(initial_parameters, data, inds, adjoint=diffrax.ForwardMode, rtol=1e-06, atol=1e-06, verbose=True, max_steps=1000)

Fit parameters so that the model matches observed data.

This method assembles the observed sequences into a format suitable for optimistix.least_squares() and then runs a Levenberg–Marquardt optimisation to minimise the residuals returned by residuals().

Parameters:
  • initial_parameters (dict) – Initial guess for the parameters to be optimised.

  • data (dict) – Dictionary containing "ts" (time stamps) and "ys" (concentration fields) as produced by solve().

  • inds (list[list[int]]) – For each sequence, the indices in data that should be used for training. All sequences must have the same spacing.

  • adjoint – Differentiation mode used when evaluating the residuals.

  • rtol (float) – Tolerances for the optimiser.

  • atol (float) – Tolerances for the optimiser.

  • verbose (bool) – If True, prints optimisation progress.

  • max_steps (int) – Maximum number of optimisation steps.

Returns:

The optimiser state after termination.

Return type:

optimistix.State

vf: Any