Skip to main content

Module neural_ode

Module neural_ode 

Source
Expand description

Neural ODE (Neural Ordinary Differential Equations) implementation.

Provides the Dormand-Prince RK45 adaptive solver and adjoint sensitivity method for memory-efficient gradient computation through continuous dynamics.

§Overview

Neural ODEs replace discrete layer stacks with a continuous ODE:

  dy/dt = f(t, y, θ),   y(t0) = y0

The output is y(t1) obtained by numerical integration. Gradients are computed via the adjoint sensitivity method, which avoids storing all intermediate states during the forward pass.

§Example

use tensorlogic_train::neural_ode::{NeuralOde, OdeFunc, OdeSolverConfig};

struct LinearOde;
impl OdeFunc for LinearOde {
    fn call(&self, _t: f64, y: &[f64], params: &[f64]) -> Vec<f64> {
        y.iter().zip(params.iter()).map(|(yi, pi)| yi * pi).collect()
    }
    fn vjp(&self, _t: f64, y: &[f64], params: &[f64], grad: &[f64])
        -> (Vec<f64>, f64, Vec<f64>)
    {
        let dy = grad.iter().zip(params.iter()).map(|(g, p)| g * p).collect();
        let dt = 0.0_f64;
        let dp = grad.iter().zip(y.iter()).map(|(g, yi)| g * yi).collect();
        (dy, dt, dp)
    }
}

let ode = NeuralOde::new(LinearOde, 0.0, 1.0);
let sol = ode.forward(&[1.0], &[-1.0]).unwrap();
assert!((sol.states.last().unwrap()[0] - (-1.0_f64).exp()).abs() < 1e-3);

Structs§

AdaptiveSolution
Result of an adaptive Dormand-Prince RK45 integration.
AdjointResult
Gradient information produced by the adjoint sensitivity method.
NeuralOde
A Neural ODE layer that wraps an OdeFunc with fixed integration limits.
OdeSolution
Result of a fixed-step RK4 integration.
OdeSolverConfig
Configuration for the adaptive ODE solver.

Enums§

OdeError
Errors that can occur during ODE integration.

Traits§

OdeFunc
ODE right-hand side: dy/dt = f(t, y, params).

Functions§

dopri5_solve
Integrate an ODE with the adaptive Dormand-Prince RK45 method (DOPRI5).
rk4_solve
Integrate an ODE using classic 4th-order Runge-Kutta with a fixed step size.