torsh_python/nn/
normalization.rs

1//! Normalization layers
2
3use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use std::collections::HashMap;
7use torsh_tensor::Tensor;
8
9/// Batch Normalization 2D layer
10#[pyclass(name = "BatchNorm2d", extends = PyModule)]
11pub struct PyBatchNorm2d {
12    weight: Option<Tensor<f32>>,
13    bias: Option<Tensor<f32>>,
14    running_mean: Tensor<f32>,
15    running_var: Tensor<f32>,
16    num_features: usize,
17    eps: f32,
18    momentum: f32,
19    affine: bool,
20    track_running_stats: bool,
21    training: bool,
22    num_batches_tracked: usize,
23}
24
25#[pymethods]
26impl PyBatchNorm2d {
27    #[new]
28    fn new(
29        num_features: usize,
30        eps: Option<f32>,
31        momentum: Option<f32>,
32        affine: Option<bool>,
33        track_running_stats: Option<bool>,
34    ) -> PyResult<(Self, PyModule)> {
35        let eps = eps.unwrap_or(1e-5);
36        let momentum = momentum.unwrap_or(0.1);
37        let affine = affine.unwrap_or(true);
38        let track_running_stats = track_running_stats.unwrap_or(true);
39
40        let shape = vec![num_features];
41
42        // Initialize weight and bias if affine=true
43        let (weight, bias) = if affine {
44            let weight = py_result!(torsh_tensor::creation::ones(&shape))?.requires_grad_(true);
45            let bias = py_result!(torsh_tensor::creation::zeros(&shape))?.requires_grad_(true);
46            (Some(weight), Some(bias))
47        } else {
48            (None, None)
49        };
50
51        // Initialize running statistics
52        let running_mean = py_result!(torsh_tensor::creation::zeros(&shape))?;
53        let running_var = py_result!(torsh_tensor::creation::ones(&shape))?;
54
55        Ok((
56            Self {
57                weight,
58                bias,
59                running_mean,
60                running_var,
61                num_features,
62                eps,
63                momentum,
64                affine,
65                track_running_stats,
66                training: true,
67                num_batches_tracked: 0,
68            },
69            PyModule::new(),
70        ))
71    }
72
73    /// Forward pass through batch normalization
74    fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
75        // ✅ Proper 2D batch normalization implementation for 4D tensors (NCHW)
76        let shape = input.tensor.shape().dims().to_vec();
77
78        // Expect 4D input: (batch, channels, height, width)
79        if shape.len() != 4 {
80            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
81                "Expected 4D input (NCHW), got {}D",
82                shape.len()
83            )));
84        }
85
86        let batch_size = shape[0];
87        let num_channels = shape[1];
88        let height = shape[2];
89        let width = shape[3];
90        let spatial_size = height * width;
91
92        if num_channels != self.num_features {
93            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
94                "Expected {} channels, got {}",
95                self.num_features, num_channels
96            )));
97        }
98
99        let input_data = py_result!(input.tensor.data())?;
100        let mut output_data = input_data.clone();
101
102        if self.training {
103            // Training mode: compute batch statistics across spatial dimensions
104            if self.track_running_stats {
105                self.num_batches_tracked += 1;
106            }
107
108            // Compute mean and variance for each channel across batch and spatial dims
109            for c in 0..num_channels {
110                let mut sum = 0.0;
111                let mut sum_sq = 0.0;
112                let mut count = 0;
113
114                for b in 0..batch_size {
115                    for h in 0..height {
116                        for w in 0..width {
117                            let idx =
118                                b * num_channels * spatial_size + c * spatial_size + h * width + w;
119                            let val = input_data[idx];
120                            sum += val;
121                            sum_sq += val * val;
122                            count += 1;
123                        }
124                    }
125                }
126
127                let mean = sum / count as f32;
128                let var = (sum_sq / count as f32) - (mean * mean);
129
130                // Update running statistics
131                if self.track_running_stats {
132                    let mut running_mean_data = py_result!(self.running_mean.data())?;
133                    let mut running_var_data = py_result!(self.running_var.data())?;
134
135                    running_mean_data[c] =
136                        (1.0 - self.momentum) * running_mean_data[c] + self.momentum * mean;
137                    running_var_data[c] =
138                        (1.0 - self.momentum) * running_var_data[c] + self.momentum * var;
139
140                    self.running_mean = py_result!(torsh_tensor::Tensor::from_data(
141                        running_mean_data,
142                        vec![num_channels],
143                        self.running_mean.device()
144                    ))?;
145                    self.running_var = py_result!(torsh_tensor::Tensor::from_data(
146                        running_var_data,
147                        vec![num_channels],
148                        self.running_var.device()
149                    ))?;
150                }
151
152                // Normalize
153                let std = (var + self.eps).sqrt();
154                for b in 0..batch_size {
155                    for h in 0..height {
156                        for w in 0..width {
157                            let idx =
158                                b * num_channels * spatial_size + c * spatial_size + h * width + w;
159                            output_data[idx] = (output_data[idx] - mean) / std;
160                        }
161                    }
162                }
163
164                // Apply affine transformation
165                if self.affine {
166                    if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
167                        let weight_data = py_result!(weight.data())?;
168                        let bias_data = py_result!(bias.data())?;
169
170                        for b in 0..batch_size {
171                            for h in 0..height {
172                                for w in 0..width {
173                                    let idx = b * num_channels * spatial_size
174                                        + c * spatial_size
175                                        + h * width
176                                        + w;
177                                    output_data[idx] =
178                                        output_data[idx] * weight_data[c] + bias_data[c];
179                                }
180                            }
181                        }
182                    }
183                }
184            }
185        } else {
186            // Evaluation mode: use running statistics
187            let running_mean_data = py_result!(self.running_mean.data())?;
188            let running_var_data = py_result!(self.running_var.data())?;
189
190            for c in 0..num_channels {
191                let mean = running_mean_data[c];
192                let var = running_var_data[c];
193                let std = (var + self.eps).sqrt();
194
195                for b in 0..batch_size {
196                    for h in 0..height {
197                        for w in 0..width {
198                            let idx =
199                                b * num_channels * spatial_size + c * spatial_size + h * width + w;
200                            output_data[idx] = (output_data[idx] - mean) / std;
201                        }
202                    }
203                }
204
205                // Apply affine transformation
206                if self.affine {
207                    if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
208                        let weight_data = py_result!(weight.data())?;
209                        let bias_data = py_result!(bias.data())?;
210
211                        for b in 0..batch_size {
212                            for h in 0..height {
213                                for w in 0..width {
214                                    let idx = b * num_channels * spatial_size
215                                        + c * spatial_size
216                                        + h * width
217                                        + w;
218                                    output_data[idx] =
219                                        output_data[idx] * weight_data[c] + bias_data[c];
220                                }
221                            }
222                        }
223                    }
224                }
225            }
226        }
227
228        let result = py_result!(torsh_tensor::Tensor::from_data(
229            output_data,
230            shape.to_vec(),
231            input.tensor.device()
232        ))?;
233
234        Ok(PyTensor { tensor: result })
235    }
236
237    /// Get layer parameters
238    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
239        let mut params = Vec::new();
240        if let Some(ref weight) = self.weight {
241            params.push(PyTensor {
242                tensor: weight.clone(),
243            });
244        }
245        if let Some(ref bias) = self.bias {
246            params.push(PyTensor {
247                tensor: bias.clone(),
248            });
249        }
250        Ok(params)
251    }
252
253    /// Get named parameters
254    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
255        let mut params = HashMap::new();
256        if let Some(ref weight) = self.weight {
257            params.insert(
258                "weight".to_string(),
259                PyTensor {
260                    tensor: weight.clone(),
261                },
262            );
263        }
264        if let Some(ref bias) = self.bias {
265            params.insert(
266                "bias".to_string(),
267                PyTensor {
268                    tensor: bias.clone(),
269                },
270            );
271        }
272        Ok(params)
273    }
274
275    /// Set training mode
276    fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
277        self.training = mode.unwrap_or(true);
278        Ok(())
279    }
280
281    /// Set evaluation mode
282    fn eval(&mut self) -> PyResult<()> {
283        self.training = false;
284        Ok(())
285    }
286
287    /// String representation
288    fn __repr__(&self) -> String {
289        format!(
290            "BatchNorm2d({}, eps={}, momentum={}, affine={}, track_running_stats={})",
291            self.num_features, self.eps, self.momentum, self.affine, self.track_running_stats
292        )
293    }
294}
295
296/// Batch Normalization 1D layer
297#[pyclass(name = "BatchNorm1d", extends = PyModule)]
298pub struct PyBatchNorm1d {
299    weight: Option<Tensor<f32>>,
300    bias: Option<Tensor<f32>>,
301    running_mean: Tensor<f32>,
302    running_var: Tensor<f32>,
303    num_features: usize,
304    eps: f32,
305    momentum: f32,
306    affine: bool,
307    track_running_stats: bool,
308    training: bool,
309    num_batches_tracked: usize,
310}
311
312#[pymethods]
313impl PyBatchNorm1d {
314    #[new]
315    fn new(
316        num_features: usize,
317        eps: Option<f32>,
318        momentum: Option<f32>,
319        affine: Option<bool>,
320        track_running_stats: Option<bool>,
321    ) -> PyResult<(Self, PyModule)> {
322        let eps = eps.unwrap_or(1e-5);
323        let momentum = momentum.unwrap_or(0.1);
324        let affine = affine.unwrap_or(true);
325        let track_running_stats = track_running_stats.unwrap_or(true);
326
327        let shape = vec![num_features];
328
329        // Initialize weight and bias if affine=true
330        let (weight, bias) = if affine {
331            let weight = py_result!(torsh_tensor::creation::ones(&shape))?.requires_grad_(true);
332            let bias = py_result!(torsh_tensor::creation::zeros(&shape))?.requires_grad_(true);
333            (Some(weight), Some(bias))
334        } else {
335            (None, None)
336        };
337
338        // Initialize running statistics
339        let running_mean = py_result!(torsh_tensor::creation::zeros(&shape))?;
340        let running_var = py_result!(torsh_tensor::creation::ones(&shape))?;
341
342        Ok((
343            Self {
344                weight,
345                bias,
346                running_mean,
347                running_var,
348                num_features,
349                eps,
350                momentum,
351                affine,
352                track_running_stats,
353                training: true,
354                num_batches_tracked: 0,
355            },
356            PyModule::new(),
357        ))
358    }
359
360    /// Forward pass through batch normalization
361    fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
362        // ✅ Proper batch normalization implementation with statistics
363        let shape = input.tensor.shape().dims().to_vec();
364
365        // Expect input: (batch, channels) for 1D
366        if shape.len() < 2 {
367            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
368                "Expected at least 2D input, got {}D",
369                shape.len()
370            )));
371        }
372
373        let batch_size = shape[0];
374        let num_features = shape[1];
375
376        if num_features != self.num_features {
377            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
378                "Expected {} features, got {}",
379                self.num_features, num_features
380            )));
381        }
382
383        let input_data = py_result!(input.tensor.data())?;
384        let mut output_data = input_data.clone();
385
386        if self.training {
387            // Training mode: compute batch statistics
388            if self.track_running_stats {
389                self.num_batches_tracked += 1;
390            }
391
392            // Compute mean and variance for each feature
393            for c in 0..num_features {
394                let mut sum = 0.0;
395                let mut sum_sq = 0.0;
396                let mut count = 0;
397
398                for b in 0..batch_size {
399                    let idx = b * num_features + c;
400                    let val = input_data[idx];
401                    sum += val;
402                    sum_sq += val * val;
403                    count += 1;
404                }
405
406                let mean = sum / count as f32;
407                let var = (sum_sq / count as f32) - (mean * mean);
408
409                // Update running statistics
410                if self.track_running_stats {
411                    let mut running_mean_data = py_result!(self.running_mean.data())?;
412                    let mut running_var_data = py_result!(self.running_var.data())?;
413
414                    running_mean_data[c] =
415                        (1.0 - self.momentum) * running_mean_data[c] + self.momentum * mean;
416                    running_var_data[c] =
417                        (1.0 - self.momentum) * running_var_data[c] + self.momentum * var;
418
419                    self.running_mean = py_result!(torsh_tensor::Tensor::from_data(
420                        running_mean_data,
421                        vec![num_features],
422                        self.running_mean.device()
423                    ))?;
424                    self.running_var = py_result!(torsh_tensor::Tensor::from_data(
425                        running_var_data,
426                        vec![num_features],
427                        self.running_var.device()
428                    ))?;
429                }
430
431                // Normalize
432                let std = (var + self.eps).sqrt();
433                for b in 0..batch_size {
434                    let idx = b * num_features + c;
435                    output_data[idx] = (output_data[idx] - mean) / std;
436                }
437
438                // Apply affine transformation
439                if self.affine {
440                    if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
441                        let weight_data = py_result!(weight.data())?;
442                        let bias_data = py_result!(bias.data())?;
443
444                        for b in 0..batch_size {
445                            let idx = b * num_features + c;
446                            output_data[idx] = output_data[idx] * weight_data[c] + bias_data[c];
447                        }
448                    }
449                }
450            }
451        } else {
452            // Evaluation mode: use running statistics
453            let running_mean_data = py_result!(self.running_mean.data())?;
454            let running_var_data = py_result!(self.running_var.data())?;
455
456            for c in 0..num_features {
457                let mean = running_mean_data[c];
458                let var = running_var_data[c];
459                let std = (var + self.eps).sqrt();
460
461                for b in 0..batch_size {
462                    let idx = b * num_features + c;
463                    output_data[idx] = (output_data[idx] - mean) / std;
464                }
465
466                // Apply affine transformation
467                if self.affine {
468                    if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
469                        let weight_data = py_result!(weight.data())?;
470                        let bias_data = py_result!(bias.data())?;
471
472                        for b in 0..batch_size {
473                            let idx = b * num_features + c;
474                            output_data[idx] = output_data[idx] * weight_data[c] + bias_data[c];
475                        }
476                    }
477                }
478            }
479        }
480
481        let result = py_result!(torsh_tensor::Tensor::from_data(
482            output_data,
483            shape.to_vec(),
484            input.tensor.device()
485        ))?;
486
487        Ok(PyTensor { tensor: result })
488    }
489
490    /// Get layer parameters
491    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
492        let mut params = Vec::new();
493        if let Some(ref weight) = self.weight {
494            params.push(PyTensor {
495                tensor: weight.clone(),
496            });
497        }
498        if let Some(ref bias) = self.bias {
499            params.push(PyTensor {
500                tensor: bias.clone(),
501            });
502        }
503        Ok(params)
504    }
505
506    /// Get named parameters
507    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
508        let mut params = HashMap::new();
509        if let Some(ref weight) = self.weight {
510            params.insert(
511                "weight".to_string(),
512                PyTensor {
513                    tensor: weight.clone(),
514                },
515            );
516        }
517        if let Some(ref bias) = self.bias {
518            params.insert(
519                "bias".to_string(),
520                PyTensor {
521                    tensor: bias.clone(),
522                },
523            );
524        }
525        Ok(params)
526    }
527
528    /// Set training mode
529    fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
530        self.training = mode.unwrap_or(true);
531        Ok(())
532    }
533
534    /// Set evaluation mode
535    fn eval(&mut self) -> PyResult<()> {
536        self.training = false;
537        Ok(())
538    }
539
540    /// String representation
541    fn __repr__(&self) -> String {
542        format!(
543            "BatchNorm1d({}, eps={}, momentum={}, affine={}, track_running_stats={})",
544            self.num_features, self.eps, self.momentum, self.affine, self.track_running_stats
545        )
546    }
547}
548
549/// Layer Normalization layer
550#[pyclass(name = "LayerNorm", extends = PyModule)]
551pub struct PyLayerNorm {
552    weight: Option<Tensor<f32>>,
553    bias: Option<Tensor<f32>>,
554    normalized_shape: Vec<usize>,
555    eps: f32,
556    elementwise_affine: bool,
557}
558
559#[pymethods]
560impl PyLayerNorm {
561    #[new]
562    fn new(
563        normalized_shape: Vec<usize>,
564        eps: Option<f32>,
565        elementwise_affine: Option<bool>,
566    ) -> PyResult<(Self, PyModule)> {
567        let eps = eps.unwrap_or(1e-5);
568        let elementwise_affine = elementwise_affine.unwrap_or(true);
569
570        // Initialize weight and bias if elementwise_affine=true
571        let (weight, bias) = if elementwise_affine {
572            let weight =
573                py_result!(torsh_tensor::creation::ones(&normalized_shape))?.requires_grad_(true);
574            let bias =
575                py_result!(torsh_tensor::creation::zeros(&normalized_shape))?.requires_grad_(true);
576            (Some(weight), Some(bias))
577        } else {
578            (None, None)
579        };
580
581        Ok((
582            Self {
583                weight,
584                bias,
585                normalized_shape,
586                eps,
587                elementwise_affine,
588            },
589            PyModule::new(),
590        ))
591    }
592
593    /// Forward pass through layer normalization
594    fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
595        // ✅ Proper layer normalization implementation
596        let shape = input.tensor.shape().dims().to_vec();
597        let ndim = shape.len();
598        let norm_ndim = self.normalized_shape.len();
599
600        // Verify that normalized_shape matches the last dimensions of input
601        if norm_ndim > ndim {
602            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
603                "normalized_shape has {} dimensions but input has only {}",
604                norm_ndim, ndim
605            )));
606        }
607
608        // Check that the normalized dimensions match
609        for i in 0..norm_ndim {
610            if shape[ndim - norm_ndim + i] != self.normalized_shape[i] {
611                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
612                    "Input shape {:?} doesn't match normalized_shape {:?}",
613                    shape, self.normalized_shape
614                )));
615            }
616        }
617
618        // Calculate the number of elements to normalize over
619        let norm_size: usize = self.normalized_shape.iter().product();
620        let batch_size: usize = shape[..ndim - norm_ndim].iter().product();
621
622        let input_data = py_result!(input.tensor.data())?;
623        let mut output_data = input_data.clone();
624
625        // Normalize each batch independently
626        for batch_idx in 0..batch_size {
627            let start = batch_idx * norm_size;
628            let end = start + norm_size;
629
630            // Compute mean
631            let mut sum = 0.0;
632            for i in start..end {
633                sum += input_data[i];
634            }
635            let mean = sum / norm_size as f32;
636
637            // Compute variance
638            let mut sum_sq_diff = 0.0;
639            for i in start..end {
640                let diff = input_data[i] - mean;
641                sum_sq_diff += diff * diff;
642            }
643            let variance = sum_sq_diff / norm_size as f32;
644
645            // Normalize
646            let std = (variance + self.eps).sqrt();
647            for i in start..end {
648                output_data[i] = (output_data[i] - mean) / std;
649            }
650
651            // Apply affine transformation if enabled
652            if self.elementwise_affine {
653                if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
654                    let weight_data = py_result!(weight.data())?;
655                    let bias_data = py_result!(bias.data())?;
656
657                    for i in 0..norm_size {
658                        let idx = start + i;
659                        output_data[idx] = output_data[idx] * weight_data[i] + bias_data[i];
660                    }
661                }
662            }
663        }
664
665        let result = py_result!(torsh_tensor::Tensor::from_data(
666            output_data,
667            shape.to_vec(),
668            input.tensor.device()
669        ))?;
670
671        Ok(PyTensor { tensor: result })
672    }
673
674    /// Get layer parameters
675    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
676        let mut params = Vec::new();
677        if let Some(ref weight) = self.weight {
678            params.push(PyTensor {
679                tensor: weight.clone(),
680            });
681        }
682        if let Some(ref bias) = self.bias {
683            params.push(PyTensor {
684                tensor: bias.clone(),
685            });
686        }
687        Ok(params)
688    }
689
690    /// Get named parameters
691    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
692        let mut params = HashMap::new();
693        if let Some(ref weight) = self.weight {
694            params.insert(
695                "weight".to_string(),
696                PyTensor {
697                    tensor: weight.clone(),
698                },
699            );
700        }
701        if let Some(ref bias) = self.bias {
702            params.insert(
703                "bias".to_string(),
704                PyTensor {
705                    tensor: bias.clone(),
706                },
707            );
708        }
709        Ok(params)
710    }
711
712    /// String representation
713    fn __repr__(&self) -> String {
714        format!(
715            "LayerNorm({:?}, eps={}, elementwise_affine={})",
716            self.normalized_shape, self.eps, self.elementwise_affine
717        )
718    }
719}