1use torsh_core::{Result as TorshResult, TorshError};
4use torsh_tensor::{stats::StatMode, Tensor};
5
6#[allow(clippy::too_many_arguments)]
10pub fn batch_norm(
11 input: &Tensor,
12 running_mean: Option<&Tensor>,
13 running_var: Option<&Tensor>,
14 weight: Option<&Tensor>,
15 bias: Option<&Tensor>,
16 training: bool,
17 momentum: f64,
18 eps: f64,
19) -> TorshResult<Tensor> {
20 let shape = input.shape().dims().to_vec();
22 let ndim = shape.len();
23
24 if ndim < 2 {
25 return Err(TorshError::invalid_argument_with_context(
26 "Batch norm requires at least 2D input",
27 "batch_norm",
28 ));
29 }
30
31 let num_features = shape[1];
32
33 let (mean, var) = if training {
35 let axes: Vec<usize> = (0..ndim).filter(|&i| i != 1).collect();
37 let mean = input.mean(Some(&axes), true)?;
38 let var = input.var(Some(&axes), true, StatMode::Population)?;
39
40 if let (Some(running_mean), Some(running_var)) = (running_mean, running_var) {
42 let _running_mean_update = running_mean
44 .mul_scalar((1.0 - momentum) as f32)?
45 .add_op(&mean.mul_scalar(momentum as f32)?)?;
46 let _running_var_update = running_var
47 .mul_scalar((1.0 - momentum) as f32)?
48 .add_op(&var.mul_scalar(momentum as f32)?)?;
49
50 }
53
54 (mean, var)
55 } else {
56 match (running_mean, running_var) {
58 (Some(rm), Some(rv)) => (rm.clone(), rv.clone()),
59 _ => {
60 return Err(TorshError::invalid_argument_with_context(
61 "Running mean and var required for eval mode",
62 "batch_norm",
63 ))
64 }
65 }
66 };
67
68 let std = var.add_scalar(eps as f32)?.sqrt()?;
70 let normalized = input.sub(&mean)?.div(&std)?;
71
72 let output = match (weight, bias) {
74 (Some(w), Some(b)) => {
75 let mut w_shape = vec![1; ndim];
77 w_shape[1] = num_features;
78 let w_reshaped = w.view(&w_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
79
80 let mut b_shape = vec![1; ndim];
81 b_shape[1] = num_features;
82 let b_reshaped = b.view(&b_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
83
84 normalized.mul_op(&w_reshaped)?.add_op(&b_reshaped)?
85 }
86 (Some(w), None) => {
87 let mut w_shape = vec![1; ndim];
88 w_shape[1] = num_features;
89 let w_reshaped = w.view(&w_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
90 normalized.mul_op(&w_reshaped)?
91 }
92 (None, Some(b)) => {
93 let mut b_shape = vec![1; ndim];
94 b_shape[1] = num_features;
95 let b_reshaped = b.view(&b_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
96 normalized.add_op(&b_reshaped)?
97 }
98 (None, None) => normalized,
99 };
100
101 Ok(output)
102}
103
104pub fn layer_norm(
108 input: &Tensor,
109 normalized_shape: &[usize],
110 weight: Option<&Tensor>,
111 bias: Option<&Tensor>,
112 eps: f64,
113) -> TorshResult<Tensor> {
114 let ndim = input.shape().ndim();
116 let norm_ndim = normalized_shape.len();
117
118 if norm_ndim > ndim {
119 return Err(TorshError::invalid_argument_with_context(
120 "Normalized shape dimension count exceeds input dimensions",
121 "layer_norm",
122 ));
123 }
124
125 let axes: Vec<usize> = ((ndim - norm_ndim)..ndim).collect();
127
128 let mean = input.mean(Some(&axes), true)?;
130 let var = input.var(Some(&axes), true, StatMode::Population)?;
131
132 let std = var.add_scalar(eps as f32)?.sqrt()?;
134 let normalized = input.sub(&mean)?.div(&std)?;
135
136 let output = match (weight, bias) {
138 (Some(w), Some(b)) => normalized.mul_op(w)?.add_op(b)?,
139 (Some(w), None) => normalized.mul_op(w)?,
140 (None, Some(b)) => normalized.add_op(b)?,
141 (None, None) => normalized,
142 };
143
144 Ok(output)
145}
146
147#[allow(clippy::too_many_arguments)]
151pub fn instance_norm(
152 input: &Tensor,
153 _running_mean: Option<&Tensor>,
154 _running_var: Option<&Tensor>,
155 weight: Option<&Tensor>,
156 bias: Option<&Tensor>,
157 _use_input_stats: bool,
158 _momentum: f64,
159 eps: f64,
160) -> TorshResult<Tensor> {
161 let shape = input.shape().dims().to_vec();
164 let ndim = shape.len();
165
166 if ndim < 3 {
167 return Err(TorshError::invalid_argument_with_context(
168 "Instance norm requires at least 3D input",
169 "instance_norm",
170 ));
171 }
172
173 let axes: Vec<usize> = (2..ndim).collect();
175
176 let mean = input.mean(Some(&axes), true)?;
178 let var = input.var(Some(&axes), true, StatMode::Population)?;
179
180 let std = var.add_scalar(eps as f32)?.sqrt()?;
182 let normalized = input.sub(&mean)?.div(&std)?;
183
184 let output = match (weight, bias) {
186 (Some(w), Some(b)) => {
187 let w = w.unsqueeze(0)?; let b = b.unsqueeze(0)?;
190 normalized.mul_op(&w)?.add_op(&b)?
191 }
192 (Some(w), None) => {
193 let w = w.unsqueeze(0)?;
194 normalized.mul_op(&w)?
195 }
196 (None, Some(b)) => {
197 let b = b.unsqueeze(0)?;
198 normalized.add_op(&b)?
199 }
200 (None, None) => normalized,
201 };
202
203 Ok(output)
204}
205
206pub fn group_norm(
210 input: &Tensor,
211 num_groups: usize,
212 weight: Option<&Tensor>,
213 bias: Option<&Tensor>,
214 eps: f64,
215) -> TorshResult<Tensor> {
216 let shape = input.shape().dims().to_vec();
217 let ndim = shape.len();
218
219 if ndim < 2 {
220 return Err(TorshError::invalid_argument_with_context(
221 "Group norm requires at least 2D input",
222 "group_norm",
223 ));
224 }
225
226 let batch_size = shape[0];
227 let num_channels = shape[1];
228
229 if num_channels % num_groups != 0 {
230 return Err(TorshError::invalid_argument_with_context(
231 &format!(
232 "Number of channels {} must be divisible by num_groups {}",
233 num_channels, num_groups
234 ),
235 "group_norm",
236 ));
237 }
238
239 let channels_per_group = num_channels / num_groups;
240
241 let mut new_shape = vec![batch_size, num_groups, channels_per_group];
243 new_shape.extend_from_slice(&shape[2..]);
244
245 let reshaped = input.reshape(&new_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
246
247 let axes: Vec<usize> = (2..new_shape.len()).collect();
249 let mean = reshaped.mean(Some(&axes), true)?;
250 let var = reshaped.var(Some(&axes), true, StatMode::Population)?;
251
252 let std = var.add_scalar(eps as f32)?.sqrt()?;
254 let normalized = reshaped.sub(&mean)?.div(&std)?;
255
256 let normalized = normalized.reshape(&shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
258
259 let output = match (weight, bias) {
261 (Some(w), Some(b)) => {
262 let w = w.unsqueeze(0)?; let b = b.unsqueeze(0)?;
264 normalized.mul_op(&w)?.add_op(&b)?
265 }
266 (Some(w), None) => {
267 let w = w.unsqueeze(0)?;
268 normalized.mul_op(&w)?
269 }
270 (None, Some(b)) => {
271 let b = b.unsqueeze(0)?;
272 normalized.add_op(&b)?
273 }
274 (None, None) => normalized,
275 };
276
277 Ok(output)
278}
279
280pub fn local_response_norm(
284 input: &Tensor,
285 size: usize,
286 alpha: f64,
287 beta: f64,
288 k: f64,
289) -> TorshResult<Tensor> {
290 let shape_obj = input.shape();
294 let shape = shape_obj.dims();
295 if shape.len() < 2 {
296 return Err(TorshError::invalid_argument_with_context(
297 "Local response norm requires at least 2D input",
298 "local_response_norm",
299 ));
300 }
301
302 let _num_channels = shape[1];
303
304 let _padding = size / 2;
306
307 let squared = input.pow_scalar(2.0)?;
309
310 let sum_sq = squared.clone();
317
318 let n = size as f32;
320 let denominator = sum_sq
321 .mul_scalar((alpha / n as f64) as f32)?
322 .add_scalar(k as f32)?
323 .pow_scalar(beta as f32)?;
324
325 input.div(&denominator)
327}
328
329pub fn normalize(
331 input: &Tensor,
332 p: f64,
333 dim: i64,
334 eps: f64,
335 out: Option<&mut Tensor>,
336) -> TorshResult<Tensor> {
337 if p <= 0.0 {
339 return Err(TorshError::invalid_argument_with_context(
340 &format!("normalize: p must be positive, got {}", p),
341 "normalize",
342 ));
343 }
344
345 let ndim = input.ndim() as i64;
347 let dim = if dim < 0 { ndim + dim } else { dim };
348
349 if dim < 0 || dim >= ndim {
350 return Err(TorshError::InvalidArgument(format!(
351 "Dimension {} out of range for tensor with {} dimensions",
352 dim, ndim
353 )));
354 }
355
356 let norm = if (p - 2.0).abs() < 1e-7 {
358 let squared = input.pow_scalar(2.0)?;
360 let sum = squared.sum_dim(&[dim as i32], true)?;
361 sum.sqrt()?
362 } else if (p - 1.0).abs() < 1e-7 {
363 let abs_vals = input.abs()?;
365 abs_vals.sum_dim(&[dim as i32], true)?
366 } else if p.is_infinite() && p.is_sign_positive() {
367 let abs_vals = input.abs()?;
369 abs_vals.max(Some(dim as usize), true)?
370 } else {
371 let abs_vals = input.abs()?;
373 let powered = abs_vals.pow_scalar(p as f32)?;
374 let sum = powered.sum_dim(&[dim as i32], true)?;
375 sum.pow_scalar((1.0 / p) as f32)?
376 };
377
378 let norm_eps = norm.add_scalar(eps as f32)?;
380
381 let normalized = input.div(&norm_eps)?;
383
384 if let Some(_out_tensor) = out {
385 return Err(TorshError::UnsupportedOperation {
388 op: "in-place normalize".to_string(),
389 dtype: "tensor".to_string(),
390 });
391 }
392
393 Ok(normalized)
394}
395
396pub fn weight_norm(weight: &Tensor, dim: i64) -> TorshResult<(Tensor, Tensor)> {
400 let squared = weight.pow_scalar(2.0)?;
402 let norm = squared.sum_dim(&[dim as i32], true)?.sqrt()?;
403
404 let direction = weight.div(&norm)?;
406
407 let magnitude = norm.squeeze(dim as i32)?;
409
410 Ok((magnitude, direction))
411}
412
413pub fn spectral_norm(
433 weight: &Tensor,
434 u: Option<&Tensor>,
435 n_power_iterations: usize,
436 eps: f64,
437) -> TorshResult<(Tensor, Tensor)> {
438 let shape_obj = weight.shape();
439 let shape = shape_obj.dims();
440
441 if shape.len() < 2 {
442 return Err(TorshError::invalid_argument_with_context(
443 "Spectral norm requires at least 2D weight tensor",
444 "spectral_norm",
445 ));
446 }
447
448 let out_features = shape[0];
451 let in_features: usize = shape[1..].iter().product();
452 let weight_mat = weight.view(&[out_features as i32, in_features as i32])?;
453
454 let mut u_vec = if let Some(u_input) = u {
456 u_input.clone()
457 } else {
458 use torsh_tensor::creation::randn;
460 randn::<f32>(&[out_features])?
461 };
462
463 let u_norm = u_vec.pow_scalar(2.0)?.sum()?.sqrt()?;
465 u_vec = u_vec.div_scalar(u_norm.item()? + eps as f32)?;
466
467 for _ in 0..n_power_iterations {
469 let weight_t = weight_mat.t()?;
471 let v = weight_t.matmul(&u_vec.view(&[out_features as i32, 1])?)?;
472 let v = v.squeeze(1)?;
473
474 let v_norm = v.pow_scalar(2.0)?.sum()?.sqrt()?;
476 let v = v.div_scalar(v_norm.item()? + eps as f32)?;
477
478 let u = weight_mat.matmul(&v.view(&[in_features as i32, 1])?)?;
480 u_vec = u.squeeze(1)?;
481
482 let u_norm = u_vec.pow_scalar(2.0)?.sum()?.sqrt()?;
484 u_vec = u_vec.div_scalar(u_norm.item()? + eps as f32)?;
485 }
486
487 let weight_t = weight_mat.t()?;
490 let v = weight_t.matmul(&u_vec.view(&[out_features as i32, 1])?)?;
491 let v = v.squeeze(1)?;
492
493 let v_norm = v.pow_scalar(2.0)?.sum()?.sqrt()?;
495 let v = v.div_scalar(v_norm.item()? + eps as f32)?;
496
497 let wv = weight_mat.matmul(&v.view(&[in_features as i32, 1])?)?;
499 let wv = wv.squeeze(1)?;
500
501 let u_wv = u_vec.mul(&wv)?.sum()?;
503 let sigma = u_wv.item()?;
504
505 let normalized_weight = weight.div_scalar(sigma + eps as f32)?;
507
508 Ok((normalized_weight, u_vec))
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use torsh_tensor::creation::tensor_1d;
515
516 #[test]
517 fn test_normalize() {
518 let input = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
520
521 let result = normalize(&input, -1.0, 0, 1e-12, None);
523 assert!(result.is_err());
524
525 let result = normalize(&input, 0.0, 0, 1e-12, None);
527 assert!(result.is_err());
528
529 let result = normalize(&input, 2.0, 0, 1e-12, None);
531 assert!(result.is_ok());
532
533 let result = normalize(&input, 1.0, 0, 1e-12, None);
535 assert!(result.is_ok());
536
537 let result = normalize(&input, 3.0, 0, 1e-12, None);
539 assert!(result.is_ok());
540
541 let result = normalize(&input, f64::INFINITY, 0, 1e-12, None);
543 assert!(result.is_ok());
544 }
545}