tiny_recursive_rs/layers/
activations.rs1use candle_core::{Result, Tensor, DType, Device, Module};
5use candle_nn::{VarBuilder, Linear, linear, Init};
6
7fn find_multiple(a: usize, b: usize) -> usize {
9 ((a + b - 1) / b) * b
10}
11
12pub struct CastedLinear {
17 weight: Tensor,
18 bias: Option<Tensor>,
19}
20
21impl CastedLinear {
22 pub fn new(
30 in_features: usize,
31 out_features: usize,
32 bias: bool,
33 vb: VarBuilder,
34 ) -> Result<Self> {
35 let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
37 let weight = vb.get_with_hints((out_features, in_features), "weight", init_ws)?;
38
39 let bias = if bias {
40 let bound = 1. / (in_features as f64).sqrt();
42 let init_bs = Init::Uniform { lo: -bound, up: bound };
43 Some(vb.get_with_hints(out_features, "bias", init_bs)?)
44 } else {
45 None
46 };
47
48 Ok(Self { weight, bias })
49 }
50
51 pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
53 let input_dtype = input.dtype();
54
55 let weight = if self.weight.dtype() != input_dtype {
57 self.weight.to_dtype(input_dtype)?
58 } else {
59 self.weight.clone()
60 };
61
62 let weight_t = weight.t()?;
65 let output = input.broadcast_matmul(&weight_t)?;
66
67 if let Some(ref b) = self.bias {
69 let bias = if b.dtype() != input_dtype {
70 b.to_dtype(input_dtype)?
71 } else {
72 b.clone()
73 };
74 output.broadcast_add(&bias)
75 } else {
76 Ok(output)
77 }
78 }
79}
80
81pub struct SwiGLU {
86 gate_up_proj: CastedLinear,
87 down_proj: CastedLinear,
88}
89
90impl SwiGLU {
91 pub fn new(hidden_size: usize, expansion: f32, vb: VarBuilder) -> Result<Self> {
98 let inter = find_multiple(((expansion * hidden_size as f32 * 2.0 / 3.0).round() as usize), 256);
100
101 let gate_up_proj = CastedLinear::new(
102 hidden_size,
103 inter * 2,
104 false,
105 vb.pp("gate_up_proj"),
106 )?;
107
108 let down_proj = CastedLinear::new(
109 inter,
110 hidden_size,
111 false,
112 vb.pp("down_proj"),
113 )?;
114
115 Ok(Self {
116 gate_up_proj,
117 down_proj,
118 })
119 }
120
121 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
122 let gate_up = self.gate_up_proj.forward(x)?;
124
125 let last_dim = gate_up.dims().len() - 1;
127 let inter_size = gate_up.dim(last_dim)? / 2;
128
129 let gate = gate_up.narrow(last_dim, 0, inter_size)?;
130 let up = gate_up.narrow(last_dim, inter_size, inter_size)?;
131
132 let gate_activated = candle_nn::ops::silu(&gate)?;
134 let gated = gate_activated.mul(&up)?;
135
136 self.down_proj.forward(&gated)
138 }
139}
140
141pub struct LinearSwish {
146 linear: CastedLinear,
147 reverse: bool,
148}
149
150impl LinearSwish {
151 pub fn new(hidden_size: usize, reverse: bool, vb: VarBuilder) -> Result<Self> {
158 let linear = CastedLinear::new(
159 hidden_size,
160 hidden_size,
161 false,
162 vb.pp("linear"),
163 )?;
164
165 Ok(Self { linear, reverse })
166 }
167
168 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
169 if self.reverse {
170 let linear_out = self.linear.forward(x)?;
172 candle_nn::ops::silu(&linear_out)
173 } else {
174 let silu_out = candle_nn::ops::silu(x)?;
176 self.linear.forward(&silu_out)
177 }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use candle_nn::VarMap;
185
186 #[test]
187 fn test_find_multiple() {
188 assert_eq!(find_multiple(100, 256), 256);
189 assert_eq!(find_multiple(300, 256), 512);
190 assert_eq!(find_multiple(256, 256), 256);
191 assert_eq!(find_multiple(1, 256), 256);
192 }
193
194 #[test]
195 fn test_casted_linear_shape() -> Result<()> {
196 let device = Device::Cpu;
197 let varmap = VarMap::new();
198 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
199
200 let linear = CastedLinear::new(64, 128, true, vb)?;
201
202 let x = Tensor::randn(0f32, 1.0, (2, 16, 64), &device)?;
203 let out = linear.forward(&x)?;
204
205 assert_eq!(out.dims(), &[2, 16, 128]);
206
207 Ok(())
208 }
209
210 #[test]
211 fn test_swiglu_shape() -> Result<()> {
212 let device = Device::Cpu;
213 let varmap = VarMap::new();
214 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
215
216 let swiglu = SwiGLU::new(256, 4.0, vb)?;
217
218 let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
219 let out = swiglu.forward(&x)?;
220
221 assert_eq!(out.dims(), x.dims());
223
224 Ok(())
225 }
226
227 #[test]
228 fn test_linear_swish_shape() -> Result<()> {
229 let device = Device::Cpu;
230 let varmap = VarMap::new();
231 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
232
233 let lin_swish = LinearSwish::new(256, false, vb)?;
234
235 let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
236 let out = lin_swish.forward(&x)?;
237
238 assert_eq!(out.dims(), x.dims());
240
241 Ok(())
242 }
243}