pub trait Model<F: Float + Debug + ScalarOperand> {
// Required methods
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>;
fn backward(
&self,
input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>>;
fn update(&mut self, learning_rate: F) -> Result<()>;
fn train_batch(
&mut self,
inputs: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
loss_fn: &dyn Loss<F>,
optimizer: &mut dyn Optimizer<F>,
) -> Result<F>;
fn predict(&self, inputs: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>;
fn evaluate(
&self,
inputs: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
loss_fn: &dyn Loss<F>,
) -> Result<F>;
}
Expand description
Trait for neural network models
Required Methods§
Sourcefn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>
Forward pass through the model
Sourcefn backward(
&self,
input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>>
fn backward( &self, input: &Array<F, IxDyn>, grad_output: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>
Backward pass to compute gradients
Sourcefn update(&mut self, learning_rate: F) -> Result<()>
fn update(&mut self, learning_rate: F) -> Result<()>
Update the model parameters with the given learning rate
Sourcefn train_batch(
&mut self,
inputs: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
loss_fn: &dyn Loss<F>,
optimizer: &mut dyn Optimizer<F>,
) -> Result<F>
fn train_batch( &mut self, inputs: &Array<F, IxDyn>, targets: &Array<F, IxDyn>, loss_fn: &dyn Loss<F>, optimizer: &mut dyn Optimizer<F>, ) -> Result<F>
Train the model on a batch of data