Skip to main content

tml_utils/
conv.rs

1use crate::network::{Initializer, Layer, LayerDims, Optimizer, XavierUniform};
2use crate::{ConvGeometryIsValid, Float, tensor::Tensor};
3use rand::{Rng, SeedableRng, rngs::StdRng};
4use std::array;
5
6#[doc(hidden)]
7pub const fn conv_out_dim(input: usize, pad: usize, kernel: usize, stride: usize) -> usize {
8    if stride == 0 {
9        return 0;
10    }
11    let padded = input + 2 * pad;
12    if padded < kernel {
13        return 0;
14    }
15    let numer = padded - kernel;
16    if !numer.is_multiple_of(stride) {
17        return 0;
18    }
19    numer / stride + 1
20}
21
22#[derive(Debug, Clone)]
23pub struct Filter<const H: usize, const W: usize, const D: usize> {
24    weights: Tensor<crate::shape!(H, W, D)>,
25    grads: Box<[Float]>,
26}
27
28impl<const H: usize, const W: usize, const D: usize> Filter<H, W, D> {
29    fn zeroed() -> Self {
30        Self {
31            weights: Tensor::<crate::shape!(H, W, D)>::from_boxed(
32                vec![0.0 as Float; H * W * D].into_boxed_slice(),
33            ),
34            grads: vec![0.0 as Float; H * W * D].into_boxed_slice(),
35        }
36    }
37
38    fn weights(&self) -> &[Float] {
39        self.weights.raw_slice()
40    }
41
42    fn grads_mut(&mut self) -> &mut [Float] {
43        &mut self.grads[..]
44    }
45}
46
47/// A convolutional layer
48///
49/// `FH` - filter/kernel height
50/// `FW` - filter/kernel width
51/// `IC` - number of input channels
52/// `OC` - number of output channels (equivalently, number of kernels/filters)
53/// `S` - stride
54/// `P` - padding
55#[derive(Debug)]
56pub struct Conv<
57    const IW: usize,
58    const IH: usize,
59    const IC: usize,
60    const FH: usize,
61    const FW: usize,
62    const OC: usize,
63    const S: usize,
64    const P: usize,
65> {
66    filters: [Filter<FH, FW, IC>; OC],
67    biases: Box<[Float; OC]>,
68    bias_grads: Box<[Float; OC]>,
69}
70impl<
71    const IW: usize,
72    const IH: usize,
73    const IC: usize,
74    const FH: usize,
75    const FW: usize,
76    const OC: usize,
77    const S: usize,
78    const P: usize,
79> Conv<IW, IH, IC, FH, FW, OC, S, P>
80where
81    [(); IC * IH * IW]:,
82    [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
83    (): ConvGeometryIsValid<IH, IW, FH, FW, S, P>,
84{
85    pub fn init() -> Self {
86        Self::with_initializer(XavierUniform)
87    }
88
89    pub fn seeded(seed: u64) -> Self {
90        Self::with_initializer_and_seed(XavierUniform, seed)
91    }
92
93    pub fn with_initializer<I: Initializer>(initializer: I) -> Self {
94        let mut rng = rand::rng();
95        Self::with_initializer_and_rng(initializer, &mut rng)
96    }
97
98    pub fn with_initializer_and_seed<I: Initializer>(initializer: I, seed: u64) -> Self {
99        let mut rng = StdRng::seed_from_u64(seed);
100        Self::with_initializer_and_rng(initializer, &mut rng)
101    }
102
103    pub fn with_initializer_and_rng<I: Initializer, R: Rng + ?Sized>(
104        initializer: I,
105        rng: &mut R,
106    ) -> Self {
107        let mut conv = Conv {
108            filters: array::from_fn(|_| Filter::zeroed()),
109            biases: Box::new([0.0 as Float; OC]),
110            bias_grads: Box::new([0.0 as Float; OC]),
111        };
112        let fan_in = FH * FW * IC;
113        let fan_out = FH * FW * OC;
114        for filter in &mut conv.filters {
115            initializer.fill(filter.weights.raw_mut_slice(), fan_in, fan_out, rng);
116        }
117        conv
118    }
119
120    pub fn create_output_space(&self) -> <Self as ConvIO>::Output {
121        Tensor::<crate::shape!(
122            OC,
123            conv_out_dim(IH, P, FH, S),
124            conv_out_dim(IW, P, FW, S)
125        )>::from_boxed(
126            vec![
127                0.0 as Float;
128                OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)
129            ]
130            .into_boxed_slice(),
131        )
132    }
133
134    pub fn input_from_data(&self, data: [Float; IC * IH * IW]) -> <Self as ConvIO>::Input {
135        Tensor::<crate::shape!(IC, IH, IW)>::from_boxed(Vec::from(data).into_boxed_slice())
136    }
137
138    pub fn forward(
139        &self,
140        input: &Tensor<crate::shape!(IC, IH, IW)>,
141        output: &mut Tensor<
142            crate::shape!(OC, conv_out_dim(IH, P, FH, S), conv_out_dim(IW, P, FW, S)),
143        >,
144    ) {
145        let input_arr: &[Float; IC * IH * IW] = input.raw_slice().try_into().expect("bad input");
146        let output_arr: &mut [Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)] =
147            output.raw_mut_slice().try_into().expect("bad output");
148        self.forward_flat(input_arr, output_arr);
149    }
150
151    pub fn forward_flat(
152        &self,
153        input: &[Float; IC * IH * IW],
154        output: &mut [Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
155    ) {
156        let out_h = conv_out_dim(IH, P, FH, S);
157        let out_w = conv_out_dim(IW, P, FW, S);
158
159        for oc in 0..OC {
160            let filter_data = self.filters[oc].weights();
161
162            for y in 0..out_h {
163                for x in 0..out_w {
164                    let mut sum = self.biases[oc];
165
166                    for ky in 0..FH {
167                        for kx in 0..FW {
168                            for ic in 0..IC {
169                                let in_y = y * S + ky;
170                                let in_x = x * S + kx;
171                                let in_y = in_y as isize - P as isize;
172                                let in_x = in_x as isize - P as isize;
173
174                                if in_y >= 0
175                                    && in_y < IH as isize
176                                    && in_x >= 0
177                                    && in_x < IW as isize
178                                {
179                                    let in_y = in_y as usize;
180                                    let in_x = in_x as usize;
181                                    let input_idx = ic * IH * IW + in_y * IW + in_x;
182                                    let filter_idx = (ky * FW + kx) * IC + ic;
183                                    sum += filter_data[filter_idx] * input[input_idx];
184                                }
185                            }
186                        }
187                    }
188
189                    let output_idx = oc * out_h * out_w + y * out_w + x;
190                    output[output_idx] = sum;
191                }
192            }
193        }
194    }
195
196    pub fn backward_flat(
197        &mut self,
198        input: &[Float; IC * IH * IW],
199        output_grad: &[Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
200        input_grad: &mut [Float; IC * IH * IW],
201    ) {
202        let out_h = conv_out_dim(IH, P, FH, S);
203        let out_w = conv_out_dim(IW, P, FW, S);
204
205        input_grad.fill(0.0);
206
207        for oc in 0..OC {
208            let Filter { weights, grads } = &mut self.filters[oc];
209            let filter_weights = weights.raw_slice();
210            let filter_grads = &mut grads[..];
211
212            for y in 0..out_h {
213                for x in 0..out_w {
214                    let output_idx = oc * out_h * out_w + y * out_w + x;
215                    let grad = output_grad[output_idx];
216                    self.bias_grads[oc] += grad;
217
218                    for ky in 0..FH {
219                        for kx in 0..FW {
220                            for ic in 0..IC {
221                                let in_y = y * S + ky;
222                                let in_x = x * S + kx;
223                                let in_y = in_y as isize - P as isize;
224                                let in_x = in_x as isize - P as isize;
225
226                                if in_y >= 0
227                                    && in_y < IH as isize
228                                    && in_x >= 0
229                                    && in_x < IW as isize
230                                {
231                                    let in_y = in_y as usize;
232                                    let in_x = in_x as usize;
233                                    let input_idx = ic * IH * IW + in_y * IW + in_x;
234                                    let filter_idx = (ky * FW + kx) * IC + ic;
235
236                                    filter_grads[filter_idx] += grad * input[input_idx];
237                                    input_grad[input_idx] += grad * filter_weights[filter_idx];
238                                }
239                            }
240                        }
241                    }
242                }
243            }
244        }
245    }
246}
247
248impl<
249    const IW: usize,
250    const IH: usize,
251    const IC: usize,
252    const FH: usize,
253    const FW: usize,
254    const OC: usize,
255    const S: usize,
256    const P: usize,
257> LayerDims for Conv<IW, IH, IC, FH, FW, OC, S, P>
258where
259    [(); IC * IH * IW]:,
260    [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
261    (): ConvGeometryIsValid<IH, IW, FH, FW, S, P>,
262{
263    const INPUT: usize = IC * IH * IW;
264    const OUTPUT: usize = OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S);
265}
266
267impl<
268    const IW: usize,
269    const IH: usize,
270    const IC: usize,
271    const FH: usize,
272    const FW: usize,
273    const OC: usize,
274    const S: usize,
275    const P: usize,
276> Layer<{ IC * IH * IW }, { OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S) }>
277    for Conv<IW, IH, IC, FH, FW, OC, S, P>
278where
279    [(); IC * IH * IW]:,
280    [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
281    (): ConvGeometryIsValid<IH, IW, FH, FW, S, P>,
282{
283    fn forward(
284        &self,
285        input: &[Float; IC * IH * IW],
286        output: &mut [Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
287    ) {
288        self.forward_flat(input, output);
289    }
290
291    fn backward(
292        &mut self,
293        input: &[Float; IC * IH * IW],
294        _output: &[Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
295        output_grad: &[Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
296        input_grad: &mut [Float; IC * IH * IW],
297    ) {
298        self.backward_flat(input, output_grad, input_grad);
299    }
300
301    fn zero_grad(&mut self) {
302        self.bias_grads.fill(0.0);
303        for filter in &mut self.filters {
304            filter.grads_mut().fill(0.0);
305        }
306    }
307
308    fn apply_gradients(&mut self, optimizer: &mut dyn Optimizer, slot: &mut usize, scale: Float) {
309        for filter in &mut self.filters {
310            optimizer.update_parameter(
311                *slot,
312                filter.weights.raw_mut_slice(),
313                filter.grads.as_ref(),
314                scale,
315            );
316            *slot += 1;
317            filter.grads_mut().fill(0.0);
318        }
319        optimizer.update_parameter(
320            *slot,
321            self.biases.as_mut_slice(),
322            self.bias_grads.as_slice(),
323            scale,
324        );
325        *slot += 1;
326        self.bias_grads.fill(0.0);
327    }
328}
329
330#[allow(dead_code)]
331/// Type-level input/output tensor metadata for conv layers.
332pub trait ConvIO {
333    type Output;
334    type Input;
335    type OutputShape;
336    type InputShape;
337    type FilterShape;
338    const N: usize;
339}
340
341impl<
342    const IW: usize,
343    const IH: usize,
344    const IC: usize,
345    const FH: usize,
346    const FW: usize,
347    const OC: usize,
348    const S: usize,
349    const P: usize,
350> ConvIO for Conv<IW, IH, IC, FH, FW, OC, S, P>
351where
352    [(); IC * IH * IW]:,
353    [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
354{
355    const N: usize = IC * IH * IW;
356    type Input = Tensor<crate::shape!(IC, IH, IW)>;
357    type Output = Tensor<Self::OutputShape>;
358    type InputShape = crate::shape!(IC, IH, IW);
359    type OutputShape = crate::shape!(OC, conv_out_dim(IH, P, FH, S), conv_out_dim(IW, P, FW, S));
360    type FilterShape = crate::shape!(FH, FW, IC);
361}
362
363#[allow(dead_code)]
364/// Flat-array convenience trait for generic conv code.
365pub trait ConvOps: ConvIO {
366    type InputArray;
367    type OutputArray;
368    type FilterArray;
369
370    const INPUT_SIZE: usize;
371    const OUTPUT_SIZE: usize;
372    const FILTER_SIZE: usize;
373
374    fn init() -> Self;
375    fn forward_flat(&self, input: &Self::InputArray, output: &mut Self::OutputArray);
376    fn input_from_fn<F: FnMut(usize) -> Float>(f: F) -> Self::InputArray;
377    fn output_zeroed() -> Self::OutputArray;
378}
379
380impl<
381    const IW: usize,
382    const IH: usize,
383    const IC: usize,
384    const FH: usize,
385    const FW: usize,
386    const OC: usize,
387    const S: usize,
388    const P: usize,
389> ConvOps for Conv<IW, IH, IC, FH, FW, OC, S, P>
390where
391    [(); FH * FW * IC]:,
392    [(); IC * IH * IW]:,
393    [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
394    (): ConvGeometryIsValid<IH, IW, FH, FW, S, P>,
395{
396    type InputArray = [Float; IC * IH * IW];
397    type OutputArray = [Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)];
398    type FilterArray = [Float; FH * FW * IC];
399
400    const INPUT_SIZE: usize = IC * IH * IW;
401    const OUTPUT_SIZE: usize = OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S);
402    const FILTER_SIZE: usize = FH * FW * IC;
403
404    fn init() -> Self {
405        Conv::<IW, IH, IC, FH, FW, OC, S, P>::init()
406    }
407
408    fn forward_flat(&self, input: &Self::InputArray, output: &mut Self::OutputArray) {
409        Conv::<IW, IH, IC, FH, FW, OC, S, P>::forward_flat(self, input, output);
410    }
411
412    fn input_from_fn<F: FnMut(usize) -> Float>(f: F) -> Self::InputArray {
413        array::from_fn(f)
414    }
415
416    fn output_zeroed() -> Self::OutputArray {
417        array::from_fn(|_| 0.0 as Float)
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    type ConvCase = Conv<3, 3, 1, 2, 2, 1, 1, 0>;
426    const IN_SIZE: usize = 3 * 3;
427    const OUT_SIZE: usize = 4;
428
429    fn approx_eq(a: Float, b: Float, eps: Float) {
430        let diff = (a - b).abs();
431        assert!(diff <= eps, "expected {a} ~= {b} (diff={diff}, eps={eps})");
432    }
433
434    fn configured_conv() -> ConvCase {
435        let mut conv = ConvCase::init();
436        for (i, w) in conv.filters[0]
437            .weights
438            .raw_mut_slice()
439            .iter_mut()
440            .enumerate()
441        {
442            *w = 0.1 * (i as Float + 1.0);
443        }
444        conv.biases[0] = 0.05;
445        conv
446    }
447
448    fn objective(
449        conv: &ConvCase,
450        input: &[Float; IN_SIZE],
451        output_grad: &[Float; OUT_SIZE],
452    ) -> Float {
453        let mut output = [0.0; OUT_SIZE];
454        conv.forward_flat(input, &mut output);
455        output
456            .iter()
457            .zip(output_grad.iter())
458            .map(|(o, g)| o * g)
459            .sum()
460    }
461
462    #[test]
463    fn input_gradient_matches_finite_difference() {
464        let mut conv = configured_conv();
465        let input = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
466        let output_grad = [0.3, -0.2, 0.1, 0.4];
467        let mut input_grad = [0.0; IN_SIZE];
468
469        conv.zero_grad();
470        conv.backward_flat(&input, &output_grad, &mut input_grad);
471
472        let eps = 1e-7;
473        for i in 0..IN_SIZE {
474            let mut plus = input;
475            let mut minus = input;
476            plus[i] += eps;
477            minus[i] -= eps;
478            let f_plus = objective(&conv, &plus, &output_grad);
479            let f_minus = objective(&conv, &minus, &output_grad);
480            let numeric = (f_plus - f_minus) / (2.0 * eps);
481            approx_eq(input_grad[i], numeric, 1e-6);
482        }
483    }
484
485    #[test]
486    fn weight_update_matches_finite_difference_gradient() {
487        let mut conv = configured_conv();
488        let input = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
489        let output_grad = [0.3, -0.2, 0.1, 0.4];
490        let mut input_grad = [0.0; IN_SIZE];
491        let weight_idx = 2;
492
493        let eps = 1e-7;
494        let mut conv_plus = configured_conv();
495        conv_plus.filters[0].weights.raw_mut_slice()[weight_idx] += eps;
496        let mut conv_minus = configured_conv();
497        conv_minus.filters[0].weights.raw_mut_slice()[weight_idx] -= eps;
498        let numeric = (objective(&conv_plus, &input, &output_grad)
499            - objective(&conv_minus, &input, &output_grad))
500            / (2.0 * eps);
501
502        conv.zero_grad();
503        conv.backward_flat(&input, &output_grad, &mut input_grad);
504        let analytic = conv.filters[0].grads[weight_idx];
505
506        approx_eq(analytic, numeric, 1e-6);
507    }
508}