1use rand::rngs::StdRng;
15use rand::{Rng, SeedableRng};
16
17use crate::{Activation, Error, Init, Layer, Mlp, Result};
18
19#[derive(Debug, Clone, Copy)]
20struct LayerSpec {
21 out_dim: usize,
22 activation: Activation,
23}
24
25#[derive(Debug, Clone)]
26pub struct MlpBuilder {
42 input_dim: usize,
43 layers: Vec<LayerSpec>,
44}
45
46impl MlpBuilder {
47 pub fn new(input_dim: usize) -> Result<Self> {
49 if input_dim == 0 {
50 return Err(Error::InvalidConfig("input_dim must be > 0".to_owned()));
51 }
52 Ok(Self {
53 input_dim,
54 layers: Vec::new(),
55 })
56 }
57
58 pub fn from_sizes(sizes: &[usize], activations: &[Activation]) -> Result<Self> {
63 if sizes.len() < 2 {
64 return Err(Error::InvalidConfig(
65 "sizes must include input and output dims".to_owned(),
66 ));
67 }
68 if sizes.contains(&0) {
69 return Err(Error::InvalidConfig(
70 "all layer sizes must be > 0".to_owned(),
71 ));
72 }
73 if activations.len() != sizes.len() - 1 {
74 return Err(Error::InvalidConfig(format!(
75 "activations length {} does not match sizes.len() - 1 ({})",
76 activations.len(),
77 sizes.len() - 1
78 )));
79 }
80
81 let mut b = Self::new(sizes[0])?;
82 for (out_dim, &act) in sizes[1..].iter().zip(activations) {
83 b = b.add_layer(*out_dim, act)?;
84 }
85 Ok(b)
86 }
87
88 pub fn add_layer(mut self, out_dim: usize, activation: Activation) -> Result<Self> {
92 if out_dim == 0 {
93 return Err(Error::InvalidConfig("layer out_dim must be > 0".to_owned()));
94 }
95 activation.validate()?;
96
97 self.layers.push(LayerSpec {
98 out_dim,
99 activation,
100 });
101 Ok(self)
102 }
103
104 pub fn build_with_seed(self, seed: u64) -> Result<Mlp> {
106 let mut rng = StdRng::seed_from_u64(seed);
107 self.build_with_rng(&mut rng)
108 }
109
110 pub fn build_with_rng<R: Rng + ?Sized>(self, rng: &mut R) -> Result<Mlp> {
112 if self.layers.is_empty() {
113 return Err(Error::InvalidConfig(
114 "mlp must have at least one layer".to_owned(),
115 ));
116 }
117
118 let mut layers = Vec::with_capacity(self.layers.len());
119 let mut in_dim = self.input_dim;
120 for spec in self.layers {
121 let init = default_init_for_activation(spec.activation);
122 let layer = Layer::new_with_rng(in_dim, spec.out_dim, init, spec.activation, rng)?;
123 layers.push(layer);
124 in_dim = spec.out_dim;
125 }
126
127 Ok(Mlp::from_layers(layers))
128 }
129}
130
131#[inline]
132fn default_init_for_activation(act: Activation) -> Init {
133 match act {
134 Activation::Tanh | Activation::Sigmoid | Activation::Identity => Init::Xavier,
135 Activation::ReLU | Activation::LeakyReLU { .. } => Init::He,
136 }
137}