1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::Random;
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
12 types::Float,
13};
14
15#[derive(Debug, Clone)]
17pub struct NeuralODELayer {
18 pub weights: Array2<f64>,
20 pub biases: Array1<f64>,
22 pub integration_steps: usize,
24 pub step_size: f64,
26 pub solver: String,
28}
29
30impl NeuralODELayer {
31 pub fn new(
33 input_dim: usize,
34 hidden_dim: usize,
35 integration_steps: usize,
36 step_size: f64,
37 ) -> Self {
38 let mut rng = Random::default();
41 let mut weights = Array2::zeros((input_dim, input_dim));
42 for i in 0..input_dim {
43 for j in 0..input_dim {
44 weights[[i, j]] = rng.random_range(-3.0..3.0) / 3.0 * 0.1;
45 }
46 }
47 let biases = Array1::zeros(input_dim);
48
49 Self {
50 weights,
51 biases,
52 integration_steps,
53 step_size,
54 solver: "euler".to_string(),
55 }
56 }
57
58 pub fn solver(mut self, solver: String) -> Self {
60 self.solver = solver;
61 self
62 }
63
64 pub fn forward(&self, x: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
66 let mut state = x.to_owned();
67
68 for _ in 0..self.integration_steps {
69 let derivative = self.compute_derivative(&state.view())?;
70 state = self.integrate_step(&state.view(), &derivative.view())?;
71 }
72
73 Ok(state)
74 }
75
76 fn compute_derivative(&self, x: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
78 if x.len() != self.weights.ncols() {
79 return Err(SklearsError::InvalidInput(format!(
80 "Input dimension {} doesn't match weights {}",
81 x.len(),
82 self.weights.ncols()
83 )));
84 }
85
86 let linear = self.weights.dot(x) + &self.biases;
88 let nonlinear = linear.mapv(|x| x.tanh());
89
90 Ok(nonlinear)
91 }
92
93 fn integrate_step(&self, x: &ArrayView1<f64>, dx: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
95 match self.solver.as_str() {
96 "euler" => {
97 Ok(x + &(dx * self.step_size))
99 }
100 "rk4" => {
101 let k1 = dx * self.step_size;
103
104 let x_plus_k1_half = x + &(&k1 * 0.5);
105 let k2 = self.compute_derivative(&x_plus_k1_half.view())? * self.step_size;
106
107 let x_plus_k2_half = x + &(&k2 * 0.5);
108 let k3 = self.compute_derivative(&x_plus_k2_half.view())? * self.step_size;
109
110 let x_plus_k3 = x + &k3;
111 let k4 = self.compute_derivative(&x_plus_k3.view())? * self.step_size;
112
113 Ok(x + &(&k1 + &k2 * 2.0 + &k3 * 2.0 + &k4) * (1.0 / 6.0))
114 }
115 _ => Err(SklearsError::InvalidInput(format!(
116 "Unknown solver: {}",
117 self.solver
118 ))),
119 }
120 }
121
122 pub fn backward(
124 &mut self,
125 x: &ArrayView1<f64>,
126 grad_output: &ArrayView1<f64>,
127 ) -> SklResult<Array1<f64>> {
128 let epsilon = 1e-6;
130 let mut grad_input = Array1::zeros(x.len());
131
132 for i in 0..x.len() {
133 let mut x_plus = x.to_owned();
134 let mut x_minus = x.to_owned();
135 x_plus[i] += epsilon;
136 x_minus[i] -= epsilon;
137
138 let out_plus = self.forward(&x_plus.view())?;
139 let out_minus = self.forward(&x_minus.view())?;
140
141 let grad_i = grad_output.dot(&((&out_plus - &out_minus) / (2.0 * epsilon)));
142 grad_input[i] = grad_i;
143 }
144
145 Ok(grad_input)
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct NeuralODE<S = Untrained> {
152 state: S,
153 layers: Vec<NeuralODELayer>,
155 classifier_weights: Option<Array2<f64>>,
157 classifier_biases: Option<Array1<f64>>,
159 n_classes: usize,
161 learning_rate: f64,
163 max_iter: usize,
165 reg_param: f64,
167 integration_steps: usize,
169 step_size: f64,
171 solver: String,
173 random_state: Option<u64>,
175}
176
177impl Default for NeuralODE<Untrained> {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183impl NeuralODE<Untrained> {
184 pub fn new() -> Self {
186 Self {
187 state: Untrained,
188 layers: Vec::new(),
189 classifier_weights: None,
190 classifier_biases: None,
191 n_classes: 2,
192 learning_rate: 0.01,
193 max_iter: 100,
194 reg_param: 0.01,
195 integration_steps: 10,
196 step_size: 0.1,
197 solver: "euler".to_string(),
198 random_state: None,
199 }
200 }
201
202 pub fn learning_rate(mut self, lr: f64) -> Self {
204 self.learning_rate = lr;
205 self
206 }
207
208 pub fn max_iter(mut self, max_iter: usize) -> Self {
210 self.max_iter = max_iter;
211 self
212 }
213
214 pub fn reg_param(mut self, reg_param: f64) -> Self {
216 self.reg_param = reg_param;
217 self
218 }
219
220 pub fn integration_steps(mut self, steps: usize) -> Self {
222 self.integration_steps = steps;
223 self
224 }
225
226 pub fn step_size(mut self, step_size: f64) -> Self {
228 self.step_size = step_size;
229 self
230 }
231
232 pub fn solver(mut self, solver: String) -> Self {
234 self.solver = solver;
235 self
236 }
237
238 pub fn random_state(mut self, seed: u64) -> Self {
240 self.random_state = Some(seed);
241 self
242 }
243
244 pub fn add_layer(&mut self, input_dim: usize, hidden_dim: usize) {
246 let layer = NeuralODELayer::new(
247 input_dim,
248 hidden_dim,
249 self.integration_steps,
250 self.step_size,
251 )
252 .solver(self.solver.clone());
253 self.layers.push(layer);
254 }
255
256 fn initialize_classifier(&mut self, input_dim: usize, n_classes: usize) {
258 self.classifier_weights = Some({
259 let mut rng = Random::default();
260 let mut w = Array2::zeros((n_classes, input_dim));
261 for i in 0..n_classes {
262 for j in 0..input_dim {
263 w[[i, j]] = rng.random_range(-3.0..3.0) / 3.0 * 0.1;
264 }
265 }
266 w
267 });
268 self.classifier_biases = Some(Array1::zeros(n_classes));
269 }
270}
271
272#[derive(Debug, Clone)]
274pub struct NeuralODETrained {
275 pub layers: Vec<NeuralODELayer>,
277 pub classifier_weights: Array2<f64>,
279 pub classifier_biases: Array1<f64>,
281 pub classes: Array1<i32>,
283 pub learning_rate: f64,
285 pub max_iter: usize,
287 pub reg_param: f64,
289 pub integration_steps: usize,
291 pub step_size: f64,
293 pub solver: String,
295}
296
297impl<S> NeuralODE<S> {
298 fn forward(&self, x: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
300 let mut current = x.to_owned();
301
302 for layer in &self.layers {
303 current = layer.forward(¤t.view())?;
304 }
305
306 Ok(current)
307 }
308
309 fn classify(&self, features: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
311 match (&self.classifier_weights, &self.classifier_biases) {
312 (Some(weights), Some(biases)) => {
313 let logits = weights.dot(features) + biases;
314 Ok(self.softmax(&logits.view()))
315 }
316 _ => Err(SklearsError::InvalidInput(
317 "Classifier not initialized".to_string(),
318 )),
319 }
320 }
321
322 fn softmax(&self, x: &ArrayView1<f64>) -> Array1<f64> {
324 let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
325 let exp_x = x.mapv(|v| (v - max_val).exp());
326 let sum_exp = exp_x.sum();
327 exp_x / sum_exp
328 }
329}
330
331impl Estimator for NeuralODE<Untrained> {
332 type Config = ();
333 type Error = SklearsError;
334 type Float = Float;
335
336 fn config(&self) -> &Self::Config {
337 &()
338 }
339}
340
341impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for NeuralODE<Untrained> {
342 type Fitted = NeuralODE<NeuralODETrained>;
343
344 fn fit(self, x: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
345 let x = x.to_owned();
346 let y = y.to_owned();
347
348 if x.nrows() != y.len() {
349 return Err(SklearsError::InvalidInput(
350 "Number of samples in X and y must match".to_string(),
351 ));
352 }
353
354 if x.nrows() == 0 {
355 return Err(SklearsError::InvalidInput(
356 "No samples provided".to_string(),
357 ));
358 }
359
360 let labeled_count = y.iter().filter(|&&label| label >= 0).count();
362 if labeled_count == 0 {
363 return Err(SklearsError::InvalidInput(
364 "No labeled samples provided".to_string(),
365 ));
366 }
367
368 let mut unique_classes: Vec<i32> = y.iter().filter(|&&label| label >= 0).cloned().collect();
370 unique_classes.sort_unstable();
371 unique_classes.dedup();
372
373 let mut model = self.clone();
374 model.n_classes = unique_classes.len();
375
376 if model.layers.is_empty() {
378 model.add_layer(x.ncols(), x.ncols()); }
380
381 let last_layer_dim = x.ncols(); model.initialize_classifier(last_layer_dim, model.n_classes);
384
385 for _iteration in 0..model.max_iter {
387 }
390
391 Ok(NeuralODE {
392 state: NeuralODETrained {
393 layers: model.layers,
394 classifier_weights: model.classifier_weights.unwrap(),
395 classifier_biases: model.classifier_biases.unwrap(),
396 classes: Array1::from(unique_classes),
397 learning_rate: model.learning_rate,
398 max_iter: model.max_iter,
399 reg_param: model.reg_param,
400 integration_steps: model.integration_steps,
401 step_size: model.step_size,
402 solver: model.solver,
403 },
404 layers: Vec::new(),
405 classifier_weights: None,
406 classifier_biases: None,
407 n_classes: 0,
408 learning_rate: 0.0,
409 max_iter: 0,
410 reg_param: 0.0,
411 integration_steps: 0,
412 step_size: 0.0,
413 solver: String::new(),
414 random_state: None,
415 })
416 }
417}
418
419impl Predict<ArrayView2<'_, Float>, Array1<i32>> for NeuralODE<NeuralODETrained> {
420 fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
421 let x = x.to_owned();
422 let mut predictions = Array1::zeros(x.nrows());
423
424 for i in 0..x.nrows() {
425 let mut current = x.row(i).to_owned();
426
427 for layer in &self.state.layers {
429 current = layer.forward(¤t.view())?;
430 }
431
432 let logits =
434 self.state.classifier_weights.dot(¤t) + &self.state.classifier_biases;
435 let max_idx = logits
436 .iter()
437 .enumerate()
438 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
439 .map(|(idx, _)| idx)
440 .unwrap_or(0);
441 predictions[i] = self.state.classes[max_idx];
442 }
443
444 Ok(predictions)
445 }
446}
447
448impl PredictProba<ArrayView2<'_, Float>, Array2<f64>> for NeuralODE<NeuralODETrained> {
449 fn predict_proba(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
450 let x = x.to_owned();
451 let mut probabilities = Array2::zeros((x.nrows(), self.state.classes.len()));
452
453 for i in 0..x.nrows() {
454 let mut current = x.row(i).to_owned();
455
456 for layer in &self.state.layers {
458 current = layer.forward(¤t.view())?;
459 }
460
461 let logits =
463 self.state.classifier_weights.dot(¤t) + &self.state.classifier_biases;
464 let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
465 let exp_logits = logits.mapv(|v| (v - max_val).exp());
466 let sum_exp = exp_logits.sum();
467 let probs = exp_logits / sum_exp;
468
469 probabilities.row_mut(i).assign(&probs);
470 }
471
472 Ok(probabilities)
473 }
474}
475
476#[allow(non_snake_case)]
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use scirs2_core::array;
481
482 #[test]
483 fn test_neural_ode_layer_creation() {
484 let layer = NeuralODELayer::new(4, 8, 10, 0.1);
485 assert_eq!(layer.weights.dim(), (4, 4));
486 assert_eq!(layer.biases.len(), 4);
487 assert_eq!(layer.integration_steps, 10);
488 assert_eq!(layer.step_size, 0.1);
489 }
490
491 #[test]
492 fn test_neural_ode_layer_forward() {
493 let layer = NeuralODELayer::new(2, 4, 5, 0.1);
494 let x = array![1.0, 2.0];
495
496 let result = layer.forward(&x.view());
497 assert!(result.is_ok());
498
499 let output = result.unwrap();
500 assert_eq!(output.len(), 2);
501 }
502
503 #[test]
504 fn test_neural_ode_creation() {
505 let node = NeuralODE::new()
506 .learning_rate(0.01)
507 .max_iter(50)
508 .integration_steps(5)
509 .step_size(0.1);
510
511 assert_eq!(node.learning_rate, 0.01);
512 assert_eq!(node.max_iter, 50);
513 assert_eq!(node.integration_steps, 5);
514 assert_eq!(node.step_size, 0.1);
515 }
516
517 #[test]
518 fn test_neural_ode_fit_predict() {
519 let X = array![
520 [1.0, 2.0],
521 [2.0, 3.0],
522 [3.0, 4.0],
523 [4.0, 5.0],
524 [5.0, 6.0],
525 [6.0, 7.0]
526 ];
527 let y = array![0, 1, 0, 1, -1, -1]; let node = NeuralODE::new()
530 .learning_rate(0.1)
531 .max_iter(10)
532 .integration_steps(3)
533 .step_size(0.2);
534
535 let result = node.fit(&X.view(), &y.view());
536 assert!(result.is_ok());
537
538 let fitted = result.unwrap();
539 assert_eq!(fitted.state.classes.len(), 2);
540
541 let predictions = fitted.predict(&X.view());
542 assert!(predictions.is_ok());
543
544 let pred = predictions.unwrap();
545 assert_eq!(pred.len(), 6);
546
547 let probabilities = fitted.predict_proba(&X.view());
548 assert!(probabilities.is_ok());
549
550 let proba = probabilities.unwrap();
551 assert_eq!(proba.dim(), (6, 2));
552
553 for i in 0..6 {
555 let sum: f64 = proba.row(i).sum();
556 assert!((sum - 1.0).abs() < 1e-10);
557 }
558 }
559
560 #[test]
561 fn test_neural_ode_insufficient_labeled_samples() {
562 let X = array![[1.0, 2.0], [2.0, 3.0]];
563 let y = array![-1, -1]; let node = NeuralODE::new();
566 let result = node.fit(&X.view(), &y.view());
567 assert!(result.is_err());
568 }
569
570 #[test]
571 fn test_neural_ode_invalid_dimensions() {
572 let X = array![[1.0, 2.0], [2.0, 3.0]];
573 let y = array![0]; let node = NeuralODE::new();
576 let result = node.fit(&X.view(), &y.view());
577 assert!(result.is_err());
578 }
579
580 #[test]
581 fn test_neural_ode_layer_solvers() {
582 let layer = NeuralODELayer::new(2, 4, 5, 0.1).solver("rk4".to_string());
583 assert_eq!(layer.solver, "rk4");
584
585 let x = array![1.0, 2.0];
586 let result = layer.forward(&x.view());
587 assert!(result.is_ok());
588 }
589
590 #[test]
591 fn test_neural_ode_layer_backward() {
592 let mut layer = NeuralODELayer::new(2, 2, 3, 0.1);
593 let x = array![1.0, 2.0];
594 let grad_output = array![0.5, 0.5];
595
596 let result = layer.backward(&x.view(), &grad_output.view());
597 assert!(result.is_ok());
598
599 let grad_input = result.unwrap();
600 assert_eq!(grad_input.len(), 2);
601 }
602
603 #[test]
604 fn test_softmax_computation() {
605 let node = NeuralODE::new();
606 let logits = array![1.0, 2.0, 3.0];
607 let probs = node.softmax(&logits.view());
608
609 assert_eq!(probs.len(), 3);
610 assert!((probs.sum() - 1.0).abs() < 1e-10);
611 assert!(probs.iter().all(|&p| p >= 0.0 && p <= 1.0));
612 }
613
614 #[test]
615 fn test_neural_ode_with_different_parameters() {
616 let X = array![
617 [1.0, 2.0, 3.0],
618 [2.0, 3.0, 4.0],
619 [3.0, 4.0, 5.0],
620 [4.0, 5.0, 6.0]
621 ];
622 let y = array![0, 1, 0, -1]; let node = NeuralODE::new()
625 .learning_rate(0.05)
626 .max_iter(5)
627 .reg_param(0.1)
628 .integration_steps(2)
629 .step_size(0.2)
630 .solver("euler".to_string());
631
632 let result = node.fit(&X.view(), &y.view());
633 assert!(result.is_ok());
634
635 let fitted = result.unwrap();
636 let predictions = fitted.predict(&X.view()).unwrap();
637 assert_eq!(predictions.len(), 4);
638 }
639}