1use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use std::collections::HashMap;
7use torsh_tensor::Tensor;
8
9#[pyclass(name = "BatchNorm2d", extends = PyModule)]
11pub struct PyBatchNorm2d {
12 weight: Option<Tensor<f32>>,
13 bias: Option<Tensor<f32>>,
14 running_mean: Tensor<f32>,
15 running_var: Tensor<f32>,
16 num_features: usize,
17 eps: f32,
18 momentum: f32,
19 affine: bool,
20 track_running_stats: bool,
21 training: bool,
22 num_batches_tracked: usize,
23}
24
25#[pymethods]
26impl PyBatchNorm2d {
27 #[new]
28 fn new(
29 num_features: usize,
30 eps: Option<f32>,
31 momentum: Option<f32>,
32 affine: Option<bool>,
33 track_running_stats: Option<bool>,
34 ) -> PyResult<(Self, PyModule)> {
35 let eps = eps.unwrap_or(1e-5);
36 let momentum = momentum.unwrap_or(0.1);
37 let affine = affine.unwrap_or(true);
38 let track_running_stats = track_running_stats.unwrap_or(true);
39
40 let shape = vec![num_features];
41
42 let (weight, bias) = if affine {
44 let weight = py_result!(torsh_tensor::creation::ones(&shape))?.requires_grad_(true);
45 let bias = py_result!(torsh_tensor::creation::zeros(&shape))?.requires_grad_(true);
46 (Some(weight), Some(bias))
47 } else {
48 (None, None)
49 };
50
51 let running_mean = py_result!(torsh_tensor::creation::zeros(&shape))?;
53 let running_var = py_result!(torsh_tensor::creation::ones(&shape))?;
54
55 Ok((
56 Self {
57 weight,
58 bias,
59 running_mean,
60 running_var,
61 num_features,
62 eps,
63 momentum,
64 affine,
65 track_running_stats,
66 training: true,
67 num_batches_tracked: 0,
68 },
69 PyModule::new(),
70 ))
71 }
72
73 fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
75 let shape = input.tensor.shape().dims().to_vec();
77
78 if shape.len() != 4 {
80 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
81 "Expected 4D input (NCHW), got {}D",
82 shape.len()
83 )));
84 }
85
86 let batch_size = shape[0];
87 let num_channels = shape[1];
88 let height = shape[2];
89 let width = shape[3];
90 let spatial_size = height * width;
91
92 if num_channels != self.num_features {
93 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
94 "Expected {} channels, got {}",
95 self.num_features, num_channels
96 )));
97 }
98
99 let input_data = py_result!(input.tensor.data())?;
100 let mut output_data = input_data.clone();
101
102 if self.training {
103 if self.track_running_stats {
105 self.num_batches_tracked += 1;
106 }
107
108 for c in 0..num_channels {
110 let mut sum = 0.0;
111 let mut sum_sq = 0.0;
112 let mut count = 0;
113
114 for b in 0..batch_size {
115 for h in 0..height {
116 for w in 0..width {
117 let idx =
118 b * num_channels * spatial_size + c * spatial_size + h * width + w;
119 let val = input_data[idx];
120 sum += val;
121 sum_sq += val * val;
122 count += 1;
123 }
124 }
125 }
126
127 let mean = sum / count as f32;
128 let var = (sum_sq / count as f32) - (mean * mean);
129
130 if self.track_running_stats {
132 let mut running_mean_data = py_result!(self.running_mean.data())?;
133 let mut running_var_data = py_result!(self.running_var.data())?;
134
135 running_mean_data[c] =
136 (1.0 - self.momentum) * running_mean_data[c] + self.momentum * mean;
137 running_var_data[c] =
138 (1.0 - self.momentum) * running_var_data[c] + self.momentum * var;
139
140 self.running_mean = py_result!(torsh_tensor::Tensor::from_data(
141 running_mean_data,
142 vec![num_channels],
143 self.running_mean.device()
144 ))?;
145 self.running_var = py_result!(torsh_tensor::Tensor::from_data(
146 running_var_data,
147 vec![num_channels],
148 self.running_var.device()
149 ))?;
150 }
151
152 let std = (var + self.eps).sqrt();
154 for b in 0..batch_size {
155 for h in 0..height {
156 for w in 0..width {
157 let idx =
158 b * num_channels * spatial_size + c * spatial_size + h * width + w;
159 output_data[idx] = (output_data[idx] - mean) / std;
160 }
161 }
162 }
163
164 if self.affine {
166 if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
167 let weight_data = py_result!(weight.data())?;
168 let bias_data = py_result!(bias.data())?;
169
170 for b in 0..batch_size {
171 for h in 0..height {
172 for w in 0..width {
173 let idx = b * num_channels * spatial_size
174 + c * spatial_size
175 + h * width
176 + w;
177 output_data[idx] =
178 output_data[idx] * weight_data[c] + bias_data[c];
179 }
180 }
181 }
182 }
183 }
184 }
185 } else {
186 let running_mean_data = py_result!(self.running_mean.data())?;
188 let running_var_data = py_result!(self.running_var.data())?;
189
190 for c in 0..num_channels {
191 let mean = running_mean_data[c];
192 let var = running_var_data[c];
193 let std = (var + self.eps).sqrt();
194
195 for b in 0..batch_size {
196 for h in 0..height {
197 for w in 0..width {
198 let idx =
199 b * num_channels * spatial_size + c * spatial_size + h * width + w;
200 output_data[idx] = (output_data[idx] - mean) / std;
201 }
202 }
203 }
204
205 if self.affine {
207 if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
208 let weight_data = py_result!(weight.data())?;
209 let bias_data = py_result!(bias.data())?;
210
211 for b in 0..batch_size {
212 for h in 0..height {
213 for w in 0..width {
214 let idx = b * num_channels * spatial_size
215 + c * spatial_size
216 + h * width
217 + w;
218 output_data[idx] =
219 output_data[idx] * weight_data[c] + bias_data[c];
220 }
221 }
222 }
223 }
224 }
225 }
226 }
227
228 let result = py_result!(torsh_tensor::Tensor::from_data(
229 output_data,
230 shape.to_vec(),
231 input.tensor.device()
232 ))?;
233
234 Ok(PyTensor { tensor: result })
235 }
236
237 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
239 let mut params = Vec::new();
240 if let Some(ref weight) = self.weight {
241 params.push(PyTensor {
242 tensor: weight.clone(),
243 });
244 }
245 if let Some(ref bias) = self.bias {
246 params.push(PyTensor {
247 tensor: bias.clone(),
248 });
249 }
250 Ok(params)
251 }
252
253 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
255 let mut params = HashMap::new();
256 if let Some(ref weight) = self.weight {
257 params.insert(
258 "weight".to_string(),
259 PyTensor {
260 tensor: weight.clone(),
261 },
262 );
263 }
264 if let Some(ref bias) = self.bias {
265 params.insert(
266 "bias".to_string(),
267 PyTensor {
268 tensor: bias.clone(),
269 },
270 );
271 }
272 Ok(params)
273 }
274
275 fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
277 self.training = mode.unwrap_or(true);
278 Ok(())
279 }
280
281 fn eval(&mut self) -> PyResult<()> {
283 self.training = false;
284 Ok(())
285 }
286
287 fn __repr__(&self) -> String {
289 format!(
290 "BatchNorm2d({}, eps={}, momentum={}, affine={}, track_running_stats={})",
291 self.num_features, self.eps, self.momentum, self.affine, self.track_running_stats
292 )
293 }
294}
295
296#[pyclass(name = "BatchNorm1d", extends = PyModule)]
298pub struct PyBatchNorm1d {
299 weight: Option<Tensor<f32>>,
300 bias: Option<Tensor<f32>>,
301 running_mean: Tensor<f32>,
302 running_var: Tensor<f32>,
303 num_features: usize,
304 eps: f32,
305 momentum: f32,
306 affine: bool,
307 track_running_stats: bool,
308 training: bool,
309 num_batches_tracked: usize,
310}
311
312#[pymethods]
313impl PyBatchNorm1d {
314 #[new]
315 fn new(
316 num_features: usize,
317 eps: Option<f32>,
318 momentum: Option<f32>,
319 affine: Option<bool>,
320 track_running_stats: Option<bool>,
321 ) -> PyResult<(Self, PyModule)> {
322 let eps = eps.unwrap_or(1e-5);
323 let momentum = momentum.unwrap_or(0.1);
324 let affine = affine.unwrap_or(true);
325 let track_running_stats = track_running_stats.unwrap_or(true);
326
327 let shape = vec![num_features];
328
329 let (weight, bias) = if affine {
331 let weight = py_result!(torsh_tensor::creation::ones(&shape))?.requires_grad_(true);
332 let bias = py_result!(torsh_tensor::creation::zeros(&shape))?.requires_grad_(true);
333 (Some(weight), Some(bias))
334 } else {
335 (None, None)
336 };
337
338 let running_mean = py_result!(torsh_tensor::creation::zeros(&shape))?;
340 let running_var = py_result!(torsh_tensor::creation::ones(&shape))?;
341
342 Ok((
343 Self {
344 weight,
345 bias,
346 running_mean,
347 running_var,
348 num_features,
349 eps,
350 momentum,
351 affine,
352 track_running_stats,
353 training: true,
354 num_batches_tracked: 0,
355 },
356 PyModule::new(),
357 ))
358 }
359
360 fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
362 let shape = input.tensor.shape().dims().to_vec();
364
365 if shape.len() < 2 {
367 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
368 "Expected at least 2D input, got {}D",
369 shape.len()
370 )));
371 }
372
373 let batch_size = shape[0];
374 let num_features = shape[1];
375
376 if num_features != self.num_features {
377 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
378 "Expected {} features, got {}",
379 self.num_features, num_features
380 )));
381 }
382
383 let input_data = py_result!(input.tensor.data())?;
384 let mut output_data = input_data.clone();
385
386 if self.training {
387 if self.track_running_stats {
389 self.num_batches_tracked += 1;
390 }
391
392 for c in 0..num_features {
394 let mut sum = 0.0;
395 let mut sum_sq = 0.0;
396 let mut count = 0;
397
398 for b in 0..batch_size {
399 let idx = b * num_features + c;
400 let val = input_data[idx];
401 sum += val;
402 sum_sq += val * val;
403 count += 1;
404 }
405
406 let mean = sum / count as f32;
407 let var = (sum_sq / count as f32) - (mean * mean);
408
409 if self.track_running_stats {
411 let mut running_mean_data = py_result!(self.running_mean.data())?;
412 let mut running_var_data = py_result!(self.running_var.data())?;
413
414 running_mean_data[c] =
415 (1.0 - self.momentum) * running_mean_data[c] + self.momentum * mean;
416 running_var_data[c] =
417 (1.0 - self.momentum) * running_var_data[c] + self.momentum * var;
418
419 self.running_mean = py_result!(torsh_tensor::Tensor::from_data(
420 running_mean_data,
421 vec![num_features],
422 self.running_mean.device()
423 ))?;
424 self.running_var = py_result!(torsh_tensor::Tensor::from_data(
425 running_var_data,
426 vec![num_features],
427 self.running_var.device()
428 ))?;
429 }
430
431 let std = (var + self.eps).sqrt();
433 for b in 0..batch_size {
434 let idx = b * num_features + c;
435 output_data[idx] = (output_data[idx] - mean) / std;
436 }
437
438 if self.affine {
440 if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
441 let weight_data = py_result!(weight.data())?;
442 let bias_data = py_result!(bias.data())?;
443
444 for b in 0..batch_size {
445 let idx = b * num_features + c;
446 output_data[idx] = output_data[idx] * weight_data[c] + bias_data[c];
447 }
448 }
449 }
450 }
451 } else {
452 let running_mean_data = py_result!(self.running_mean.data())?;
454 let running_var_data = py_result!(self.running_var.data())?;
455
456 for c in 0..num_features {
457 let mean = running_mean_data[c];
458 let var = running_var_data[c];
459 let std = (var + self.eps).sqrt();
460
461 for b in 0..batch_size {
462 let idx = b * num_features + c;
463 output_data[idx] = (output_data[idx] - mean) / std;
464 }
465
466 if self.affine {
468 if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
469 let weight_data = py_result!(weight.data())?;
470 let bias_data = py_result!(bias.data())?;
471
472 for b in 0..batch_size {
473 let idx = b * num_features + c;
474 output_data[idx] = output_data[idx] * weight_data[c] + bias_data[c];
475 }
476 }
477 }
478 }
479 }
480
481 let result = py_result!(torsh_tensor::Tensor::from_data(
482 output_data,
483 shape.to_vec(),
484 input.tensor.device()
485 ))?;
486
487 Ok(PyTensor { tensor: result })
488 }
489
490 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
492 let mut params = Vec::new();
493 if let Some(ref weight) = self.weight {
494 params.push(PyTensor {
495 tensor: weight.clone(),
496 });
497 }
498 if let Some(ref bias) = self.bias {
499 params.push(PyTensor {
500 tensor: bias.clone(),
501 });
502 }
503 Ok(params)
504 }
505
506 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
508 let mut params = HashMap::new();
509 if let Some(ref weight) = self.weight {
510 params.insert(
511 "weight".to_string(),
512 PyTensor {
513 tensor: weight.clone(),
514 },
515 );
516 }
517 if let Some(ref bias) = self.bias {
518 params.insert(
519 "bias".to_string(),
520 PyTensor {
521 tensor: bias.clone(),
522 },
523 );
524 }
525 Ok(params)
526 }
527
528 fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
530 self.training = mode.unwrap_or(true);
531 Ok(())
532 }
533
534 fn eval(&mut self) -> PyResult<()> {
536 self.training = false;
537 Ok(())
538 }
539
540 fn __repr__(&self) -> String {
542 format!(
543 "BatchNorm1d({}, eps={}, momentum={}, affine={}, track_running_stats={})",
544 self.num_features, self.eps, self.momentum, self.affine, self.track_running_stats
545 )
546 }
547}
548
549#[pyclass(name = "LayerNorm", extends = PyModule)]
551pub struct PyLayerNorm {
552 weight: Option<Tensor<f32>>,
553 bias: Option<Tensor<f32>>,
554 normalized_shape: Vec<usize>,
555 eps: f32,
556 elementwise_affine: bool,
557}
558
559#[pymethods]
560impl PyLayerNorm {
561 #[new]
562 fn new(
563 normalized_shape: Vec<usize>,
564 eps: Option<f32>,
565 elementwise_affine: Option<bool>,
566 ) -> PyResult<(Self, PyModule)> {
567 let eps = eps.unwrap_or(1e-5);
568 let elementwise_affine = elementwise_affine.unwrap_or(true);
569
570 let (weight, bias) = if elementwise_affine {
572 let weight =
573 py_result!(torsh_tensor::creation::ones(&normalized_shape))?.requires_grad_(true);
574 let bias =
575 py_result!(torsh_tensor::creation::zeros(&normalized_shape))?.requires_grad_(true);
576 (Some(weight), Some(bias))
577 } else {
578 (None, None)
579 };
580
581 Ok((
582 Self {
583 weight,
584 bias,
585 normalized_shape,
586 eps,
587 elementwise_affine,
588 },
589 PyModule::new(),
590 ))
591 }
592
593 fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
595 let shape = input.tensor.shape().dims().to_vec();
597 let ndim = shape.len();
598 let norm_ndim = self.normalized_shape.len();
599
600 if norm_ndim > ndim {
602 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
603 "normalized_shape has {} dimensions but input has only {}",
604 norm_ndim, ndim
605 )));
606 }
607
608 for i in 0..norm_ndim {
610 if shape[ndim - norm_ndim + i] != self.normalized_shape[i] {
611 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
612 "Input shape {:?} doesn't match normalized_shape {:?}",
613 shape, self.normalized_shape
614 )));
615 }
616 }
617
618 let norm_size: usize = self.normalized_shape.iter().product();
620 let batch_size: usize = shape[..ndim - norm_ndim].iter().product();
621
622 let input_data = py_result!(input.tensor.data())?;
623 let mut output_data = input_data.clone();
624
625 for batch_idx in 0..batch_size {
627 let start = batch_idx * norm_size;
628 let end = start + norm_size;
629
630 let mut sum = 0.0;
632 for i in start..end {
633 sum += input_data[i];
634 }
635 let mean = sum / norm_size as f32;
636
637 let mut sum_sq_diff = 0.0;
639 for i in start..end {
640 let diff = input_data[i] - mean;
641 sum_sq_diff += diff * diff;
642 }
643 let variance = sum_sq_diff / norm_size as f32;
644
645 let std = (variance + self.eps).sqrt();
647 for i in start..end {
648 output_data[i] = (output_data[i] - mean) / std;
649 }
650
651 if self.elementwise_affine {
653 if let (Some(ref weight), Some(ref bias)) = (&self.weight, &self.bias) {
654 let weight_data = py_result!(weight.data())?;
655 let bias_data = py_result!(bias.data())?;
656
657 for i in 0..norm_size {
658 let idx = start + i;
659 output_data[idx] = output_data[idx] * weight_data[i] + bias_data[i];
660 }
661 }
662 }
663 }
664
665 let result = py_result!(torsh_tensor::Tensor::from_data(
666 output_data,
667 shape.to_vec(),
668 input.tensor.device()
669 ))?;
670
671 Ok(PyTensor { tensor: result })
672 }
673
674 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
676 let mut params = Vec::new();
677 if let Some(ref weight) = self.weight {
678 params.push(PyTensor {
679 tensor: weight.clone(),
680 });
681 }
682 if let Some(ref bias) = self.bias {
683 params.push(PyTensor {
684 tensor: bias.clone(),
685 });
686 }
687 Ok(params)
688 }
689
690 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
692 let mut params = HashMap::new();
693 if let Some(ref weight) = self.weight {
694 params.insert(
695 "weight".to_string(),
696 PyTensor {
697 tensor: weight.clone(),
698 },
699 );
700 }
701 if let Some(ref bias) = self.bias {
702 params.insert(
703 "bias".to_string(),
704 PyTensor {
705 tensor: bias.clone(),
706 },
707 );
708 }
709 Ok(params)
710 }
711
712 fn __repr__(&self) -> String {
714 format!(
715 "LayerNorm({:?}, eps={}, elementwise_affine={})",
716 self.normalized_shape, self.eps, self.elementwise_affine
717 )
718 }
719}