scirs2_series/advanced_training_modules/
neural_ode.rs1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::Result;
11
12#[derive(Debug)]
14pub struct NeuralODE<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
15 parameters: Array2<F>,
17 time_steps: Array1<F>,
19 solver_config: ODESolverConfig<F>,
21 input_dim: usize,
23 hidden_dim: usize,
24}
25
26#[derive(Debug, Clone)]
28pub struct ODESolverConfig<F: Float + Debug> {
29 method: IntegrationMethod,
31 #[allow(dead_code)]
33 step_size: F,
34 #[allow(dead_code)]
36 tolerance: F,
37}
38
39#[derive(Debug, Clone)]
41pub enum IntegrationMethod {
42 Euler,
44 RungeKutta4,
46 RKF45,
48}
49
50impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> NeuralODE<F> {
51 pub fn new(
53 input_dim: usize,
54 hidden_dim: usize,
55 time_steps: Array1<F>,
56 solver_config: ODESolverConfig<F>,
57 ) -> Self {
58 let total_params = input_dim * hidden_dim + hidden_dim * input_dim + 2 * hidden_dim;
60 let scale = F::from(2.0).unwrap() / F::from(input_dim).unwrap();
61 let std_dev = scale.sqrt();
62
63 let mut parameters = Array2::zeros((1, total_params));
64 for i in 0..total_params {
65 let val = ((i * 23) % 1000) as f64 / 1000.0 - 0.5;
66 parameters[[0, i]] = F::from(val).unwrap() * std_dev;
67 }
68
69 Self {
70 parameters,
71 time_steps,
72 solver_config,
73 input_dim,
74 hidden_dim,
75 }
76 }
77
78 pub fn forward(&self, initial_state: &Array1<F>) -> Result<Array2<F>> {
80 let num_times = self.time_steps.len();
81 let mut trajectory = Array2::zeros((num_times, self.input_dim));
82
83 for i in 0..self.input_dim {
85 trajectory[[0, i]] = initial_state[i];
86 }
87
88 for t in 1..num_times {
90 let dt = self.time_steps[t] - self.time_steps[t - 1];
91 let current_state = trajectory.row(t - 1).to_owned();
92
93 let next_state = match self.solver_config.method {
94 IntegrationMethod::Euler => self.euler_step(¤t_state, dt)?,
95 IntegrationMethod::RungeKutta4 => self.rk4_step(¤t_state, dt)?,
96 IntegrationMethod::RKF45 => self.rkf45_step(¤t_state, dt)?,
97 };
98
99 for i in 0..self.input_dim {
100 trajectory[[t, i]] = next_state[i];
101 }
102 }
103
104 Ok(trajectory)
105 }
106
107 fn neural_network(&self, state: &Array1<F>) -> Result<Array1<F>> {
109 let (w1, b1, w2, b2) = self.extract_ode_weights();
110
111 let mut hidden = Array1::zeros(self.hidden_dim);
113 for i in 0..self.hidden_dim {
114 let mut sum = b1[i];
115 for j in 0..self.input_dim {
116 sum = sum + w1[[i, j]] * state[j];
117 }
118 hidden[i] = self.tanh(sum);
119 }
120
121 let mut output = Array1::zeros(self.input_dim);
123 for i in 0..self.input_dim {
124 let mut sum = b2[i];
125 for j in 0..self.hidden_dim {
126 sum = sum + w2[[i, j]] * hidden[j];
127 }
128 output[i] = sum;
129 }
130
131 Ok(output)
132 }
133
134 fn extract_ode_weights(&self) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
136 let param_vec = self.parameters.row(0);
137 let mut idx = 0;
138
139 let mut w1 = Array2::zeros((self.hidden_dim, self.input_dim));
141 for i in 0..self.hidden_dim {
142 for j in 0..self.input_dim {
143 w1[[i, j]] = param_vec[idx];
144 idx += 1;
145 }
146 }
147
148 let mut b1 = Array1::zeros(self.hidden_dim);
150 for i in 0..self.hidden_dim {
151 b1[i] = param_vec[idx];
152 idx += 1;
153 }
154
155 let mut w2 = Array2::zeros((self.input_dim, self.hidden_dim));
157 for i in 0..self.input_dim {
158 for j in 0..self.hidden_dim {
159 w2[[i, j]] = param_vec[idx];
160 idx += 1;
161 }
162 }
163
164 let mut b2 = Array1::zeros(self.input_dim);
166 for i in 0..self.input_dim {
167 b2[i] = param_vec[idx];
168 idx += 1;
169 }
170
171 (w1, b1, w2, b2)
172 }
173
174 fn euler_step(&self, state: &Array1<F>, dt: F) -> Result<Array1<F>> {
176 let derivative = self.neural_network(state)?;
177 let mut next_state = Array1::zeros(self.input_dim);
178
179 for i in 0..self.input_dim {
180 next_state[i] = state[i] + dt * derivative[i];
181 }
182
183 Ok(next_state)
184 }
185
186 fn rk4_step(&self, state: &Array1<F>, dt: F) -> Result<Array1<F>> {
188 let k1 = self.neural_network(state)?;
189
190 let mut temp_state = Array1::zeros(self.input_dim);
191 for i in 0..self.input_dim {
192 temp_state[i] = state[i] + dt * k1[i] / F::from(2.0).unwrap();
193 }
194 let k2 = self.neural_network(&temp_state)?;
195
196 for i in 0..self.input_dim {
197 temp_state[i] = state[i] + dt * k2[i] / F::from(2.0).unwrap();
198 }
199 let k3 = self.neural_network(&temp_state)?;
200
201 for i in 0..self.input_dim {
202 temp_state[i] = state[i] + dt * k3[i];
203 }
204 let k4 = self.neural_network(&temp_state)?;
205
206 let mut next_state = Array1::zeros(self.input_dim);
207 for i in 0..self.input_dim {
208 next_state[i] = state[i]
209 + dt * (k1[i]
210 + F::from(2.0).unwrap() * k2[i]
211 + F::from(2.0).unwrap() * k3[i]
212 + k4[i])
213 / F::from(6.0).unwrap();
214 }
215
216 Ok(next_state)
217 }
218
219 fn rkf45_step(&self, state: &Array1<F>, dt: F) -> Result<Array1<F>> {
221 self.rk4_step(state, dt)
223 }
224
225 fn tanh(&self, x: F) -> F {
227 x.tanh()
228 }
229}
230
231impl<F: Float + Debug> ODESolverConfig<F> {
232 pub fn new(method: IntegrationMethod, step_size: F, tolerance: F) -> Self {
234 Self {
235 method,
236 step_size,
237 tolerance,
238 }
239 }
240
241 pub fn euler(step_size: F) -> Self {
243 Self::new(IntegrationMethod::Euler, step_size, F::from(1e-6).unwrap())
244 }
245
246 pub fn runge_kutta4(step_size: F) -> Self {
248 Self::new(
249 IntegrationMethod::RungeKutta4,
250 step_size,
251 F::from(1e-6).unwrap(),
252 )
253 }
254
255 pub fn rkf45(step_size: F, tolerance: F) -> Self {
257 Self::new(IntegrationMethod::RKF45, step_size, tolerance)
258 }
259}