Skip to main content

ruvector_cnn/simd/
winograd.rs

1//! Winograd F(2,3) Convolution Implementation
2//!
3//! Implements Winograd minimal filtering for 3x3 convolutions, reducing
4//! multiplications from 36 to 16 for each 2x2 output tile.
5//!
6//! # Algorithm
7//!
8//! For F(2,3) (2x2 output, 3x3 filter):
9//! - Input tile: 4x4
10//! - Filter: 3x3
11//! - Output tile: 2x2
12//!
13//! Y = A^T [[G g G^T] ⊙ [B^T d B]] A
14//!
15//! Where:
16//! - g = 3x3 filter
17//! - d = 4x4 input tile
18//! - G, B^T, A^T = Winograd transform matrices
19//! - ⊙ = element-wise multiplication
20//!
21//! # Performance
22//!
23//! - Standard 3x3 conv: 9 muls × 4 outputs = 36 multiplications
24//! - Winograd F(2,3): 16 multiplications (4×4 element-wise)
25//! - Theoretical speedup: 2.25×
26//! - Practical speedup: 1.8-2.5× (varies with hardware)
27//!
28//! # Trade-offs
29//!
30//! - Transform overhead for small batch sizes
31//! - Numerical precision slightly reduced
32//! - Works best with stride=1, no dilation
33//! - Memory overhead for storing transformed weights
34
35#[cfg(target_arch = "x86_64")]
36use std::arch::x86_64::*;
37
38/// Winograd F(2,3) transform matrices (pre-computed constants)
39pub mod transforms {
40    /// G matrix for filter transform (4x3)
41    /// Transforms 3x3 filter to 4x4 Winograd domain
42    #[rustfmt::skip]
43    pub const G: [[f32; 3]; 4] = [
44        [ 1.0,   0.0,   0.0  ],
45        [ 0.5,   0.5,   0.5  ],
46        [ 0.5,  -0.5,   0.5  ],
47        [ 0.0,   0.0,   1.0  ],
48    ];
49
50    /// G^T matrix (3x4)
51    #[rustfmt::skip]
52    pub const G_T: [[f32; 4]; 3] = [
53        [ 1.0,  0.5,  0.5,  0.0 ],
54        [ 0.0,  0.5, -0.5,  0.0 ],
55        [ 0.0,  0.5,  0.5,  1.0 ],
56    ];
57
58    /// B^T matrix for input transform (4x4)
59    /// Transforms 4x4 input tile to Winograd domain
60    #[rustfmt::skip]
61    pub const B_T: [[f32; 4]; 4] = [
62        [ 1.0,  0.0, -1.0,  0.0 ],
63        [ 0.0,  1.0,  1.0,  0.0 ],
64        [ 0.0, -1.0,  1.0,  0.0 ],
65        [ 0.0,  1.0,  0.0, -1.0 ],
66    ];
67
68    /// B matrix (4x4) - transpose of B^T
69    #[rustfmt::skip]
70    pub const B: [[f32; 4]; 4] = [
71        [ 1.0,  0.0,  0.0,  0.0 ],
72        [ 0.0,  1.0, -1.0,  1.0 ],
73        [-1.0,  1.0,  1.0,  0.0 ],
74        [ 0.0,  0.0,  0.0, -1.0 ],
75    ];
76
77    /// A^T matrix for output transform (2x4)
78    /// Transforms 4x4 Winograd result to 2x2 output tile
79    #[rustfmt::skip]
80    pub const A_T: [[f32; 4]; 2] = [
81        [ 1.0,  1.0,  1.0,  0.0 ],
82        [ 0.0,  1.0, -1.0, -1.0 ],
83    ];
84
85    /// A matrix (4x2) - transpose of A^T
86    #[rustfmt::skip]
87    pub const A: [[f32; 2]; 4] = [
88        [ 1.0,  0.0 ],
89        [ 1.0,  1.0 ],
90        [ 1.0, -1.0 ],
91        [ 0.0, -1.0 ],
92    ];
93}
94
95/// Transform a 3x3 filter to 4x4 Winograd domain
96///
97/// Computes: U = G × g × G^T
98///
99/// # Arguments
100/// * `filter` - 3x3 filter weights (row-major)
101///
102/// # Returns
103/// * 4x4 transformed filter (row-major)
104pub fn transform_filter(filter: &[f32; 9]) -> [f32; 16] {
105    let g = [
106        [filter[0], filter[1], filter[2]],
107        [filter[3], filter[4], filter[5]],
108        [filter[6], filter[7], filter[8]],
109    ];
110
111    // Compute Gg = G × g (4x3 matrix)
112    let mut gg = [[0.0f32; 3]; 4];
113    for i in 0..4 {
114        for j in 0..3 {
115            for k in 0..3 {
116                gg[i][j] += transforms::G[i][k] * g[k][j];
117            }
118        }
119    }
120
121    // Compute U = Gg × G^T (4x4 matrix)
122    let mut u = [0.0f32; 16];
123    for i in 0..4 {
124        for j in 0..4 {
125            let mut sum = 0.0f32;
126            for k in 0..3 {
127                sum += gg[i][k] * transforms::G_T[k][j];
128            }
129            u[i * 4 + j] = sum;
130        }
131    }
132
133    u
134}
135
136/// Transform a 4x4 input tile to Winograd domain
137///
138/// Computes: V = B^T × d × B
139///
140/// # Arguments
141/// * `tile` - 4x4 input tile (row-major)
142///
143/// # Returns
144/// * 4x4 transformed tile (row-major)
145pub fn transform_input(tile: &[f32; 16]) -> [f32; 16] {
146    let d = [
147        [tile[0], tile[1], tile[2], tile[3]],
148        [tile[4], tile[5], tile[6], tile[7]],
149        [tile[8], tile[9], tile[10], tile[11]],
150        [tile[12], tile[13], tile[14], tile[15]],
151    ];
152
153    // Compute B^T × d (4x4)
154    let mut btd = [[0.0f32; 4]; 4];
155    for i in 0..4 {
156        for j in 0..4 {
157            for k in 0..4 {
158                btd[i][j] += transforms::B_T[i][k] * d[k][j];
159            }
160        }
161    }
162
163    // Compute V = (B^T × d) × B (4x4)
164    let mut v = [0.0f32; 16];
165    for i in 0..4 {
166        for j in 0..4 {
167            let mut sum = 0.0f32;
168            for k in 0..4 {
169                sum += btd[i][k] * transforms::B[k][j];
170            }
171            v[i * 4 + j] = sum;
172        }
173    }
174
175    v
176}
177
178/// Transform Winograd domain result to 2x2 output tile
179///
180/// Computes: Y = A^T × M × A
181///
182/// # Arguments
183/// * `m` - 4x4 element-wise product in Winograd domain
184///
185/// # Returns
186/// * 2x2 output tile (row-major)
187pub fn transform_output(m: &[f32; 16]) -> [f32; 4] {
188    let m_mat = [
189        [m[0], m[1], m[2], m[3]],
190        [m[4], m[5], m[6], m[7]],
191        [m[8], m[9], m[10], m[11]],
192        [m[12], m[13], m[14], m[15]],
193    ];
194
195    // Compute A^T × M (2x4)
196    let mut atm = [[0.0f32; 4]; 2];
197    for i in 0..2 {
198        for j in 0..4 {
199            for k in 0..4 {
200                atm[i][j] += transforms::A_T[i][k] * m_mat[k][j];
201            }
202        }
203    }
204
205    // Compute Y = (A^T × M) × A (2x2)
206    let mut y = [0.0f32; 4];
207    for i in 0..2 {
208        for j in 0..2 {
209            let mut sum = 0.0f32;
210            for k in 0..4 {
211                sum += atm[i][k] * transforms::A[k][j];
212            }
213            y[i * 2 + j] = sum;
214        }
215    }
216
217    y
218}
219
220/// Element-wise multiplication in Winograd domain
221///
222/// Computes: M = U ⊙ V for a single input/output channel pair
223pub fn winograd_multiply(u: &[f32; 16], v: &[f32; 16]) -> [f32; 16] {
224    let mut m = [0.0f32; 16];
225    for i in 0..16 {
226        m[i] = u[i] * v[i];
227    }
228    m
229}
230
231/// Pre-transformed filter cache for efficient inference
232#[derive(Debug, Clone)]
233pub struct WinogradFilterCache {
234    /// Transformed filters: [out_c, in_c, 16]
235    pub filters: Vec<f32>,
236    pub out_channels: usize,
237    pub in_channels: usize,
238}
239
240impl WinogradFilterCache {
241    /// Create a new filter cache from 3x3 filters
242    ///
243    /// # Arguments
244    /// * `filters` - Original 3x3 filters [out_c, in_c, 3, 3]
245    /// * `out_channels` - Number of output channels
246    /// * `in_channels` - Number of input channels
247    pub fn new(filters: &[f32], out_channels: usize, in_channels: usize) -> Self {
248        let mut transformed = vec![0.0f32; out_channels * in_channels * 16];
249
250        for oc in 0..out_channels {
251            for ic in 0..in_channels {
252                // Extract 3x3 filter
253                let filter_offset = (oc * in_channels + ic) * 9;
254                let mut filter_3x3 = [0.0f32; 9];
255                filter_3x3.copy_from_slice(&filters[filter_offset..filter_offset + 9]);
256
257                // Transform to Winograd domain
258                let transformed_filter = transform_filter(&filter_3x3);
259
260                // Store in cache
261                let cache_offset = (oc * in_channels + ic) * 16;
262                transformed[cache_offset..cache_offset + 16].copy_from_slice(&transformed_filter);
263            }
264        }
265
266        Self {
267            filters: transformed,
268            out_channels,
269            in_channels,
270        }
271    }
272
273    /// Get transformed filter for specific channel pair
274    #[inline]
275    pub fn get(&self, out_c: usize, in_c: usize) -> &[f32] {
276        let offset = (out_c * self.in_channels + in_c) * 16;
277        &self.filters[offset..offset + 16]
278    }
279}
280
281/// Winograd F(2,3) convolution (scalar reference implementation)
282///
283/// Performs 3x3 convolution using Winograd transforms.
284/// Suitable for stride=1, no dilation.
285///
286/// # Arguments
287/// * `input` - Input tensor [H, W, C] (HWC format)
288/// * `filter_cache` - Pre-transformed Winograd filters
289/// * `output` - Output tensor [out_H, out_W, out_C]
290/// * `h`, `w` - Input height and width
291/// * `padding` - Zero padding (typically 1 for same-size output)
292pub fn conv_3x3_winograd(
293    input: &[f32],
294    filter_cache: &WinogradFilterCache,
295    output: &mut [f32],
296    h: usize,
297    w: usize,
298    padding: usize,
299) {
300    let in_c = filter_cache.in_channels;
301    let out_c = filter_cache.out_channels;
302
303    // Output dimensions (2x2 tiles)
304    let out_h = (h + 2 * padding - 2) / 2;
305    let out_w = (w + 2 * padding - 2) / 2;
306
307    // Process each 2x2 output tile
308    for oh_tile in 0..out_h {
309        for ow_tile in 0..out_w {
310            // Output positions for this tile
311            let oh0 = oh_tile * 2;
312            let ow0 = ow_tile * 2;
313
314            // Accumulator for all output channels
315            let mut tile_output = vec![0.0f32; out_c * 4];
316
317            // For each input channel
318            for ic in 0..in_c {
319                // Extract 4x4 input tile (with padding)
320                let mut input_tile = [0.0f32; 16];
321                for ti in 0..4 {
322                    for tj in 0..4 {
323                        let ih = (oh0 + ti) as isize - padding as isize;
324                        let iw = (ow0 + tj) as isize - padding as isize;
325
326                        if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
327                            let idx = (ih as usize * w + iw as usize) * in_c + ic;
328                            input_tile[ti * 4 + tj] = input[idx];
329                        }
330                    }
331                }
332
333                // Transform input tile
334                let v = transform_input(&input_tile);
335
336                // For each output channel
337                for oc in 0..out_c {
338                    // Get pre-transformed filter
339                    let u = filter_cache.get(oc, ic);
340
341                    // Element-wise multiply
342                    let mut m = [0.0f32; 16];
343                    for i in 0..16 {
344                        m[i] = u[i] * v[i];
345                    }
346
347                    // Transform to spatial domain and accumulate
348                    let y = transform_output(&m);
349                    for i in 0..4 {
350                        tile_output[oc * 4 + i] += y[i];
351                    }
352                }
353            }
354
355            // Write output tile
356            for oi in 0..2 {
357                for oj in 0..2 {
358                    let oh = oh0 + oi;
359                    let ow = ow0 + oj;
360                    if oh < out_h * 2 && ow < out_w * 2 {
361                        for oc in 0..out_c {
362                            let out_idx = (oh * out_w * 2 + ow) * out_c + oc;
363                            if out_idx < output.len() {
364                                output[out_idx] = tile_output[oc * 4 + oi * 2 + oj];
365                            }
366                        }
367                    }
368                }
369            }
370        }
371    }
372}
373
374/// AVX2 Winograd input transform (4 tiles at once)
375#[cfg(target_arch = "x86_64")]
376#[target_feature(enable = "avx2")]
377pub unsafe fn transform_input_avx2(tiles: &[[f32; 16]; 4]) -> [[f32; 16]; 4] {
378    let mut result = [[0.0f32; 16]; 4];
379
380    // Process each tile (can be further optimized with interleaving)
381    for t in 0..4 {
382        result[t] = transform_input(&tiles[t]);
383    }
384
385    result
386}
387
388/// AVX2 Winograd output transform (4 tiles at once)
389#[cfg(target_arch = "x86_64")]
390#[target_feature(enable = "avx2")]
391pub unsafe fn transform_output_avx2(m_tiles: &[[f32; 16]; 4]) -> [[f32; 4]; 4] {
392    let mut result = [[0.0f32; 4]; 4];
393
394    for t in 0..4 {
395        result[t] = transform_output(&m_tiles[t]);
396    }
397
398    result
399}
400
401// Non-x86_64 stubs
402#[cfg(not(target_arch = "x86_64"))]
403pub unsafe fn transform_input_avx2(_tiles: &[[f32; 16]; 4]) -> [[f32; 16]; 4] {
404    [[0.0f32; 16]; 4]
405}
406
407#[cfg(not(target_arch = "x86_64"))]
408pub unsafe fn transform_output_avx2(_m_tiles: &[[f32; 16]; 4]) -> [[f32; 4]; 4] {
409    [[0.0f32; 4]; 4]
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_filter_transform_roundtrip() {
418        // Identity-like filter (center = 1, rest = 0)
419        let filter = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
420        let transformed = transform_filter(&filter);
421
422        // The center value should dominate
423        assert!(transformed[5].abs() > 0.1 || transformed[6].abs() > 0.1);
424    }
425
426    #[test]
427    fn test_input_transform() {
428        // Simple input tile
429        let tile = [
430            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
431        ];
432        let transformed = transform_input(&tile);
433
434        // Should produce non-zero output
435        let sum: f32 = transformed.iter().map(|x| x.abs()).sum();
436        assert!(sum > 0.0);
437    }
438
439    #[test]
440    fn test_output_transform() {
441        // Simple Winograd domain values
442        let m = [
443            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
444        ];
445        let output = transform_output(&m);
446
447        // Should produce 2x2 output
448        assert_eq!(output.len(), 4);
449    }
450
451    #[test]
452    fn test_winograd_filter_cache() {
453        // Single 3x3 filter
454        let filters = vec![1.0, 0.0, -1.0, 2.0, 0.0, -2.0, 1.0, 0.0, -1.0];
455        let cache = WinogradFilterCache::new(&filters, 1, 1);
456
457        assert_eq!(cache.filters.len(), 16);
458        assert_eq!(cache.out_channels, 1);
459        assert_eq!(cache.in_channels, 1);
460    }
461
462    #[test]
463    fn test_winograd_identity_conv() {
464        // Test with identity-like filter on small input
465        let filters = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
466        let cache = WinogradFilterCache::new(&filters, 1, 1);
467
468        // 4x4 input with padding=1 -> 4x4 output
469        let input = vec![
470            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
471        ];
472        let mut output = vec![0.0f32; 16];
473
474        conv_3x3_winograd(&input, &cache, &mut output, 4, 4, 1);
475
476        // With identity filter and padding, output should roughly match input
477        // (exact match depends on border handling)
478        let output_sum: f32 = output.iter().sum();
479        assert!(output_sum.abs() > 0.0);
480    }
481}