Skip to main content

torsh_functional/
normalization.rs

1//! Normalization functions for neural networks
2
3use torsh_core::{Result as TorshResult, TorshError};
4use torsh_tensor::{stats::StatMode, Tensor};
5
6/// Batch normalization
7///
8/// Applies batch normalization over a batch of inputs
9#[allow(clippy::too_many_arguments)]
10pub fn batch_norm(
11    input: &Tensor,
12    running_mean: Option<&Tensor>,
13    running_var: Option<&Tensor>,
14    weight: Option<&Tensor>,
15    bias: Option<&Tensor>,
16    training: bool,
17    momentum: f64,
18    eps: f64,
19) -> TorshResult<Tensor> {
20    // Input can be 2D (N, C), 3D (N, C, L), 4D (N, C, H, W) or 5D (N, C, D, H, W)
21    let shape = input.shape().dims().to_vec();
22    let ndim = shape.len();
23
24    if ndim < 2 {
25        return Err(TorshError::invalid_argument_with_context(
26            "Batch norm requires at least 2D input",
27            "batch_norm",
28        ));
29    }
30
31    let num_features = shape[1];
32
33    // Calculate mean and variance
34    let (mean, var) = if training {
35        // Calculate batch statistics
36        let axes: Vec<usize> = (0..ndim).filter(|&i| i != 1).collect();
37        let mean = input.mean(Some(&axes), true)?;
38        let var = input.var(Some(&axes), true, StatMode::Population)?;
39
40        // Update running statistics if provided
41        if let (Some(running_mean), Some(running_var)) = (running_mean, running_var) {
42            // running_mean = (1 - momentum) * running_mean + momentum * batch_mean
43            let _running_mean_update = running_mean
44                .mul_scalar((1.0 - momentum) as f32)?
45                .add_op(&mean.mul_scalar(momentum as f32)?)?;
46            let _running_var_update = running_var
47                .mul_scalar((1.0 - momentum) as f32)?
48                .add_op(&var.mul_scalar(momentum as f32)?)?;
49
50            // Note: In practice, these updates should be applied in-place
51            // This would require mutable references which we don't have here
52        }
53
54        (mean, var)
55    } else {
56        // Use running statistics
57        match (running_mean, running_var) {
58            (Some(rm), Some(rv)) => (rm.clone(), rv.clone()),
59            _ => {
60                return Err(TorshError::invalid_argument_with_context(
61                    "Running mean and var required for eval mode",
62                    "batch_norm",
63                ))
64            }
65        }
66    };
67
68    // Normalize: (x - mean) / sqrt(var + eps)
69    let std = var.add_scalar(eps as f32)?.sqrt()?;
70    let normalized = input.sub(&mean)?.div(&std)?;
71
72    // Apply affine transformation if weight and bias are provided
73    let output = match (weight, bias) {
74        (Some(w), Some(b)) => {
75            // Reshape weight and bias to match normalized dimensions
76            let mut w_shape = vec![1; ndim];
77            w_shape[1] = num_features;
78            let w_reshaped = w.view(&w_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
79
80            let mut b_shape = vec![1; ndim];
81            b_shape[1] = num_features;
82            let b_reshaped = b.view(&b_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
83
84            normalized.mul_op(&w_reshaped)?.add_op(&b_reshaped)?
85        }
86        (Some(w), None) => {
87            let mut w_shape = vec![1; ndim];
88            w_shape[1] = num_features;
89            let w_reshaped = w.view(&w_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
90            normalized.mul_op(&w_reshaped)?
91        }
92        (None, Some(b)) => {
93            let mut b_shape = vec![1; ndim];
94            b_shape[1] = num_features;
95            let b_reshaped = b.view(&b_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
96            normalized.add_op(&b_reshaped)?
97        }
98        (None, None) => normalized,
99    };
100
101    Ok(output)
102}
103
104/// Layer normalization
105///
106/// Applies layer normalization over a mini-batch of inputs
107pub fn layer_norm(
108    input: &Tensor,
109    normalized_shape: &[usize],
110    weight: Option<&Tensor>,
111    bias: Option<&Tensor>,
112    eps: f64,
113) -> TorshResult<Tensor> {
114    // Normalize over the last len(normalized_shape) dimensions
115    let ndim = input.shape().ndim();
116    let norm_ndim = normalized_shape.len();
117
118    if norm_ndim > ndim {
119        return Err(TorshError::invalid_argument_with_context(
120            "Normalized shape dimension count exceeds input dimensions",
121            "layer_norm",
122        ));
123    }
124
125    // Calculate axes to normalize over
126    let axes: Vec<usize> = ((ndim - norm_ndim)..ndim).collect();
127
128    // Calculate mean and variance
129    let mean = input.mean(Some(&axes), true)?;
130    let var = input.var(Some(&axes), true, StatMode::Population)?;
131
132    // Normalize
133    let std = var.add_scalar(eps as f32)?.sqrt()?;
134    let normalized = input.sub(&mean)?.div(&std)?;
135
136    // Apply affine transformation if provided
137    let output = match (weight, bias) {
138        (Some(w), Some(b)) => normalized.mul_op(w)?.add_op(b)?,
139        (Some(w), None) => normalized.mul_op(w)?,
140        (None, Some(b)) => normalized.add_op(b)?,
141        (None, None) => normalized,
142    };
143
144    Ok(output)
145}
146
147/// Instance normalization
148///
149/// Applies instance normalization over a batch of inputs
150#[allow(clippy::too_many_arguments)]
151pub fn instance_norm(
152    input: &Tensor,
153    _running_mean: Option<&Tensor>,
154    _running_var: Option<&Tensor>,
155    weight: Option<&Tensor>,
156    bias: Option<&Tensor>,
157    _use_input_stats: bool,
158    _momentum: f64,
159    eps: f64,
160) -> TorshResult<Tensor> {
161    // Instance norm normalizes each instance separately
162    // For 4D input (N, C, H, W), normalize over (H, W) for each (N, C)
163    let shape = input.shape().dims().to_vec();
164    let ndim = shape.len();
165
166    if ndim < 3 {
167        return Err(TorshError::invalid_argument_with_context(
168            "Instance norm requires at least 3D input",
169            "instance_norm",
170        ));
171    }
172
173    // Calculate axes to normalize over (spatial dimensions)
174    let axes: Vec<usize> = (2..ndim).collect();
175
176    // Calculate mean and variance
177    let mean = input.mean(Some(&axes), true)?;
178    let var = input.var(Some(&axes), true, StatMode::Population)?;
179
180    // Normalize
181    let std = var.add_scalar(eps as f32)?.sqrt()?;
182    let normalized = input.sub(&mean)?.div(&std)?;
183
184    // Apply affine transformation if provided
185    let output = match (weight, bias) {
186        (Some(w), Some(b)) => {
187            // Reshape weight and bias for broadcasting
188            let w = w.unsqueeze(0)?; // Add batch dimension
189            let b = b.unsqueeze(0)?;
190            normalized.mul_op(&w)?.add_op(&b)?
191        }
192        (Some(w), None) => {
193            let w = w.unsqueeze(0)?;
194            normalized.mul_op(&w)?
195        }
196        (None, Some(b)) => {
197            let b = b.unsqueeze(0)?;
198            normalized.add_op(&b)?
199        }
200        (None, None) => normalized,
201    };
202
203    Ok(output)
204}
205
206/// Group normalization
207///
208/// Divides channels into groups and normalizes within each group
209pub fn group_norm(
210    input: &Tensor,
211    num_groups: usize,
212    weight: Option<&Tensor>,
213    bias: Option<&Tensor>,
214    eps: f64,
215) -> TorshResult<Tensor> {
216    let shape = input.shape().dims().to_vec();
217    let ndim = shape.len();
218
219    if ndim < 2 {
220        return Err(TorshError::invalid_argument_with_context(
221            "Group norm requires at least 2D input",
222            "group_norm",
223        ));
224    }
225
226    let batch_size = shape[0];
227    let num_channels = shape[1];
228
229    if num_channels % num_groups != 0 {
230        return Err(TorshError::invalid_argument_with_context(
231            &format!(
232                "Number of channels {} must be divisible by num_groups {}",
233                num_channels, num_groups
234            ),
235            "group_norm",
236        ));
237    }
238
239    let channels_per_group = num_channels / num_groups;
240
241    // Reshape to (N, G, C//G, *spatial)
242    let mut new_shape = vec![batch_size, num_groups, channels_per_group];
243    new_shape.extend_from_slice(&shape[2..]);
244
245    let reshaped = input.reshape(&new_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
246
247    // Normalize over channel and spatial dimensions within each group
248    let axes: Vec<usize> = (2..new_shape.len()).collect();
249    let mean = reshaped.mean(Some(&axes), true)?;
250    let var = reshaped.var(Some(&axes), true, StatMode::Population)?;
251
252    // Normalize
253    let std = var.add_scalar(eps as f32)?.sqrt()?;
254    let normalized = reshaped.sub(&mean)?.div(&std)?;
255
256    // Reshape back to original dimensions
257    let normalized = normalized.reshape(&shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
258
259    // Apply affine transformation if provided
260    let output = match (weight, bias) {
261        (Some(w), Some(b)) => {
262            let w = w.unsqueeze(0)?; // Add batch dimension
263            let b = b.unsqueeze(0)?;
264            normalized.mul_op(&w)?.add_op(&b)?
265        }
266        (Some(w), None) => {
267            let w = w.unsqueeze(0)?;
268            normalized.mul_op(&w)?
269        }
270        (None, Some(b)) => {
271            let b = b.unsqueeze(0)?;
272            normalized.add_op(&b)?
273        }
274        (None, None) => normalized,
275    };
276
277    Ok(output)
278}
279
280/// Local response normalization
281///
282/// Applies local response normalization over an input signal
283pub fn local_response_norm(
284    input: &Tensor,
285    size: usize,
286    alpha: f64,
287    beta: f64,
288    k: f64,
289) -> TorshResult<Tensor> {
290    // LRN normalizes over neighboring channels
291    // For each position, normalize using channels [c-size/2, c+size/2]
292
293    let shape_obj = input.shape();
294    let shape = shape_obj.dims();
295    if shape.len() < 2 {
296        return Err(TorshError::invalid_argument_with_context(
297            "Local response norm requires at least 2D input",
298            "local_response_norm",
299        ));
300    }
301
302    let _num_channels = shape[1];
303
304    // Create padded tensor for easier computation
305    let _padding = size / 2;
306
307    // Compute squared values
308    let squared = input.pow_scalar(2.0)?;
309
310    // For each channel, sum over the neighboring channels
311    // This is a simplified implementation - a full implementation would use
312    // efficient convolution-like operations for the windowed sum
313
314    // For now, return a placeholder implementation
315    // that at least computes a basic normalization
316    let sum_sq = squared.clone();
317
318    // Compute denominator: (k + alpha/n * sum(x_i^2))^beta
319    let n = size as f32;
320    let denominator = sum_sq
321        .mul_scalar((alpha / n as f64) as f32)?
322        .add_scalar(k as f32)?
323        .pow_scalar(beta as f32)?;
324
325    // Normalize
326    input.div(&denominator)
327}
328
329/// Normalize tensor using Lp norm
330pub fn normalize(
331    input: &Tensor,
332    p: f64,
333    dim: i64,
334    eps: f64,
335    out: Option<&mut Tensor>,
336) -> TorshResult<Tensor> {
337    // Validate p parameter
338    if p <= 0.0 {
339        return Err(TorshError::invalid_argument_with_context(
340            &format!("normalize: p must be positive, got {}", p),
341            "normalize",
342        ));
343    }
344
345    // Validate dimension
346    let ndim = input.ndim() as i64;
347    let dim = if dim < 0 { ndim + dim } else { dim };
348
349    if dim < 0 || dim >= ndim {
350        return Err(TorshError::InvalidArgument(format!(
351            "Dimension {} out of range for tensor with {} dimensions",
352            dim, ndim
353        )));
354    }
355
356    // Compute Lp norm: (sum(|x|^p))^(1/p)
357    let norm = if (p - 2.0).abs() < 1e-7 {
358        // Optimized path for L2 norm (most common case)
359        let squared = input.pow_scalar(2.0)?;
360        let sum = squared.sum_dim(&[dim as i32], true)?;
361        sum.sqrt()?
362    } else if (p - 1.0).abs() < 1e-7 {
363        // Optimized path for L1 norm
364        let abs_vals = input.abs()?;
365        abs_vals.sum_dim(&[dim as i32], true)?
366    } else if p.is_infinite() && p.is_sign_positive() {
367        // L-infinity norm: max(|x|)
368        let abs_vals = input.abs()?;
369        abs_vals.max(Some(dim as usize), true)?
370    } else {
371        // General Lp norm
372        let abs_vals = input.abs()?;
373        let powered = abs_vals.pow_scalar(p as f32)?;
374        let sum = powered.sum_dim(&[dim as i32], true)?;
375        sum.pow_scalar((1.0 / p) as f32)?
376    };
377
378    // Add epsilon to avoid division by zero
379    let norm_eps = norm.add_scalar(eps as f32)?;
380
381    // Normalize
382    let normalized = input.div(&norm_eps)?;
383
384    if let Some(_out_tensor) = out {
385        // Copy to output tensor if provided
386        // For now, we don't support in-place operations
387        return Err(TorshError::UnsupportedOperation {
388            op: "in-place normalize".to_string(),
389            dtype: "tensor".to_string(),
390        });
391    }
392
393    Ok(normalized)
394}
395
396/// Weight normalization
397///
398/// Decouples the magnitude and direction of weight vectors
399pub fn weight_norm(weight: &Tensor, dim: i64) -> TorshResult<(Tensor, Tensor)> {
400    // Compute the norm over the specified dimension
401    let squared = weight.pow_scalar(2.0)?;
402    let norm = squared.sum_dim(&[dim as i32], true)?.sqrt()?;
403
404    // Normalized direction
405    let direction = weight.div(&norm)?;
406
407    // Squeeze the norm dimension for the magnitude output
408    let magnitude = norm.squeeze(dim as i32)?;
409
410    Ok((magnitude, direction))
411}
412
413/// Spectral normalization
414///
415/// Normalizes weight by its spectral norm (largest singular value) using power iteration.
416///
417/// The spectral norm ||W||_2 is the largest singular value of the weight matrix W.
418/// This is computed efficiently using the power iteration method:
419///
420/// 1. Start with random vector u
421/// 2. Iterate: v = W^T u / ||W^T u||, u = W v / ||W v||
422/// 3. Spectral norm ≈ u^T W v
423///
424/// # Arguments
425/// * `weight` - Weight tensor (at least 2D)
426/// * `u` - Optional initial vector for power iteration
427/// * `n_power_iterations` - Number of power iterations (typically 1-5)
428/// * `eps` - Small constant for numerical stability
429///
430/// # Returns
431/// * Tuple of (normalized_weight, updated_u_vector)
432pub fn spectral_norm(
433    weight: &Tensor,
434    u: Option<&Tensor>,
435    n_power_iterations: usize,
436    eps: f64,
437) -> TorshResult<(Tensor, Tensor)> {
438    let shape_obj = weight.shape();
439    let shape = shape_obj.dims();
440
441    if shape.len() < 2 {
442        return Err(TorshError::invalid_argument_with_context(
443            "Spectral norm requires at least 2D weight tensor",
444            "spectral_norm",
445        ));
446    }
447
448    // Reshape weight to 2D: [out_features, in_features]
449    // For conv layers: [out_channels, in_channels * kernel_h * kernel_w]
450    let out_features = shape[0];
451    let in_features: usize = shape[1..].iter().product();
452    let weight_mat = weight.view(&[out_features as i32, in_features as i32])?;
453
454    // Initialize u vector if not provided
455    let mut u_vec = if let Some(u_input) = u {
456        u_input.clone()
457    } else {
458        // Initialize with random normal values
459        use torsh_tensor::creation::randn;
460        randn::<f32>(&[out_features])?
461    };
462
463    // Normalize u to unit length
464    let u_norm = u_vec.pow_scalar(2.0)?.sum()?.sqrt()?;
465    u_vec = u_vec.div_scalar(u_norm.item()? + eps as f32)?;
466
467    // Power iteration to find dominant eigenvector
468    for _ in 0..n_power_iterations {
469        // v = W^T u
470        let weight_t = weight_mat.t()?;
471        let v = weight_t.matmul(&u_vec.view(&[out_features as i32, 1])?)?;
472        let v = v.squeeze(1)?;
473
474        // Normalize v
475        let v_norm = v.pow_scalar(2.0)?.sum()?.sqrt()?;
476        let v = v.div_scalar(v_norm.item()? + eps as f32)?;
477
478        // u = W v
479        let u = weight_mat.matmul(&v.view(&[in_features as i32, 1])?)?;
480        u_vec = u.squeeze(1)?;
481
482        // Normalize u
483        let u_norm = u_vec.pow_scalar(2.0)?.sum()?.sqrt()?;
484        u_vec = u_vec.div_scalar(u_norm.item()? + eps as f32)?;
485    }
486
487    // Compute spectral norm: sigma = u^T W v
488    // First compute v = W^T u
489    let weight_t = weight_mat.t()?;
490    let v = weight_t.matmul(&u_vec.view(&[out_features as i32, 1])?)?;
491    let v = v.squeeze(1)?;
492
493    // Normalize v
494    let v_norm = v.pow_scalar(2.0)?.sum()?.sqrt()?;
495    let v = v.div_scalar(v_norm.item()? + eps as f32)?;
496
497    // Compute sigma = u^T W v
498    let wv = weight_mat.matmul(&v.view(&[in_features as i32, 1])?)?;
499    let wv = wv.squeeze(1)?;
500
501    // u^T (Wv) - dot product
502    let u_wv = u_vec.mul(&wv)?.sum()?;
503    let sigma = u_wv.item()?;
504
505    // Normalize weight by spectral norm
506    let normalized_weight = weight.div_scalar(sigma + eps as f32)?;
507
508    Ok((normalized_weight, u_vec))
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use torsh_tensor::creation::tensor_1d;
515
516    #[test]
517    fn test_normalize() {
518        // Basic parameter validation tests
519        let input = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
520
521        // Test invalid p value (must be positive)
522        let result = normalize(&input, -1.0, 0, 1e-12, None);
523        assert!(result.is_err());
524
525        // Test invalid p value (zero)
526        let result = normalize(&input, 0.0, 0, 1e-12, None);
527        assert!(result.is_err());
528
529        // Test valid p=2 normalization (L2 norm)
530        let result = normalize(&input, 2.0, 0, 1e-12, None);
531        assert!(result.is_ok());
532
533        // Test valid p=1 normalization (L1 norm)
534        let result = normalize(&input, 1.0, 0, 1e-12, None);
535        assert!(result.is_ok());
536
537        // Test valid p=3 normalization (general p-norm)
538        let result = normalize(&input, 3.0, 0, 1e-12, None);
539        assert!(result.is_ok());
540
541        // Test L-infinity norm
542        let result = normalize(&input, f64::INFINITY, 0, 1e-12, None);
543        assert!(result.is_ok());
544    }
545}