1#[cfg(target_arch = "x86_64")]
36use std::arch::x86_64::*;
37
38pub mod transforms {
40 #[rustfmt::skip]
43 pub const G: [[f32; 3]; 4] = [
44 [ 1.0, 0.0, 0.0 ],
45 [ 0.5, 0.5, 0.5 ],
46 [ 0.5, -0.5, 0.5 ],
47 [ 0.0, 0.0, 1.0 ],
48 ];
49
50 #[rustfmt::skip]
52 pub const G_T: [[f32; 4]; 3] = [
53 [ 1.0, 0.5, 0.5, 0.0 ],
54 [ 0.0, 0.5, -0.5, 0.0 ],
55 [ 0.0, 0.5, 0.5, 1.0 ],
56 ];
57
58 #[rustfmt::skip]
61 pub const B_T: [[f32; 4]; 4] = [
62 [ 1.0, 0.0, -1.0, 0.0 ],
63 [ 0.0, 1.0, 1.0, 0.0 ],
64 [ 0.0, -1.0, 1.0, 0.0 ],
65 [ 0.0, 1.0, 0.0, -1.0 ],
66 ];
67
68 #[rustfmt::skip]
70 pub const B: [[f32; 4]; 4] = [
71 [ 1.0, 0.0, 0.0, 0.0 ],
72 [ 0.0, 1.0, -1.0, 1.0 ],
73 [-1.0, 1.0, 1.0, 0.0 ],
74 [ 0.0, 0.0, 0.0, -1.0 ],
75 ];
76
77 #[rustfmt::skip]
80 pub const A_T: [[f32; 4]; 2] = [
81 [ 1.0, 1.0, 1.0, 0.0 ],
82 [ 0.0, 1.0, -1.0, -1.0 ],
83 ];
84
85 #[rustfmt::skip]
87 pub const A: [[f32; 2]; 4] = [
88 [ 1.0, 0.0 ],
89 [ 1.0, 1.0 ],
90 [ 1.0, -1.0 ],
91 [ 0.0, -1.0 ],
92 ];
93}
94
95pub fn transform_filter(filter: &[f32; 9]) -> [f32; 16] {
105 let g = [
106 [filter[0], filter[1], filter[2]],
107 [filter[3], filter[4], filter[5]],
108 [filter[6], filter[7], filter[8]],
109 ];
110
111 let mut gg = [[0.0f32; 3]; 4];
113 for i in 0..4 {
114 for j in 0..3 {
115 for k in 0..3 {
116 gg[i][j] += transforms::G[i][k] * g[k][j];
117 }
118 }
119 }
120
121 let mut u = [0.0f32; 16];
123 for i in 0..4 {
124 for j in 0..4 {
125 let mut sum = 0.0f32;
126 for k in 0..3 {
127 sum += gg[i][k] * transforms::G_T[k][j];
128 }
129 u[i * 4 + j] = sum;
130 }
131 }
132
133 u
134}
135
136pub fn transform_input(tile: &[f32; 16]) -> [f32; 16] {
146 let d = [
147 [tile[0], tile[1], tile[2], tile[3]],
148 [tile[4], tile[5], tile[6], tile[7]],
149 [tile[8], tile[9], tile[10], tile[11]],
150 [tile[12], tile[13], tile[14], tile[15]],
151 ];
152
153 let mut btd = [[0.0f32; 4]; 4];
155 for i in 0..4 {
156 for j in 0..4 {
157 for k in 0..4 {
158 btd[i][j] += transforms::B_T[i][k] * d[k][j];
159 }
160 }
161 }
162
163 let mut v = [0.0f32; 16];
165 for i in 0..4 {
166 for j in 0..4 {
167 let mut sum = 0.0f32;
168 for k in 0..4 {
169 sum += btd[i][k] * transforms::B[k][j];
170 }
171 v[i * 4 + j] = sum;
172 }
173 }
174
175 v
176}
177
178pub fn transform_output(m: &[f32; 16]) -> [f32; 4] {
188 let m_mat = [
189 [m[0], m[1], m[2], m[3]],
190 [m[4], m[5], m[6], m[7]],
191 [m[8], m[9], m[10], m[11]],
192 [m[12], m[13], m[14], m[15]],
193 ];
194
195 let mut atm = [[0.0f32; 4]; 2];
197 for i in 0..2 {
198 for j in 0..4 {
199 for k in 0..4 {
200 atm[i][j] += transforms::A_T[i][k] * m_mat[k][j];
201 }
202 }
203 }
204
205 let mut y = [0.0f32; 4];
207 for i in 0..2 {
208 for j in 0..2 {
209 let mut sum = 0.0f32;
210 for k in 0..4 {
211 sum += atm[i][k] * transforms::A[k][j];
212 }
213 y[i * 2 + j] = sum;
214 }
215 }
216
217 y
218}
219
220pub fn winograd_multiply(u: &[f32; 16], v: &[f32; 16]) -> [f32; 16] {
224 let mut m = [0.0f32; 16];
225 for i in 0..16 {
226 m[i] = u[i] * v[i];
227 }
228 m
229}
230
231#[derive(Debug, Clone)]
233pub struct WinogradFilterCache {
234 pub filters: Vec<f32>,
236 pub out_channels: usize,
237 pub in_channels: usize,
238}
239
240impl WinogradFilterCache {
241 pub fn new(filters: &[f32], out_channels: usize, in_channels: usize) -> Self {
248 let mut transformed = vec![0.0f32; out_channels * in_channels * 16];
249
250 for oc in 0..out_channels {
251 for ic in 0..in_channels {
252 let filter_offset = (oc * in_channels + ic) * 9;
254 let mut filter_3x3 = [0.0f32; 9];
255 filter_3x3.copy_from_slice(&filters[filter_offset..filter_offset + 9]);
256
257 let transformed_filter = transform_filter(&filter_3x3);
259
260 let cache_offset = (oc * in_channels + ic) * 16;
262 transformed[cache_offset..cache_offset + 16].copy_from_slice(&transformed_filter);
263 }
264 }
265
266 Self {
267 filters: transformed,
268 out_channels,
269 in_channels,
270 }
271 }
272
273 #[inline]
275 pub fn get(&self, out_c: usize, in_c: usize) -> &[f32] {
276 let offset = (out_c * self.in_channels + in_c) * 16;
277 &self.filters[offset..offset + 16]
278 }
279}
280
281pub fn conv_3x3_winograd(
293 input: &[f32],
294 filter_cache: &WinogradFilterCache,
295 output: &mut [f32],
296 h: usize,
297 w: usize,
298 padding: usize,
299) {
300 let in_c = filter_cache.in_channels;
301 let out_c = filter_cache.out_channels;
302
303 let out_h = (h + 2 * padding - 2) / 2;
305 let out_w = (w + 2 * padding - 2) / 2;
306
307 for oh_tile in 0..out_h {
309 for ow_tile in 0..out_w {
310 let oh0 = oh_tile * 2;
312 let ow0 = ow_tile * 2;
313
314 let mut tile_output = vec![0.0f32; out_c * 4];
316
317 for ic in 0..in_c {
319 let mut input_tile = [0.0f32; 16];
321 for ti in 0..4 {
322 for tj in 0..4 {
323 let ih = (oh0 + ti) as isize - padding as isize;
324 let iw = (ow0 + tj) as isize - padding as isize;
325
326 if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
327 let idx = (ih as usize * w + iw as usize) * in_c + ic;
328 input_tile[ti * 4 + tj] = input[idx];
329 }
330 }
331 }
332
333 let v = transform_input(&input_tile);
335
336 for oc in 0..out_c {
338 let u = filter_cache.get(oc, ic);
340
341 let mut m = [0.0f32; 16];
343 for i in 0..16 {
344 m[i] = u[i] * v[i];
345 }
346
347 let y = transform_output(&m);
349 for i in 0..4 {
350 tile_output[oc * 4 + i] += y[i];
351 }
352 }
353 }
354
355 for oi in 0..2 {
357 for oj in 0..2 {
358 let oh = oh0 + oi;
359 let ow = ow0 + oj;
360 if oh < out_h * 2 && ow < out_w * 2 {
361 for oc in 0..out_c {
362 let out_idx = (oh * out_w * 2 + ow) * out_c + oc;
363 if out_idx < output.len() {
364 output[out_idx] = tile_output[oc * 4 + oi * 2 + oj];
365 }
366 }
367 }
368 }
369 }
370 }
371 }
372}
373
374#[cfg(target_arch = "x86_64")]
376#[target_feature(enable = "avx2")]
377pub unsafe fn transform_input_avx2(tiles: &[[f32; 16]; 4]) -> [[f32; 16]; 4] {
378 let mut result = [[0.0f32; 16]; 4];
379
380 for t in 0..4 {
382 result[t] = transform_input(&tiles[t]);
383 }
384
385 result
386}
387
388#[cfg(target_arch = "x86_64")]
390#[target_feature(enable = "avx2")]
391pub unsafe fn transform_output_avx2(m_tiles: &[[f32; 16]; 4]) -> [[f32; 4]; 4] {
392 let mut result = [[0.0f32; 4]; 4];
393
394 for t in 0..4 {
395 result[t] = transform_output(&m_tiles[t]);
396 }
397
398 result
399}
400
401#[cfg(not(target_arch = "x86_64"))]
403pub unsafe fn transform_input_avx2(_tiles: &[[f32; 16]; 4]) -> [[f32; 16]; 4] {
404 [[0.0f32; 16]; 4]
405}
406
407#[cfg(not(target_arch = "x86_64"))]
408pub unsafe fn transform_output_avx2(_m_tiles: &[[f32; 16]; 4]) -> [[f32; 4]; 4] {
409 [[0.0f32; 4]; 4]
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_filter_transform_roundtrip() {
418 let filter = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
420 let transformed = transform_filter(&filter);
421
422 assert!(transformed[5].abs() > 0.1 || transformed[6].abs() > 0.1);
424 }
425
426 #[test]
427 fn test_input_transform() {
428 let tile = [
430 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
431 ];
432 let transformed = transform_input(&tile);
433
434 let sum: f32 = transformed.iter().map(|x| x.abs()).sum();
436 assert!(sum > 0.0);
437 }
438
439 #[test]
440 fn test_output_transform() {
441 let m = [
443 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
444 ];
445 let output = transform_output(&m);
446
447 assert_eq!(output.len(), 4);
449 }
450
451 #[test]
452 fn test_winograd_filter_cache() {
453 let filters = vec![1.0, 0.0, -1.0, 2.0, 0.0, -2.0, 1.0, 0.0, -1.0];
455 let cache = WinogradFilterCache::new(&filters, 1, 1);
456
457 assert_eq!(cache.filters.len(), 16);
458 assert_eq!(cache.out_channels, 1);
459 assert_eq!(cache.in_channels, 1);
460 }
461
462 #[test]
463 fn test_winograd_identity_conv() {
464 let filters = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
466 let cache = WinogradFilterCache::new(&filters, 1, 1);
467
468 let input = vec![
470 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
471 ];
472 let mut output = vec![0.0f32; 16];
473
474 conv_3x3_winograd(&input, &cache, &mut output, 4, 4, 1);
475
476 let output_sum: f32 = output.iter().sum();
479 assert!(output_sum.abs() > 0.0);
480 }
481}