rust_ml/model/core/
param_collection.rs

1use crate::core::error::ModelError;
2use ndarray::{ArrayView, ArrayViewMut, Dimension, IxDyn};
3use std::fmt::Debug;
4
5/// Provides access to parameters.
6pub trait ParamCollection: Debug + Send + Sync {
7    /// Get a reference to a specific parameter with strong typing.
8    fn get<D: Dimension>(&self, key: &str) -> Result<ArrayView<f64, D>, ModelError>;
9
10    fn get_mut<D: Dimension>(&mut self, key: &str) -> Result<ArrayViewMut<f64, D>, ModelError>;
11
12    /// Set the value of a parameter.
13    fn set<D: Dimension>(&mut self, key: &str, value: ArrayView<f64, D>) -> Result<(), ModelError>;
14
15    /// Iterate over all parameters.
16    fn param_iter(&self) -> Vec<(&str, ArrayView<f64, IxDyn>)>;
17}
18
19pub trait GradientCollection {
20    /// Get a reference to a specific gradient with strong typing.
21    fn get_gradient<D: Dimension>(&self, key: &str) -> Result<ArrayView<f64, D>, ModelError>;
22
23    /// Set the value of a gradient.
24    fn set_gradient<D: Dimension>(
25        &mut self,
26        key: &str,
27        value: ArrayView<f64, D>,
28    ) -> Result<(), ModelError>;
29}