Skip to main content

OdeFunc

Trait OdeFunc 

Source
pub trait OdeFunc: Send + Sync {
    // Required method
    fn call(&self, t: f64, y: &[f64], params: &[f64]) -> Vec<f64>;

    // Provided method
    fn vjp(
        &self,
        t: f64,
        y: &[f64],
        params: &[f64],
        grad_output: &[f64],
    ) -> (Vec<f64>, f64, Vec<f64>) { ... }
}
Expand description

ODE right-hand side: dy/dt = f(t, y, params).

Implement this trait to define the dynamics of a Neural ODE layer.

Required Methods§

Source

fn call(&self, t: f64, y: &[f64], params: &[f64]) -> Vec<f64>

Evaluate the ODE RHS at time t, state y, and parameters params.

Provided Methods§

Source

fn vjp( &self, t: f64, y: &[f64], params: &[f64], grad_output: &[f64], ) -> (Vec<f64>, f64, Vec<f64>)

Vector-Jacobian product (VJP) for the adjoint method.

Returns (dL/dy, dL/dt, dL/dparams) given grad_output = dL/df.

The default implementation uses finite differences (expensive but correct). Override for analytic efficiency.

Implementors§