scirs2_neural/models/
mod.rs1use ndarray::{Array, ScalarOperand};
7use num_traits::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::losses::Loss;
12use crate::optimizers::Optimizer;
13
14pub trait Model<F: Float + Debug + ScalarOperand> {
16 fn forward(&self, input: &Array<F, ndarray::IxDyn>) -> Result<Array<F, ndarray::IxDyn>>;
18
19 fn backward(
21 &self,
22 input: &Array<F, ndarray::IxDyn>,
23 grad_output: &Array<F, ndarray::IxDyn>,
24 ) -> Result<Array<F, ndarray::IxDyn>>;
25
26 fn update(&mut self, learning_rate: F) -> Result<()>;
28
29 fn train_batch(
31 &mut self,
32 inputs: &Array<F, ndarray::IxDyn>,
33 targets: &Array<F, ndarray::IxDyn>,
34 loss_fn: &dyn Loss<F>,
35 optimizer: &mut dyn Optimizer<F>,
36 ) -> Result<F>;
37
38 fn predict(&self, inputs: &Array<F, ndarray::IxDyn>) -> Result<Array<F, ndarray::IxDyn>>;
40
41 fn evaluate(
43 &self,
44 inputs: &Array<F, ndarray::IxDyn>,
45 targets: &Array<F, ndarray::IxDyn>,
46 loss_fn: &dyn Loss<F>,
47 ) -> Result<F>;
48}
49
50pub mod architectures;
51pub mod sequential;
52pub mod trainer;
53
54pub use architectures::*;
55pub use sequential::Sequential;
56pub use trainer::{History, Trainer, TrainingConfig};