Trait OptimizableModel

Source
pub trait OptimizableModel<Input, Output>:
    BaseModel<Input, Output>
    + ParamCollection
    + GradientCollection {
    // Required methods
    fn forward(&self, input: &Input) -> Result<Output, ModelError>;
    fn backward(
        &mut self,
        input: &Input,
        output_grad: &Output,
    ) -> Result<(), ModelError>;
    fn compute_output_gradient(
        &self,
        x: &Input,
        y: &Output,
    ) -> Result<Output, ModelError>;
}

Required Methods§

Source

fn forward(&self, input: &Input) -> Result<Output, ModelError>

Forward pass through the model.

Source

fn backward( &mut self, input: &Input, output_grad: &Output, ) -> Result<(), ModelError>

Backward pass to compute gradients.

Source

fn compute_output_gradient( &self, x: &Input, y: &Output, ) -> Result<Output, ModelError>

Computes the gradient of the cost with respect to the output predictions

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl OptimizableModel<ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>>> for LinearRegression

Implementation of the OptimizableModel trait for LinearRegression

Source§

impl OptimizableModel<ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>>> for LogisticRegression