scirs2_neural/models/
mod.rs

1//! Neural network model implementations
2//!
3//! This module provides implementations of neural network models,
4//! including Sequential models and training utilities.
5
6use ndarray::{Array, ScalarOperand};
7use num_traits::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::losses::Loss;
12use crate::optimizers::Optimizer;
13
14/// Trait for neural network models
15pub trait Model<F: Float + Debug + ScalarOperand> {
16    /// Forward pass through the model
17    fn forward(&self, input: &Array<F, ndarray::IxDyn>) -> Result<Array<F, ndarray::IxDyn>>;
18
19    /// Backward pass to compute gradients
20    fn backward(
21        &self,
22        input: &Array<F, ndarray::IxDyn>,
23        grad_output: &Array<F, ndarray::IxDyn>,
24    ) -> Result<Array<F, ndarray::IxDyn>>;
25
26    /// Update the model parameters with the given learning rate
27    fn update(&mut self, learning_rate: F) -> Result<()>;
28
29    /// Train the model on a batch of data
30    fn train_batch(
31        &mut self,
32        inputs: &Array<F, ndarray::IxDyn>,
33        targets: &Array<F, ndarray::IxDyn>,
34        loss_fn: &dyn Loss<F>,
35        optimizer: &mut dyn Optimizer<F>,
36    ) -> Result<F>;
37
38    /// Predict the output for a batch of inputs
39    fn predict(&self, inputs: &Array<F, ndarray::IxDyn>) -> Result<Array<F, ndarray::IxDyn>>;
40
41    /// Evaluate the model on a batch of data
42    fn evaluate(
43        &self,
44        inputs: &Array<F, ndarray::IxDyn>,
45        targets: &Array<F, ndarray::IxDyn>,
46        loss_fn: &dyn Loss<F>,
47    ) -> Result<F>;
48}
49
50pub mod architectures;
51pub mod sequential;
52pub mod trainer;
53
54pub use architectures::*;
55pub use sequential::Sequential;
56pub use trainer::{History, Trainer, TrainingConfig};