pub struct Mlp { /* private fields */ }Expand description
A feed-forward multi-layer perceptron composed of dense layers.
Implementations§
Source§impl Mlp
impl Mlp
Sourcepub fn output_dim(&self) -> usize
pub fn output_dim(&self) -> usize
Returns the produced output dimension.
Sourcepub fn num_layers(&self) -> usize
pub fn num_layers(&self) -> usize
Returns the number of layers.
Sourcepub fn layer(&self, idx: usize) -> Option<&Layer>
pub fn layer(&self, idx: usize) -> Option<&Layer>
Returns a reference to a layer by index.
This is primarily useful for inspection and debugging.
Sourcepub fn scratch_batch(&self, batch_size: usize) -> BatchScratch
pub fn scratch_batch(&self, batch_size: usize) -> BatchScratch
Allocate a BatchScratch buffer suitable for this model and a fixed batch size.
Sourcepub fn backprop_scratch_batch(&self, batch_size: usize) -> BatchBackpropScratch
pub fn backprop_scratch_batch(&self, batch_size: usize) -> BatchBackpropScratch
Allocate a BatchBackpropScratch buffer suitable for this model and a fixed batch size.
Sourcepub fn forward<'a>(&self, input: &[f32], scratch: &'a mut Scratch) -> &'a [f32]
pub fn forward<'a>(&self, input: &[f32], scratch: &'a mut Scratch) -> &'a [f32]
Forward pass for a single sample.
Writes intermediate activations into scratch and returns the final output slice.
Shape contract:
input.len() == self.input_dim()scratchmust be built for thisMlp(same layer count and output sizes)
Sourcepub fn forward_batch<'a>(
&self,
inputs: &[f32],
scratch: &'a mut BatchScratch,
) -> &'a [f32]
pub fn forward_batch<'a>( &self, inputs: &[f32], scratch: &'a mut BatchScratch, ) -> &'a [f32]
Forward pass for a contiguous batch.
Writes intermediate activations into scratch and returns the final output buffer
for the whole batch (flat row-major).
Shape contract:
inputs.len() == batch_size * self.input_dim()scratchmust be built for thisMlpand the samebatch_size
Sourcepub fn backward<'a>(
&self,
input: &[f32],
scratch: &Scratch,
grads: &'a mut Gradients,
) -> &'a [f32]
pub fn backward<'a>( &self, input: &[f32], scratch: &Scratch, grads: &'a mut Gradients, ) -> &'a [f32]
Backward pass for a single sample, using the internal d_output buffer.
You must call forward first using the same input and scratch.
Before calling this, write the upstream gradient dL/d(output) into
grads.d_output_mut().
Overwrite semantics:
gradsis overwritten with gradients for this sample.
Returns dL/d(input).
Sourcepub fn backward_accumulate<'a>(
&self,
input: &[f32],
scratch: &Scratch,
grads: &'a mut Gradients,
) -> &'a [f32]
pub fn backward_accumulate<'a>( &self, input: &[f32], scratch: &Scratch, grads: &'a mut Gradients, ) -> &'a [f32]
Backward pass for a single sample (parameter accumulation semantics).
This is identical to backward except that parameter gradients are accumulated:
grads.d_weightsandgrads.d_biasesare accumulated into (+=)grads.d_layer_outputsandgrads.d_inputare overwritten
This is useful for mini-batch training.
You must call forward first using the same input and scratch.
Before calling this, write the upstream gradient dL/d(output) into
grads.d_output_mut().
Sourcepub fn backward_batch(
&self,
inputs: &[f32],
scratch: &BatchScratch,
d_outputs: &[f32],
grads: &mut Gradients,
backprop_scratch: &mut BatchBackpropScratch,
)
pub fn backward_batch( &self, inputs: &[f32], scratch: &BatchScratch, d_outputs: &[f32], grads: &mut Gradients, backprop_scratch: &mut BatchBackpropScratch, )
Backward pass for a contiguous batch.
This overwrites grads with the mean parameter gradients over the batch.
Inputs:
inputs: flat row-major with shape(batch_size, input_dim)scratch: activations fromforward_batchd_outputs: flat row-major upstream gradients with shape(batch_size, output_dim)
Sourcepub fn predict_into(
&self,
input: &[f32],
scratch: &mut Scratch,
out: &mut [f32],
) -> Result<()>
pub fn predict_into( &self, input: &[f32], scratch: &mut Scratch, out: &mut [f32], ) -> Result<()>
Shape-safe, non-allocating inference.
This validates shapes and returns Result instead of panicking.
Internally it uses the low-level forward hot path.
Sourcepub fn predict_one_into(
&self,
input: &[f32],
scratch: &mut Scratch,
out: &mut [f32],
) -> Result<()>
pub fn predict_one_into( &self, input: &[f32], scratch: &mut Scratch, out: &mut [f32], ) -> Result<()>
Shape-safe, non-allocating inference for a single input.
Alias of Mlp::predict_into.
Source§impl Mlp
impl Mlp
Sourcepub fn evaluate(
&self,
data: &Dataset,
loss_fn: Loss,
metrics: &[Metric],
) -> Result<EvalReport>
pub fn evaluate( &self, data: &Dataset, loss_fn: Loss, metrics: &[Metric], ) -> Result<EvalReport>
Evaluate a dataset with a loss and optional metrics.
Sourcepub fn fit(
&mut self,
train: &Dataset,
val: Option<&Dataset>,
cfg: FitConfig,
) -> Result<FitReport>
pub fn fit( &mut self, train: &Dataset, val: Option<&Dataset>, cfg: FitConfig, ) -> Result<FitReport>
Train the model on a dataset.
This is a “batteries included” API intended to be easy to use. Internally it still uses allocation-free forward/backward via scratch buffers.
Sourcepub fn predict(&self, data: &Dataset) -> Result<Vec<f32>>
pub fn predict(&self, data: &Dataset) -> Result<Vec<f32>>
Predict outputs for all inputs in data.
Returns a flat buffer with shape (len, output_dim).
Sourcepub fn predict_inputs(&self, inputs: &Inputs) -> Result<Vec<f32>>
pub fn predict_inputs(&self, inputs: &Inputs) -> Result<Vec<f32>>
Predict outputs for inputs (X).
Returns a flat buffer with shape (len, output_dim).
Sourcepub fn evaluate_mse(&self, data: &Dataset) -> Result<f32>
pub fn evaluate_mse(&self, data: &Dataset) -> Result<f32>
Evaluate mean MSE over a dataset.
This is a convenience wrapper around evaluate.
Source§impl Mlp
impl Mlp
Sourcepub fn to_json_string_pretty(&self) -> Result<String>
pub fn to_json_string_pretty(&self) -> Result<String>
Serialize the model to a pretty-printed JSON string.
Sourcepub fn to_json_string(&self) -> Result<String>
pub fn to_json_string(&self) -> Result<String>
Serialize the model to a compact JSON string.
Sourcepub fn from_json_str(s: &str) -> Result<Self>
pub fn from_json_str(s: &str) -> Result<Self>
Parse a model from a JSON string.