pub trait Model {
// Required methods
fn forward(
&self,
input: &ArrayView<'_, f64, Ix2>,
) -> TrainResult<Array<f64, Ix2>>;
fn backward(
&self,
input: &ArrayView<'_, f64, Ix2>,
grad_output: &ArrayView<'_, f64, Ix2>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>>;
fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>>;
fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>);
// Provided methods
fn num_parameters(&self) -> usize { ... }
fn state_dict(&self) -> HashMap<String, Vec<f64>> { ... }
fn load_state_dict(
&mut self,
state: HashMap<String, Vec<f64>>,
) -> TrainResult<()> { ... }
fn reset_parameters(&mut self) { ... }
}Expand description
Trait for trainable models.
This trait defines the interface for models that can be trained with the Tensorlogic training infrastructure. Models must implement forward and backward passes, parameter management, and optional save/load functionality.
Required Methods§
Sourcefn backward(
&self,
input: &ArrayView<'_, f64, Ix2>,
grad_output: &ArrayView<'_, f64, Ix2>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>>
fn backward( &self, input: &ArrayView<'_, f64, Ix2>, grad_output: &ArrayView<'_, f64, Ix2>, ) -> TrainResult<HashMap<String, Array<f64, Ix2>>>
Sourcefn parameters(&self) -> &HashMap<String, Array<f64, Ix2>>
fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>>
Get a reference to the model’s parameters.
Provided Methods§
Sourcefn num_parameters(&self) -> usize
fn num_parameters(&self) -> usize
Get the number of parameters in the model.
Sourcefn load_state_dict(
&mut self,
state: HashMap<String, Vec<f64>>,
) -> TrainResult<()>
fn load_state_dict( &mut self, state: HashMap<String, Vec<f64>>, ) -> TrainResult<()>
Load model state from a dictionary.
Sourcefn reset_parameters(&mut self)
fn reset_parameters(&mut self)
Reset model parameters (optional, for retraining).