Skip to main content

Model

Trait Model 

Source
pub trait Model {
    // Required methods
    fn forward(
        &self,
        input: &ArrayView<'_, f64, Ix2>,
    ) -> TrainResult<Array<f64, Ix2>>;
    fn backward(
        &self,
        input: &ArrayView<'_, f64, Ix2>,
        grad_output: &ArrayView<'_, f64, Ix2>,
    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
    fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>>;
    fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>>;
    fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>);

    // Provided methods
    fn num_parameters(&self) -> usize { ... }
    fn state_dict(&self) -> HashMap<String, Vec<f64>> { ... }
    fn load_state_dict(
        &mut self,
        state: HashMap<String, Vec<f64>>,
    ) -> TrainResult<()> { ... }
    fn reset_parameters(&mut self) { ... }
}
Expand description

Trait for trainable models.

This trait defines the interface for models that can be trained with the Tensorlogic training infrastructure. Models must implement forward and backward passes, parameter management, and optional save/load functionality.

Required Methods§

Source

fn forward( &self, input: &ArrayView<'_, f64, Ix2>, ) -> TrainResult<Array<f64, Ix2>>

Perform a forward pass through the model.

§Arguments
  • input - Input tensor
§Returns

Output tensor from the model

Source

fn backward( &self, input: &ArrayView<'_, f64, Ix2>, grad_output: &ArrayView<'_, f64, Ix2>, ) -> TrainResult<HashMap<String, Array<f64, Ix2>>>

Perform a backward pass to compute gradients.

§Arguments
  • input - Input tensor used in forward pass
  • grad_output - Gradient of loss with respect to model output
§Returns

Gradients for each model parameter

Source

fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>>

Get a reference to the model’s parameters.

Source

fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>>

Get a mutable reference to the model’s parameters.

Source

fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>)

Set the model’s parameters.

Provided Methods§

Source

fn num_parameters(&self) -> usize

Get the number of parameters in the model.

Source

fn state_dict(&self) -> HashMap<String, Vec<f64>>

Save model state to a dictionary.

Source

fn load_state_dict( &mut self, state: HashMap<String, Vec<f64>>, ) -> TrainResult<()>

Load model state from a dictionary.

Source

fn reset_parameters(&mut self)

Reset model parameters (optional, for retraining).

Implementors§