tiny_recursive_rs/layers/
activations.rs

1/// Activation functions for TRM
2///
3/// Based on the Python implementation in layers.py
4use candle_core::{Result, Tensor, DType, Device, Module};
5use candle_nn::{VarBuilder, Linear, linear, Init};
6
7/// Helper function to find the smallest multiple of b that is >= a
8fn find_multiple(a: usize, b: usize) -> usize {
9    ((a + b - 1) / b) * b
10}
11
12/// Linear layer with automatic dtype casting
13///
14/// Casts weights and bias to input dtype before computation.
15/// Uses truncated normal initialization (approximated here with normal).
16pub struct CastedLinear {
17    weight: Tensor,
18    bias: Option<Tensor>,
19}
20
21impl CastedLinear {
22    /// Create new CastedLinear layer
23    ///
24    /// # Arguments
25    /// * `in_features` - Input dimension
26    /// * `out_features` - Output dimension
27    /// * `bias` - Whether to include bias
28    /// * `vb` - VarBuilder for parameter initialization
29    pub fn new(
30        in_features: usize,
31        out_features: usize,
32        bias: bool,
33        vb: VarBuilder,
34    ) -> Result<Self> {
35        // Use Kaiming Normal initialization for weights (like candle-nn's Linear)
36        let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
37        let weight = vb.get_with_hints((out_features, in_features), "weight", init_ws)?;
38
39        let bias = if bias {
40            // Use uniform initialization for bias (like candle-nn's Linear)
41            let bound = 1. / (in_features as f64).sqrt();
42            let init_bs = Init::Uniform { lo: -bound, up: bound };
43            Some(vb.get_with_hints(out_features, "bias", init_bs)?)
44        } else {
45            None
46        };
47
48        Ok(Self { weight, bias })
49    }
50
51    /// Forward pass with automatic dtype casting
52    pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
53        let input_dtype = input.dtype();
54
55        // Cast weight to input dtype
56        let weight = if self.weight.dtype() != input_dtype {
57            self.weight.to_dtype(input_dtype)?
58        } else {
59            self.weight.clone()
60        };
61
62        // Perform linear transformation: input @ weight^T
63        // weight is [out_features, in_features], so weight^T is [in_features, out_features]
64        let weight_t = weight.t()?;
65        let output = input.broadcast_matmul(&weight_t)?;
66
67        // Add bias if present
68        if let Some(ref b) = self.bias {
69            let bias = if b.dtype() != input_dtype {
70                b.to_dtype(input_dtype)?
71            } else {
72                b.clone()
73            };
74            output.broadcast_add(&bias)
75        } else {
76            Ok(output)
77        }
78    }
79}
80
81/// SwiGLU activation: Swish-Gated Linear Unit
82///
83/// A gated activation function that combines SiLU (Swish) with gating.
84/// Formula: down_proj(silu(gate) * up)
85pub struct SwiGLU {
86    gate_up_proj: CastedLinear,
87    down_proj: CastedLinear,
88}
89
90impl SwiGLU {
91    /// Create new SwiGLU layer
92    ///
93    /// # Arguments
94    /// * `hidden_size` - Input/output dimension
95    /// * `expansion` - Expansion factor for intermediate dimension
96    /// * `vb` - VarBuilder for parameter initialization
97    pub fn new(hidden_size: usize, expansion: f32, vb: VarBuilder) -> Result<Self> {
98        // Calculate intermediate size and round to multiple of 256
99        let inter = find_multiple(((expansion * hidden_size as f32 * 2.0 / 3.0).round() as usize), 256);
100
101        let gate_up_proj = CastedLinear::new(
102            hidden_size,
103            inter * 2,
104            false,
105            vb.pp("gate_up_proj"),
106        )?;
107
108        let down_proj = CastedLinear::new(
109            inter,
110            hidden_size,
111            false,
112            vb.pp("down_proj"),
113        )?;
114
115        Ok(Self {
116            gate_up_proj,
117            down_proj,
118        })
119    }
120
121    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
122        // Project to 2x intermediate size
123        let gate_up = self.gate_up_proj.forward(x)?;
124
125        // Split into gate and up
126        let last_dim = gate_up.dims().len() - 1;
127        let inter_size = gate_up.dim(last_dim)? / 2;
128
129        let gate = gate_up.narrow(last_dim, 0, inter_size)?;
130        let up = gate_up.narrow(last_dim, inter_size, inter_size)?;
131
132        // Apply SiLU to gate and multiply with up
133        let gate_activated = candle_nn::ops::silu(&gate)?;
134        let gated = gate_activated.mul(&up)?;
135
136        // Project back down
137        self.down_proj.forward(&gated)
138    }
139}
140
141/// LinearSwish activation
142///
143/// Combines linear transformation with SiLU (Swish) activation.
144/// Can apply in either order based on `reverse` flag.
145pub struct LinearSwish {
146    linear: CastedLinear,
147    reverse: bool,
148}
149
150impl LinearSwish {
151    /// Create new LinearSwish layer
152    ///
153    /// # Arguments
154    /// * `hidden_size` - Input/output dimension
155    /// * `reverse` - If true: SiLU(Linear(x)), if false: Linear(SiLU(x))
156    /// * `vb` - VarBuilder for parameter initialization
157    pub fn new(hidden_size: usize, reverse: bool, vb: VarBuilder) -> Result<Self> {
158        let linear = CastedLinear::new(
159            hidden_size,
160            hidden_size,
161            false,
162            vb.pp("linear"),
163        )?;
164
165        Ok(Self { linear, reverse })
166    }
167
168    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
169        if self.reverse {
170            // SiLU(Linear(x))
171            let linear_out = self.linear.forward(x)?;
172            candle_nn::ops::silu(&linear_out)
173        } else {
174            // Linear(SiLU(x))
175            let silu_out = candle_nn::ops::silu(x)?;
176            self.linear.forward(&silu_out)
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use candle_nn::VarMap;
185
186    #[test]
187    fn test_find_multiple() {
188        assert_eq!(find_multiple(100, 256), 256);
189        assert_eq!(find_multiple(300, 256), 512);
190        assert_eq!(find_multiple(256, 256), 256);
191        assert_eq!(find_multiple(1, 256), 256);
192    }
193
194    #[test]
195    fn test_casted_linear_shape() -> Result<()> {
196        let device = Device::Cpu;
197        let varmap = VarMap::new();
198        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
199
200        let linear = CastedLinear::new(64, 128, true, vb)?;
201
202        let x = Tensor::randn(0f32, 1.0, (2, 16, 64), &device)?;
203        let out = linear.forward(&x)?;
204
205        assert_eq!(out.dims(), &[2, 16, 128]);
206
207        Ok(())
208    }
209
210    #[test]
211    fn test_swiglu_shape() -> Result<()> {
212        let device = Device::Cpu;
213        let varmap = VarMap::new();
214        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
215
216        let swiglu = SwiGLU::new(256, 4.0, vb)?;
217
218        let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
219        let out = swiglu.forward(&x)?;
220
221        // Output should have same shape as input
222        assert_eq!(out.dims(), x.dims());
223
224        Ok(())
225    }
226
227    #[test]
228    fn test_linear_swish_shape() -> Result<()> {
229        let device = Device::Cpu;
230        let varmap = VarMap::new();
231        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
232
233        let lin_swish = LinearSwish::new(256, false, vb)?;
234
235        let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
236        let out = lin_swish.forward(&x)?;
237
238        // Output should have same shape as input
239        assert_eq!(out.dims(), x.dims());
240
241        Ok(())
242    }
243}