Skip to main content

rust_mlp/
builder.rs

1//! Model builder.
2//!
3//! `MlpBuilder` is the recommended way to define a model.
4//!
5//! It makes model structure explicit (layer sizes + activations) and chooses a
6//! reasonable default weight initializer for each activation:
7//!
8//! - `tanh` / `sigmoid` / `identity`: Xavier/Glorot
9//! - `relu` / `leaky relu`: He/Kaiming
10//!
11//! The resulting `Mlp` still supports the low-level, allocation-free hot path:
12//! reuse `Scratch` / `Gradients` for per-sample forward/backward.
13
14use rand::rngs::StdRng;
15use rand::{Rng, SeedableRng};
16
17use crate::{Activation, Error, Init, Layer, Mlp, Result};
18
19#[derive(Debug, Clone, Copy)]
20struct LayerSpec {
21    out_dim: usize,
22    activation: Activation,
23}
24
25#[derive(Debug, Clone)]
26/// Builder for an `Mlp`.
27///
28/// Example:
29///
30/// ```rust
31/// use rust_mlp::{Activation, MlpBuilder};
32///
33/// # fn main() -> rust_mlp::Result<()> {
34/// let mlp = MlpBuilder::new(2)?
35///     .add_layer(8, Activation::ReLU)?
36///     .add_layer(1, Activation::Sigmoid)?
37///     .build_with_seed(0)?;
38/// # Ok(())
39/// # }
40/// ```
41pub struct MlpBuilder {
42    input_dim: usize,
43    layers: Vec<LayerSpec>,
44}
45
46impl MlpBuilder {
47    /// Start building an MLP that accepts inputs of length `input_dim`.
48    pub fn new(input_dim: usize) -> Result<Self> {
49        if input_dim == 0 {
50            return Err(Error::InvalidConfig("input_dim must be > 0".to_owned()));
51        }
52        Ok(Self {
53            input_dim,
54            layers: Vec::new(),
55        })
56    }
57
58    /// Convenience constructor from a sizes list + activations.
59    ///
60    /// `sizes` includes input and output dimensions, so its length must be at least 2.
61    /// `activations` must have length `sizes.len() - 1`.
62    pub fn from_sizes(sizes: &[usize], activations: &[Activation]) -> Result<Self> {
63        if sizes.len() < 2 {
64            return Err(Error::InvalidConfig(
65                "sizes must include input and output dims".to_owned(),
66            ));
67        }
68        if sizes.contains(&0) {
69            return Err(Error::InvalidConfig(
70                "all layer sizes must be > 0".to_owned(),
71            ));
72        }
73        if activations.len() != sizes.len() - 1 {
74            return Err(Error::InvalidConfig(format!(
75                "activations length {} does not match sizes.len() - 1 ({})",
76                activations.len(),
77                sizes.len() - 1
78            )));
79        }
80
81        let mut b = Self::new(sizes[0])?;
82        for (out_dim, &act) in sizes[1..].iter().zip(activations) {
83            b = b.add_layer(*out_dim, act)?;
84        }
85        Ok(b)
86    }
87
88    /// Add a dense layer.
89    ///
90    /// The layer will have `out_dim` outputs and uses `activation`.
91    pub fn add_layer(mut self, out_dim: usize, activation: Activation) -> Result<Self> {
92        if out_dim == 0 {
93            return Err(Error::InvalidConfig("layer out_dim must be > 0".to_owned()));
94        }
95        activation.validate()?;
96
97        self.layers.push(LayerSpec {
98            out_dim,
99            activation,
100        });
101        Ok(self)
102    }
103
104    /// Build using a deterministic seed.
105    pub fn build_with_seed(self, seed: u64) -> Result<Mlp> {
106        let mut rng = StdRng::seed_from_u64(seed);
107        self.build_with_rng(&mut rng)
108    }
109
110    /// Build using the provided RNG.
111    pub fn build_with_rng<R: Rng + ?Sized>(self, rng: &mut R) -> Result<Mlp> {
112        if self.layers.is_empty() {
113            return Err(Error::InvalidConfig(
114                "mlp must have at least one layer".to_owned(),
115            ));
116        }
117
118        let mut layers = Vec::with_capacity(self.layers.len());
119        let mut in_dim = self.input_dim;
120        for spec in self.layers {
121            let init = default_init_for_activation(spec.activation);
122            let layer = Layer::new_with_rng(in_dim, spec.out_dim, init, spec.activation, rng)?;
123            layers.push(layer);
124            in_dim = spec.out_dim;
125        }
126
127        Ok(Mlp::from_layers(layers))
128    }
129}
130
131#[inline]
132fn default_init_for_activation(act: Activation) -> Init {
133    match act {
134        Activation::Tanh | Activation::Sigmoid | Activation::Identity => Init::Xavier,
135        Activation::ReLU | Activation::LeakyReLU { .. } => Init::He,
136    }
137}