rust_ml/model/core/
param_collection.rs1use crate::core::error::ModelError;
2use ndarray::{ArrayView, ArrayViewMut, Dimension, IxDyn};
3use std::fmt::Debug;
4
5pub trait ParamCollection: Debug + Send + Sync {
7 fn get<D: Dimension>(&self, key: &str) -> Result<ArrayView<f64, D>, ModelError>;
9
10 fn get_mut<D: Dimension>(&mut self, key: &str) -> Result<ArrayViewMut<f64, D>, ModelError>;
11
12 fn set<D: Dimension>(&mut self, key: &str, value: ArrayView<f64, D>) -> Result<(), ModelError>;
14
15 fn param_iter(&self) -> Vec<(&str, ArrayView<f64, IxDyn>)>;
17}
18
19pub trait GradientCollection {
20 fn get_gradient<D: Dimension>(&self, key: &str) -> Result<ArrayView<f64, D>, ModelError>;
22
23 fn set_gradient<D: Dimension>(
25 &mut self,
26 key: &str,
27 value: ArrayView<f64, D>,
28 ) -> Result<(), ModelError>;
29}