Skip to main content

Mlp

Struct Mlp 

Source
pub struct Mlp { /* private fields */ }
Expand description

A feed-forward multi-layer perceptron composed of dense layers.

Implementations§

Source§

impl Mlp

Source

pub fn input_dim(&self) -> usize

Returns the expected input dimension.

Source

pub fn output_dim(&self) -> usize

Returns the produced output dimension.

Source

pub fn num_layers(&self) -> usize

Returns the number of layers.

Source

pub fn layer(&self, idx: usize) -> Option<&Layer>

Returns a reference to a layer by index.

This is primarily useful for inspection and debugging.

Source

pub fn scratch(&self) -> Scratch

Allocate a Scratch buffer suitable for this model.

Source

pub fn scratch_batch(&self, batch_size: usize) -> BatchScratch

Allocate a BatchScratch buffer suitable for this model and a fixed batch size.

Source

pub fn backprop_scratch_batch(&self, batch_size: usize) -> BatchBackpropScratch

Allocate a BatchBackpropScratch buffer suitable for this model and a fixed batch size.

Source

pub fn gradients(&self) -> Gradients

Allocate a Gradients buffer suitable for this model.

Source

pub fn trainer(&self) -> Trainer

Convenience constructor: allocate all training buffers.

Source

pub fn forward<'a>(&self, input: &[f32], scratch: &'a mut Scratch) -> &'a [f32]

Forward pass for a single sample.

Writes intermediate activations into scratch and returns the final output slice.

Shape contract:

  • input.len() == self.input_dim()
  • scratch must be built for this Mlp (same layer count and output sizes)
Source

pub fn forward_batch<'a>( &self, inputs: &[f32], scratch: &'a mut BatchScratch, ) -> &'a [f32]

Forward pass for a contiguous batch.

Writes intermediate activations into scratch and returns the final output buffer for the whole batch (flat row-major).

Shape contract:

  • inputs.len() == batch_size * self.input_dim()
  • scratch must be built for this Mlp and the same batch_size
Source

pub fn backward<'a>( &self, input: &[f32], scratch: &Scratch, grads: &'a mut Gradients, ) -> &'a [f32]

Backward pass for a single sample, using the internal d_output buffer.

You must call forward first using the same input and scratch.

Before calling this, write the upstream gradient dL/d(output) into grads.d_output_mut().

Overwrite semantics:

  • grads is overwritten with gradients for this sample.

Returns dL/d(input).

Source

pub fn backward_accumulate<'a>( &self, input: &[f32], scratch: &Scratch, grads: &'a mut Gradients, ) -> &'a [f32]

Backward pass for a single sample (parameter accumulation semantics).

This is identical to backward except that parameter gradients are accumulated:

  • grads.d_weights and grads.d_biases are accumulated into (+=)
  • grads.d_layer_outputs and grads.d_input are overwritten

This is useful for mini-batch training.

You must call forward first using the same input and scratch. Before calling this, write the upstream gradient dL/d(output) into grads.d_output_mut().

Source

pub fn backward_batch( &self, inputs: &[f32], scratch: &BatchScratch, d_outputs: &[f32], grads: &mut Gradients, backprop_scratch: &mut BatchBackpropScratch, )

Backward pass for a contiguous batch.

This overwrites grads with the mean parameter gradients over the batch.

Inputs:

  • inputs: flat row-major with shape (batch_size, input_dim)
  • scratch: activations from forward_batch
  • d_outputs: flat row-major upstream gradients with shape (batch_size, output_dim)
Source

pub fn sgd_step(&mut self, grads: &Gradients, lr: f32)

Applies an SGD update to all layers.

Source

pub fn predict_into( &self, input: &[f32], scratch: &mut Scratch, out: &mut [f32], ) -> Result<()>

Shape-safe, non-allocating inference.

This validates shapes and returns Result instead of panicking. Internally it uses the low-level forward hot path.

Source

pub fn predict_one_into( &self, input: &[f32], scratch: &mut Scratch, out: &mut [f32], ) -> Result<()>

Shape-safe, non-allocating inference for a single input.

Alias of Mlp::predict_into.

Source§

impl Mlp

Source

pub fn evaluate( &self, data: &Dataset, loss_fn: Loss, metrics: &[Metric], ) -> Result<EvalReport>

Evaluate a dataset with a loss and optional metrics.

Source

pub fn fit( &mut self, train: &Dataset, val: Option<&Dataset>, cfg: FitConfig, ) -> Result<FitReport>

Train the model on a dataset.

This is a “batteries included” API intended to be easy to use. Internally it still uses allocation-free forward/backward via scratch buffers.

Source

pub fn predict(&self, data: &Dataset) -> Result<Vec<f32>>

Predict outputs for all inputs in data.

Returns a flat buffer with shape (len, output_dim).

Source

pub fn predict_inputs(&self, inputs: &Inputs) -> Result<Vec<f32>>

Predict outputs for inputs (X).

Returns a flat buffer with shape (len, output_dim).

Source

pub fn evaluate_mse(&self, data: &Dataset) -> Result<f32>

Evaluate mean MSE over a dataset.

This is a convenience wrapper around evaluate.

Source§

impl Mlp

Source

pub fn to_json_string_pretty(&self) -> Result<String>

Serialize the model to a pretty-printed JSON string.

Source

pub fn to_json_string(&self) -> Result<String>

Serialize the model to a compact JSON string.

Source

pub fn from_json_str(s: &str) -> Result<Self>

Parse a model from a JSON string.

Source

pub fn save_json<P: AsRef<Path>>(&self, path: P) -> Result<()>

Save the model to a JSON file (pretty-printed).

Source

pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Self>

Load a model from a JSON file.

Trait Implementations§

Source§

impl Clone for Mlp

Source§

fn clone(&self) -> Mlp

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for Mlp

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl From<&Mlp> for SerializedMlp

Source§

fn from(model: &Mlp) -> Self

Converts to this type from the input type.
Source§

impl TryFrom<SerializedMlp> for Mlp

Source§

type Error = Error

The type returned in the event of a conversion error.
Source§

fn try_from(value: SerializedMlp) -> Result<Self, Self::Error>

Performs the conversion.

Auto Trait Implementations§

§

impl Freeze for Mlp

§

impl RefUnwindSafe for Mlp

§

impl Send for Mlp

§

impl Sync for Mlp

§

impl Unpin for Mlp

§

impl UnwindSafe for Mlp

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V