rust_ml/model/core/
base.rs1use crate::core::error::ModelError;
2
3use super::param_collection::{GradientCollection, ParamCollection};
4
5pub trait BaseModel<Input, Output> {
6 fn predict(&self, x: &Input) -> Result<Output, ModelError>;
12
13 fn compute_cost(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
19}
20
21pub trait OptimizableModel<Input, Output>:
22 BaseModel<Input, Output> + ParamCollection + GradientCollection
23{
24 fn forward(&self, input: &Input) -> Result<Output, ModelError>;
26
27 fn backward(&mut self, input: &Input, output_grad: &Output) -> Result<(), ModelError>;
29
30 fn compute_output_gradient(&self, x: &Input, y: &Output) -> Result<Output, ModelError>;
32}