Skip to main content

sklears_simd/
batch_operations.rs

1//! Batch operations for tensor processing
2//!
3//! This module provides SIMD-optimized batch operations essential for
4//! modern neural networks and tensor computations.
5
6use crate::half_precision::{BF16, F16};
7
8#[cfg(feature = "no-std")]
9extern crate alloc;
10
11/// Batch normalization operation
12pub struct BatchNorm {
13    epsilon: f32,
14}
15
16impl BatchNorm {
17    /// Create a new batch normalization layer
18    pub fn new(epsilon: f32) -> Self {
19        Self { epsilon }
20    }
21
22    /// Apply batch normalization to a batch of data
23    /// Input shape: [batch_size, features]
24    /// mean, variance, gamma, beta shape: \[features\]
25    #[allow(clippy::too_many_arguments)] // BatchNorm forward requires all statistical parameters
26    pub fn forward(
27        &self,
28        input: &[f32],
29        mean: &[f32],
30        variance: &[f32],
31        gamma: &[f32],
32        beta: &[f32],
33        output: &mut [f32],
34        batch_size: usize,
35        features: usize,
36    ) {
37        assert_eq!(input.len(), batch_size * features);
38        assert_eq!(output.len(), batch_size * features);
39        assert_eq!(mean.len(), features);
40        assert_eq!(variance.len(), features);
41        assert_eq!(gamma.len(), features);
42        assert_eq!(beta.len(), features);
43
44        for batch in 0..batch_size {
45            for feat in 0..features {
46                let idx = batch * features + feat;
47                let x = input[idx];
48                let normalized = (x - mean[feat]) / (variance[feat] + self.epsilon).sqrt();
49                output[idx] = gamma[feat] * normalized + beta[feat];
50            }
51        }
52    }
53
54    /// Apply batch normalization with FP16 precision
55    #[allow(clippy::too_many_arguments)] // BatchNorm forward_f16 requires all statistical parameters
56    pub fn forward_f16(
57        &self,
58        input: &[F16],
59        mean: &[F16],
60        variance: &[F16],
61        gamma: &[F16],
62        beta: &[F16],
63        output: &mut [F16],
64        batch_size: usize,
65        features: usize,
66    ) {
67        assert_eq!(input.len(), batch_size * features);
68        assert_eq!(output.len(), batch_size * features);
69        assert_eq!(mean.len(), features);
70        assert_eq!(variance.len(), features);
71        assert_eq!(gamma.len(), features);
72        assert_eq!(beta.len(), features);
73
74        for batch in 0..batch_size {
75            for feat in 0..features {
76                let idx = batch * features + feat;
77                let x = input[idx].to_f32();
78                let m = mean[feat].to_f32();
79                let v = variance[feat].to_f32();
80                let g = gamma[feat].to_f32();
81                let b = beta[feat].to_f32();
82
83                let normalized = (x - m) / (v + self.epsilon).sqrt();
84                let result = g * normalized + b;
85                output[idx] = F16::from_f32(result);
86            }
87        }
88    }
89
90    /// Compute batch statistics (mean and variance)
91    pub fn compute_stats(
92        input: &[f32],
93        mean: &mut [f32],
94        variance: &mut [f32],
95        batch_size: usize,
96        features: usize,
97    ) {
98        assert_eq!(input.len(), batch_size * features);
99        assert_eq!(mean.len(), features);
100        assert_eq!(variance.len(), features);
101
102        // Compute mean
103        for (feat, m) in mean.iter_mut().enumerate() {
104            let mut sum = 0.0;
105            for batch in 0..batch_size {
106                sum += input[batch * features + feat];
107            }
108            *m = sum / batch_size as f32;
109        }
110
111        // Compute variance
112        for (feat, v) in variance.iter_mut().enumerate() {
113            let mut sum_sq_diff = 0.0;
114            for batch in 0..batch_size {
115                let diff = input[batch * features + feat] - mean[feat];
116                sum_sq_diff += diff * diff;
117            }
118            *v = sum_sq_diff / batch_size as f32;
119        }
120    }
121}
122
123/// Layer normalization operation
124pub struct LayerNorm {
125    epsilon: f32,
126}
127
128impl LayerNorm {
129    /// Create a new layer normalization
130    pub fn new(epsilon: f32) -> Self {
131        Self { epsilon }
132    }
133
134    /// Apply layer normalization
135    /// Normalizes across the feature dimension for each sample
136    pub fn forward(
137        &self,
138        input: &[f32],
139        gamma: &[f32],
140        beta: &[f32],
141        output: &mut [f32],
142        batch_size: usize,
143        features: usize,
144    ) {
145        assert_eq!(input.len(), batch_size * features);
146        assert_eq!(output.len(), batch_size * features);
147        assert_eq!(gamma.len(), features);
148        assert_eq!(beta.len(), features);
149
150        for batch in 0..batch_size {
151            let start_idx = batch * features;
152            let end_idx = start_idx + features;
153
154            // Compute mean for this sample
155            let sample_slice = &input[start_idx..end_idx];
156            let mean = sample_slice.iter().sum::<f32>() / features as f32;
157
158            // Compute variance for this sample
159            let sum_sq_diff: f32 = sample_slice.iter().map(|&x| (x - mean).powi(2)).sum();
160            let variance = sum_sq_diff / features as f32;
161            let std_dev = (variance + self.epsilon).sqrt();
162
163            // Apply normalization
164            for (i, feat) in (start_idx..end_idx).enumerate() {
165                let normalized = (input[feat] - mean) / std_dev;
166                output[feat] = gamma[i] * normalized + beta[i];
167            }
168        }
169    }
170}
171
172/// Batch matrix multiplication operations
173pub mod batch_matmul {
174    use super::*;
175
176    /// Batch matrix multiplication: C\[i\] = A\[i\] * B\[i\]
177    /// All matrices have the same dimensions
178    pub fn batch_matmul_f32(
179        a: &[f32],
180        b: &[f32],
181        c: &mut [f32],
182        batch_size: usize,
183        m: usize,
184        n: usize,
185        k: usize,
186    ) {
187        assert_eq!(a.len(), batch_size * m * k);
188        assert_eq!(b.len(), batch_size * k * n);
189        assert_eq!(c.len(), batch_size * m * n);
190
191        for batch in 0..batch_size {
192            let a_offset = batch * m * k;
193            let b_offset = batch * k * n;
194            let c_offset = batch * m * n;
195
196            for i in 0..m {
197                for j in 0..n {
198                    let mut sum = 0.0;
199                    for l in 0..k {
200                        let a_idx = a_offset + i * k + l;
201                        let b_idx = b_offset + l * n + j;
202                        sum += a[a_idx] * b[b_idx];
203                    }
204                    let c_idx = c_offset + i * n + j;
205                    c[c_idx] = sum;
206                }
207            }
208        }
209    }
210
211    /// Batch matrix multiplication with broadcasting: C\[i\] = A\[i\] * B
212    /// A is batched, B is shared across all batches
213    pub fn batch_matmul_broadcast_f32(
214        a: &[f32],
215        b: &[f32],
216        c: &mut [f32],
217        batch_size: usize,
218        m: usize,
219        n: usize,
220        k: usize,
221    ) {
222        assert_eq!(a.len(), batch_size * m * k);
223        assert_eq!(b.len(), k * n);
224        assert_eq!(c.len(), batch_size * m * n);
225
226        for batch in 0..batch_size {
227            let a_offset = batch * m * k;
228            let c_offset = batch * m * n;
229
230            for i in 0..m {
231                for j in 0..n {
232                    let mut sum = 0.0;
233                    for l in 0..k {
234                        let a_idx = a_offset + i * k + l;
235                        let b_idx = l * n + j;
236                        sum += a[a_idx] * b[b_idx];
237                    }
238                    let c_idx = c_offset + i * n + j;
239                    c[c_idx] = sum;
240                }
241            }
242        }
243    }
244
245    /// Batch matrix multiplication with FP16
246    pub fn batch_matmul_f16(
247        a: &[F16],
248        b: &[F16],
249        c: &mut [F16],
250        batch_size: usize,
251        m: usize,
252        n: usize,
253        k: usize,
254    ) {
255        assert_eq!(a.len(), batch_size * m * k);
256        assert_eq!(b.len(), batch_size * k * n);
257        assert_eq!(c.len(), batch_size * m * n);
258
259        for batch in 0..batch_size {
260            let a_offset = batch * m * k;
261            let b_offset = batch * k * n;
262            let c_offset = batch * m * n;
263
264            for i in 0..m {
265                for j in 0..n {
266                    let mut sum = 0.0f32;
267                    for l in 0..k {
268                        let a_idx = a_offset + i * k + l;
269                        let b_idx = b_offset + l * n + j;
270                        sum += a[a_idx].to_f32() * b[b_idx].to_f32();
271                    }
272                    let c_idx = c_offset + i * n + j;
273                    c[c_idx] = F16::from_f32(sum);
274                }
275            }
276        }
277    }
278
279    /// Batch matrix multiplication with BF16
280    pub fn batch_matmul_bf16(
281        a: &[BF16],
282        b: &[BF16],
283        c: &mut [BF16],
284        batch_size: usize,
285        m: usize,
286        n: usize,
287        k: usize,
288    ) {
289        assert_eq!(a.len(), batch_size * m * k);
290        assert_eq!(b.len(), batch_size * k * n);
291        assert_eq!(c.len(), batch_size * m * n);
292
293        for batch in 0..batch_size {
294            let a_offset = batch * m * k;
295            let b_offset = batch * k * n;
296            let c_offset = batch * m * n;
297
298            for i in 0..m {
299                for j in 0..n {
300                    let mut sum = 0.0f32;
301                    for l in 0..k {
302                        let a_idx = a_offset + i * k + l;
303                        let b_idx = b_offset + l * n + j;
304                        sum += a[a_idx].to_f32() * b[b_idx].to_f32();
305                    }
306                    let c_idx = c_offset + i * n + j;
307                    c[c_idx] = BF16::from_f32(sum);
308                }
309            }
310        }
311    }
312}
313
314/// Attention mechanism operations
315pub mod attention {
316
317    /// Scaled dot-product attention
318    /// Query, Key, Value shapes: [batch_size, seq_len, d_model]
319    /// Output shape: [batch_size, seq_len, d_model]
320    #[allow(clippy::too_many_arguments)] // Attention signature requires all tensor dimensions
321    pub fn scaled_dot_product_attention(
322        query: &[f32],
323        key: &[f32],
324        value: &[f32],
325        output: &mut [f32],
326        batch_size: usize,
327        seq_len: usize,
328        d_model: usize,
329        mask: Option<&[bool]>,
330    ) {
331        let scale = 1.0 / (d_model as f32).sqrt();
332
333        assert_eq!(query.len(), batch_size * seq_len * d_model);
334        assert_eq!(key.len(), batch_size * seq_len * d_model);
335        assert_eq!(value.len(), batch_size * seq_len * d_model);
336        assert_eq!(output.len(), batch_size * seq_len * d_model);
337
338        // Temporary storage for attention scores
339        #[cfg(not(feature = "no-std"))]
340        let mut scores = vec![0.0f32; batch_size * seq_len * seq_len];
341        #[cfg(feature = "no-std")]
342        let mut scores = alloc::vec![0.0f32; batch_size * seq_len * seq_len];
343
344        for batch in 0..batch_size {
345            // Compute attention scores: Q * K^T
346            for i in 0..seq_len {
347                for j in 0..seq_len {
348                    let mut dot_product = 0.0;
349                    for k in 0..d_model {
350                        let q_idx = batch * seq_len * d_model + i * d_model + k;
351                        let k_idx = batch * seq_len * d_model + j * d_model + k;
352                        dot_product += query[q_idx] * key[k_idx];
353                    }
354                    let score_idx = batch * seq_len * seq_len + i * seq_len + j;
355                    scores[score_idx] = dot_product * scale;
356
357                    // Apply mask if provided
358                    if let Some(mask) = mask {
359                        if !mask[i * seq_len + j] {
360                            scores[score_idx] = f32::NEG_INFINITY;
361                        }
362                    }
363                }
364            }
365
366            // Apply softmax to attention scores
367            for i in 0..seq_len {
368                let row_start = batch * seq_len * seq_len + i * seq_len;
369                let row_end = row_start + seq_len;
370
371                // Find max for numerical stability
372                let max_val = scores[row_start..row_end]
373                    .iter()
374                    .copied()
375                    .fold(f32::NEG_INFINITY, f32::max);
376
377                // Compute exp and sum
378                let row = &mut scores[row_start..row_end];
379                for s in row.iter_mut() {
380                    *s = (*s - max_val).exp();
381                }
382                let sum_exp: f32 = scores[row_start..row_end].iter().sum();
383
384                // Normalize
385                for s in scores[row_start..row_end].iter_mut() {
386                    *s /= sum_exp;
387                }
388            }
389
390            // Compute output: Attention_weights * V
391            for i in 0..seq_len {
392                for k in 0..d_model {
393                    let mut weighted_sum = 0.0;
394                    for j in 0..seq_len {
395                        let attention_weight = scores[batch * seq_len * seq_len + i * seq_len + j];
396                        let v_idx = batch * seq_len * d_model + j * d_model + k;
397                        weighted_sum += attention_weight * value[v_idx];
398                    }
399                    let out_idx = batch * seq_len * d_model + i * d_model + k;
400                    output[out_idx] = weighted_sum;
401                }
402            }
403        }
404    }
405
406    /// Multi-head attention
407    #[allow(clippy::too_many_arguments)] // Multi-head attention requires all tensor + head dimensions
408    pub fn multi_head_attention(
409        query: &[f32],
410        key: &[f32],
411        value: &[f32],
412        output: &mut [f32],
413        batch_size: usize,
414        seq_len: usize,
415        d_model: usize,
416        num_heads: usize,
417        mask: Option<&[bool]>,
418    ) {
419        assert_eq!(d_model % num_heads, 0);
420        let d_k = d_model / num_heads;
421
422        assert_eq!(query.len(), batch_size * seq_len * d_model);
423        assert_eq!(key.len(), batch_size * seq_len * d_model);
424        assert_eq!(value.len(), batch_size * seq_len * d_model);
425        assert_eq!(output.len(), batch_size * seq_len * d_model);
426
427        #[cfg(not(feature = "no-std"))]
428        let mut head_outputs = vec![0.0f32; batch_size * num_heads * seq_len * d_k];
429        #[cfg(feature = "no-std")]
430        let mut head_outputs = alloc::vec![0.0f32; batch_size * num_heads * seq_len * d_k];
431
432        // Process each head
433        for head in 0..num_heads {
434            let head_start = head * d_k;
435            let _head_end = head_start + d_k;
436
437            // Extract head-specific Q, K, V
438            #[cfg(not(feature = "no-std"))]
439            let mut head_q = vec![0.0f32; batch_size * seq_len * d_k];
440            #[cfg(feature = "no-std")]
441            let mut head_q = alloc::vec![0.0f32; batch_size * seq_len * d_k];
442            #[cfg(not(feature = "no-std"))]
443            let mut head_k = vec![0.0f32; batch_size * seq_len * d_k];
444            #[cfg(feature = "no-std")]
445            let mut head_k = alloc::vec![0.0f32; batch_size * seq_len * d_k];
446            #[cfg(not(feature = "no-std"))]
447            let mut head_v = vec![0.0f32; batch_size * seq_len * d_k];
448            #[cfg(feature = "no-std")]
449            let mut head_v = alloc::vec![0.0f32; batch_size * seq_len * d_k];
450
451            for batch in 0..batch_size {
452                for seq in 0..seq_len {
453                    for d in 0..d_k {
454                        let src_idx = batch * seq_len * d_model + seq * d_model + head_start + d;
455                        let dst_idx = batch * seq_len * d_k + seq * d_k + d;
456                        head_q[dst_idx] = query[src_idx];
457                        head_k[dst_idx] = key[src_idx];
458                        head_v[dst_idx] = value[src_idx];
459                    }
460                }
461            }
462
463            // Apply attention for this head
464            #[cfg(not(feature = "no-std"))]
465            let mut head_output = vec![0.0f32; batch_size * seq_len * d_k];
466            #[cfg(feature = "no-std")]
467            let mut head_output = alloc::vec![0.0f32; batch_size * seq_len * d_k];
468            scaled_dot_product_attention(
469                &head_q,
470                &head_k,
471                &head_v,
472                &mut head_output,
473                batch_size,
474                seq_len,
475                d_k,
476                mask,
477            );
478
479            // Store head output
480            let head_offset = head * batch_size * seq_len * d_k;
481            head_outputs[head_offset..head_offset + head_output.len()]
482                .copy_from_slice(&head_output);
483        }
484
485        // Concatenate all heads
486        for batch in 0..batch_size {
487            for seq in 0..seq_len {
488                for head in 0..num_heads {
489                    for d in 0..d_k {
490                        let src_idx = head * batch_size * seq_len * d_k
491                            + batch * seq_len * d_k
492                            + seq * d_k
493                            + d;
494                        let dst_idx = batch * seq_len * d_model + seq * d_model + head * d_k + d;
495                        output[dst_idx] = head_outputs[src_idx];
496                    }
497                }
498            }
499        }
500    }
501}
502
503/// Convolution operations for batched data
504pub mod convolution {
505
506    /// 2D convolution for batched images
507    /// Input shape: [batch_size, in_channels, height, width]
508    /// Weight shape: [out_channels, in_channels, kernel_height, kernel_width]
509    /// Output shape: [batch_size, out_channels, out_height, out_width]
510    #[allow(clippy::too_many_arguments)] // Conv2d requires all spatial and channel dimensions
511    pub fn conv2d_batch(
512        input: &[f32],
513        weight: &[f32],
514        bias: &[f32],
515        output: &mut [f32],
516        batch_size: usize,
517        in_channels: usize,
518        out_channels: usize,
519        input_height: usize,
520        input_width: usize,
521        kernel_height: usize,
522        kernel_width: usize,
523        stride_h: usize,
524        stride_w: usize,
525        padding_h: usize,
526        padding_w: usize,
527    ) {
528        let output_height = (input_height + 2 * padding_h - kernel_height) / stride_h + 1;
529        let output_width = (input_width + 2 * padding_w - kernel_width) / stride_w + 1;
530
531        assert_eq!(
532            input.len(),
533            batch_size * in_channels * input_height * input_width
534        );
535        assert_eq!(
536            weight.len(),
537            out_channels * in_channels * kernel_height * kernel_width
538        );
539        assert_eq!(bias.len(), out_channels);
540        assert_eq!(
541            output.len(),
542            batch_size * out_channels * output_height * output_width
543        );
544
545        for batch in 0..batch_size {
546            for (out_ch, &bias_val) in bias.iter().enumerate() {
547                for out_y in 0..output_height {
548                    for out_x in 0..output_width {
549                        let mut sum = bias_val;
550
551                        for in_ch in 0..in_channels {
552                            for ky in 0..kernel_height {
553                                for kx in 0..kernel_width {
554                                    let in_y = out_y * stride_h + ky;
555                                    let in_x = out_x * stride_w + kx;
556
557                                    if in_y >= padding_h
558                                        && in_x >= padding_w
559                                        && in_y < input_height + padding_h
560                                        && in_x < input_width + padding_w
561                                    {
562                                        let input_y = in_y - padding_h;
563                                        let input_x = in_x - padding_w;
564
565                                        let input_idx =
566                                            batch * in_channels * input_height * input_width
567                                                + in_ch * input_height * input_width
568                                                + input_y * input_width
569                                                + input_x;
570                                        let weight_idx =
571                                            out_ch * in_channels * kernel_height * kernel_width
572                                                + in_ch * kernel_height * kernel_width
573                                                + ky * kernel_width
574                                                + kx;
575
576                                        sum += input[input_idx] * weight[weight_idx];
577                                    }
578                                }
579                            }
580                        }
581
582                        let output_idx = batch * out_channels * output_height * output_width
583                            + out_ch * output_height * output_width
584                            + out_y * output_width
585                            + out_x;
586                        output[output_idx] = sum;
587                    }
588                }
589            }
590        }
591    }
592}
593
594#[allow(non_snake_case)]
595#[cfg(all(test, not(feature = "no-std")))]
596mod tests {
597    use super::*;
598
599    #[cfg(feature = "no-std")]
600    use alloc::{vec, vec::Vec};
601
602    #[test]
603    fn test_batch_norm() {
604        let batch_norm = BatchNorm::new(1e-5);
605        let batch_size = 2;
606        let features = 3;
607
608        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
609        let mean = vec![2.5, 3.5, 4.5];
610        let variance = vec![2.25, 2.25, 2.25];
611        let gamma = vec![1.0, 1.0, 1.0];
612        let beta = vec![0.0, 0.0, 0.0];
613        let mut output = vec![0.0; 6];
614
615        batch_norm.forward(
616            &input,
617            &mean,
618            &variance,
619            &gamma,
620            &beta,
621            &mut output,
622            batch_size,
623            features,
624        );
625
626        // Check that normalization was applied
627        for &val in &output {
628            assert!(val.abs() < 2.0); // Normalized values should be small
629        }
630    }
631
632    #[test]
633    fn test_layer_norm() {
634        let layer_norm = LayerNorm::new(1e-5);
635        let batch_size = 2;
636        let features = 3;
637
638        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
639        let gamma = vec![1.0, 1.0, 1.0];
640        let beta = vec![0.0, 0.0, 0.0];
641        let mut output = vec![0.0; 6];
642
643        layer_norm.forward(&input, &gamma, &beta, &mut output, batch_size, features);
644
645        // Each sample should be normalized independently
646        for batch in 0..batch_size {
647            let start = batch * features;
648            let end = start + features;
649            let sample_mean: f32 = output[start..end].iter().sum::<f32>() / features as f32;
650            assert!((sample_mean).abs() < 1e-5);
651        }
652    }
653
654    #[test]
655    fn test_batch_matmul() {
656        let batch_size = 2;
657        let m = 2;
658        let n = 2;
659        let k = 2;
660
661        let a = vec![
662            1.0, 2.0, 3.0, 4.0, // batch 0
663            5.0, 6.0, 7.0, 8.0, // batch 1
664        ];
665        let b = vec![
666            1.0, 0.0, 0.0, 1.0, // batch 0 (identity)
667            1.0, 0.0, 0.0, 1.0, // batch 1 (identity)
668        ];
669        let mut c = vec![0.0; batch_size * m * n];
670
671        batch_matmul::batch_matmul_f32(&a, &b, &mut c, batch_size, m, n, k);
672
673        // With identity matrices, output should equal input
674        let expected = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
675        for i in 0..expected.len() {
676            assert!((c[i] - expected[i]).abs() < 1e-5);
677        }
678    }
679
680    #[test]
681    fn test_batch_matmul_broadcast() {
682        let batch_size = 2;
683        let m = 2;
684        let n = 2;
685        let k = 2;
686
687        let a = vec![
688            1.0, 2.0, 3.0, 4.0, // batch 0
689            5.0, 6.0, 7.0, 8.0, // batch 1
690        ];
691        let b = vec![1.0, 0.0, 0.0, 1.0]; // shared identity matrix
692        let mut c = vec![0.0; batch_size * m * n];
693
694        batch_matmul::batch_matmul_broadcast_f32(&a, &b, &mut c, batch_size, m, n, k);
695
696        // With identity matrix, output should equal input
697        let expected = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
698        for i in 0..expected.len() {
699            assert!((c[i] - expected[i]).abs() < 1e-5);
700        }
701    }
702
703    #[test]
704    fn test_attention_basic() {
705        let batch_size = 1;
706        let seq_len = 3;
707        let d_model = 4;
708
709        // Simple test case with known values
710        let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
711        let key = query.clone();
712        let value = vec![
713            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
714        ];
715        let mut output = vec![0.0; batch_size * seq_len * d_model];
716
717        attention::scaled_dot_product_attention(
718            &query,
719            &key,
720            &value,
721            &mut output,
722            batch_size,
723            seq_len,
724            d_model,
725            None,
726        );
727
728        // Output should be a weighted combination of values
729        assert_eq!(output.len(), 12);
730        // All values should be finite
731        for &val in &output {
732            assert!(val.is_finite());
733        }
734    }
735
736    #[test]
737    fn test_conv2d_batch_simple() {
738        let batch_size = 1;
739        let in_channels = 1;
740        let out_channels = 1;
741        let input_height = 3;
742        let input_width = 3;
743        let kernel_height = 2;
744        let kernel_width = 2;
745
746        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
747        let weight = vec![1.0, 0.0, 0.0, 1.0]; // Simple kernel
748        let bias = vec![0.0];
749
750        let output_height = input_height - kernel_height + 1;
751        let output_width = input_width - kernel_width + 1;
752        let mut output = vec![0.0; batch_size * out_channels * output_height * output_width];
753
754        convolution::conv2d_batch(
755            &input,
756            &weight,
757            &bias,
758            &mut output,
759            batch_size,
760            in_channels,
761            out_channels,
762            input_height,
763            input_width,
764            kernel_height,
765            kernel_width,
766            1,
767            1,
768            0,
769            0,
770        );
771
772        // Check that convolution produced finite results
773        for &val in &output {
774            assert!(val.is_finite());
775        }
776        assert_eq!(output.len(), 4); // 2x2 output
777    }
778
779    #[test]
780    fn test_batch_norm_f16() {
781        let batch_norm = BatchNorm::new(1e-3); // Larger epsilon for FP16
782        let batch_size = 2;
783        let features = 3;
784
785        let input = vec![
786            F16::from_f32(1.0),
787            F16::from_f32(2.0),
788            F16::from_f32(3.0),
789            F16::from_f32(4.0),
790            F16::from_f32(5.0),
791            F16::from_f32(6.0),
792        ];
793        let mean = vec![F16::from_f32(2.5), F16::from_f32(3.5), F16::from_f32(4.5)];
794        let variance = vec![
795            F16::from_f32(2.25),
796            F16::from_f32(2.25),
797            F16::from_f32(2.25),
798        ];
799        let gamma = vec![F16::from_f32(1.0), F16::from_f32(1.0), F16::from_f32(1.0)];
800        let beta = vec![F16::from_f32(0.0), F16::from_f32(0.0), F16::from_f32(0.0)];
801        let mut output = vec![F16::from_bits(0); 6];
802
803        batch_norm.forward_f16(
804            &input,
805            &mean,
806            &variance,
807            &gamma,
808            &beta,
809            &mut output,
810            batch_size,
811            features,
812        );
813
814        // Check that normalization was applied
815        for &val in &output {
816            assert!(val.to_f32().abs() < 2.0);
817        }
818    }
819
820    #[test]
821    fn test_batch_stats_computation() {
822        let batch_size = 4;
823        let features = 2;
824
825        let input = vec![
826            1.0, 2.0, // batch 0
827            3.0, 4.0, // batch 1
828            5.0, 6.0, // batch 2
829            7.0, 8.0, // batch 3
830        ];
831        let mut mean = vec![0.0; features];
832        let mut variance = vec![0.0; features];
833
834        BatchNorm::compute_stats(&input, &mut mean, &mut variance, batch_size, features);
835
836        // Expected mean: [4.0, 5.0]
837        assert!((mean[0] - 4.0).abs() < 1e-6);
838        assert!((mean[1] - 5.0).abs() < 1e-6);
839
840        // Expected variance: [5.0, 5.0] (for values 1,3,5,7 and 2,4,6,8)
841        assert!((variance[0] - 5.0).abs() < 1e-6);
842        assert!((variance[1] - 5.0).abs() < 1e-6);
843    }
844}