train_station/tensor/reductions/
norm.rs

1//! L2 norm reduction operations for tensors
2//!
3//! This module provides L2 norm (Euclidean norm) reduction operations for tensors.
4//! The L2 norm computes the square root of the sum of squared elements, which is
5//! commonly used in machine learning for regularization, distance calculations,
6//! and gradient clipping.
7//!
8//! # Operations
9//!
10//! * `norm()` - Computes L2 norm over all elements, returning a scalar tensor
11//! * `norm_dims()` - Computes L2 norm over specified dimensions with optional dimension preservation
12//!
13//! # Examples
14//!
15//! ```
16//! use train_station::Tensor;
17//!
18//! // Compute L2 norm of all elements
19//! let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
20//! let norm = tensor.norm();
21//! assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5
22//!
23//! // Compute L2 norm along specific dimensions
24//! let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
25//! let row_norms = matrix.norm_dims(&[1], true);
26//! assert_eq!(row_norms.shape().dims, vec![2, 1]);
27//! ```
28//!
29//! # Performance
30//!
31//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
32//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
33//! correctness while preserving memory layout efficiency.
34//!
35//! # Gradient Tracking
36//!
37//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
38//! The gradient computation follows the mathematical derivative of the L2 norm operation.
39
40use crate::gradtrack::{GradEngine, GradFn};
41use crate::tensor::core::Tensor;
42
43impl Tensor {
44    /// Computes the L2 norm (Euclidean norm) over all elements
45    ///
46    /// The L2 norm is calculated as sqrt(sum(x²)) where x represents each element
47    /// in the tensor. This operation reduces the tensor to a scalar value \[1\].
48    ///
49    /// # Returns
50    ///
51    /// A scalar tensor containing the L2 norm value
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// use train_station::Tensor;
57    ///
58    /// // Basic L2 norm calculation
59    /// let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
60    /// let norm = tensor.norm();
61    /// assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5
62    /// ```
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// // L2 norm of a larger tensor
68    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
69    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
70    /// let norm = tensor.norm();
71    /// // sqrt(1² + 2² + 3² + 4² + 5² + 6² + 7² + 8²) = sqrt(204) ≈ 14.283
72    /// let expected = 204.0_f32.sqrt();
73    /// assert!((norm.get(&[0]) - expected).abs() < 1e-5);
74    /// ```
75    ///
76    /// # Performance
77    ///
78    /// Uses optimized contiguous tensor path with 4x loop unrolling for better
79    /// performance. Non-contiguous tensors use stride-aware iteration.
80    pub fn norm(&self) -> Tensor {
81        // Compute sqrt(sum(x^2))
82        let mut sumsq = 0.0f32;
83        let n = self.size();
84
85        if self.is_contiguous() {
86            // Fast path for contiguous tensors
87            unsafe {
88                let src = self.as_ptr();
89                let mut i = 0usize;
90                while i + 4 <= n {
91                    let x0 = *src.add(i);
92                    let x1 = *src.add(i + 1);
93                    let x2 = *src.add(i + 2);
94                    let x3 = *src.add(i + 3);
95                    sumsq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
96                    i += 4;
97                }
98                while i < n {
99                    let v = *src.add(i);
100                    sumsq += v * v;
101                    i += 1;
102                }
103            }
104        } else {
105            // Stride-aware path for non-contiguous tensors
106            let dims = self.shape().dims.clone();
107            for flat_idx in 0..n {
108                // Convert flat index to multi-dimensional coordinates
109                let mut coords = vec![0; dims.len()];
110                let mut tmp = flat_idx;
111                for k in (0..dims.len()).rev() {
112                    coords[k] = tmp % dims[k];
113                    tmp /= dims[k];
114                }
115
116                // Get value using stride-aware offset
117                let offset = self.shape().offset(&coords);
118                let value = unsafe { *self.as_ptr().add(offset) };
119                sumsq += value * value;
120            }
121        }
122        let mut out = Tensor::new(vec![1]);
123        unsafe {
124            *out.as_mut_ptr() = sumsq.sqrt();
125        }
126
127        if self.requires_grad() {
128            let mut result = out.clone();
129            result.set_requires_grad_internal(true);
130            let grad_fn = GradFn::ReduceNorm {
131                saved_norm: Box::new(out.clone()),
132                saved_input: Box::new(self.clone()),
133                input_shape: self.shape().dims.clone(),
134            };
135            result.set_grad_fn(grad_fn.clone());
136            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
137            return result;
138        }
139
140        out
141    }
142
143    /// Computes the L2 norm over specified dimensions
144    ///
145    /// Reduces the tensor along the specified dimensions by computing the L2 norm
146    /// of each slice. The result maintains the original tensor structure with
147    /// reduced dimensions optionally preserved as size-1 dimensions.
148    ///
149    /// # Arguments
150    ///
151    /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
152    /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
153    ///
154    /// # Returns
155    ///
156    /// A tensor with L2 norm computed over the specified dimensions
157    ///
158    /// # Examples
159    ///
160    /// ```
161    /// use train_station::Tensor;
162    ///
163    /// // Norm along rows (dimension 1) with keepdim=true
164    /// let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
165    /// let row_norms = matrix.norm_dims(&[1], true);
166    /// assert_eq!(row_norms.shape().dims, vec![2, 1]);
167    /// assert!((row_norms.get(&[0, 0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²)
168    /// assert!((row_norms.get(&[1, 0]) - 5.0).abs() < 1e-6); // sqrt(0² + 5²)
169    /// ```
170    ///
171    /// ```
172    /// use train_station::Tensor;
173    ///
174    /// // Norm along columns (dimension 0) with keepdim=false
175    /// let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
176    /// let col_norms = matrix.norm_dims(&[0], false);
177    /// assert_eq!(col_norms.shape().dims, vec![2]);
178    /// assert!((col_norms.get(&[0]) - 3.0).abs() < 1e-6); // sqrt(3² + 0²)
179    /// assert!((col_norms.get(&[1]) - 6.403).abs() < 1e-3); // sqrt(4² + 5²)
180    /// ```
181    ///
182    /// ```
183    /// use train_station::Tensor;
184    ///
185    /// // Norm over multiple dimensions
186    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
187    /// let norm_all = tensor.norm_dims(&[0, 1], false);
188    /// assert_eq!(norm_all.shape().dims, vec![1]);
189    /// // sqrt(1² + 2² + 3² + 4²) = sqrt(30) ≈ 5.477
190    /// assert!((norm_all.get(&[0]) - 30.0_f32.sqrt()).abs() < 1e-5);
191    /// ```
192    ///
193    /// # Panics
194    ///
195    /// * If `dims` is empty
196    /// * If any dimension index is out of bounds for the tensor rank
197    ///
198    /// # Performance
199    ///
200    /// Uses efficient coordinate-based iteration that works correctly with
201    /// both contiguous and non-contiguous tensor layouts.
202    pub fn norm_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
203        assert!(
204            !dims.is_empty(),
205            "norm_dims requires at least one dimension"
206        );
207        let rank = self.shape().rank();
208        for &d in dims {
209            assert!(
210                d < rank,
211                "norm_dims dim {} out of bounds for rank {}",
212                d,
213                rank
214            );
215        }
216
217        // Build output shape
218        let in_shape = self.shape().dims.clone();
219        let mut out_dims = in_shape.clone();
220        let mut reduced: Vec<usize> = dims.to_vec();
221        reduced.sort_unstable();
222        reduced.dedup();
223        for &d in reduced.iter() {
224            out_dims[d] = if keepdim { 1 } else { 0 };
225        }
226        if !keepdim {
227            out_dims.retain(|&s| s != 0);
228        }
229        if out_dims.is_empty() {
230            out_dims.push(1);
231        }
232        let mut out = Tensor::zeros(out_dims.clone());
233
234        // Compute sum of squares reduced, then sqrt
235        let out_rank = out.shape().rank();
236        let mut coords = vec![0usize; rank];
237        unsafe {
238            let sptr = out.as_mut_ptr();
239            for lin in 0..self.size() {
240                let mut tmp = lin;
241                for i in (0..rank).rev() {
242                    let s = in_shape[i];
243                    coords[i] = if s == 0 { 0 } else { tmp % s };
244                    if s != 0 {
245                        tmp /= s;
246                    }
247                }
248                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
249                for (i, &c) in coords.iter().enumerate().take(rank) {
250                    if reduced.contains(&i) {
251                        if keepdim {
252                            out_coords.push(0);
253                        }
254                    } else {
255                        out_coords.push(c);
256                    }
257                }
258                let off = if out_coords.is_empty() {
259                    0
260                } else {
261                    out.shape().offset(&out_coords)
262                };
263                // Get input value using stride-aware offset
264                let in_offset = self.shape().offset(&coords);
265                let v = *self.as_ptr().add(in_offset);
266                *sptr.add(off) += v * v;
267            }
268            // sqrt in place
269            for i in 0..out.size() {
270                *sptr.add(i) = (*sptr.add(i)).sqrt();
271            }
272        }
273
274        if self.requires_grad() {
275            let mut result = out.clone();
276            result.set_requires_grad_internal(true);
277            let grad_fn = GradFn::ReduceNormDims {
278                dims: reduced,
279                keepdim,
280                input_shape: self.shape().dims.clone(),
281                saved_norm: Box::new(out.clone()),
282                saved_input: Box::new(self.clone()),
283            };
284            result.set_grad_fn(grad_fn.clone());
285            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
286            return result;
287        }
288
289        out
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_norm_forward_basic() {
299        let x = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
300        let n = x.norm();
301        unsafe {
302            assert!((*n.as_ptr() - 5.0).abs() < 1e-6);
303        }
304    }
305
306    #[test]
307    fn test_norm_dims_forward() {
308        let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
309        let n = x.norm_dims(&[1], true);
310        assert_eq!(n.shape().dims, vec![2, 1]);
311        assert!((n.get(&[0, 0]) - 5.0).abs() < 1e-6);
312        assert!((n.get(&[1, 0]) - 5.0).abs() < 1e-6);
313    }
314
315    #[test]
316    fn test_norm_non_contiguous_transpose() {
317        // Test norm on transposed tensor (non-contiguous view)
318        let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 12.0, 5.0, 0.0], vec![2, 3]).unwrap();
319        // Original: [[3, 4, 0], [12, 5, 0]]
320
321        let x_t = x.transpose(0, 1);
322        // Transposed: [[3, 12], [4, 5], [0, 0]]
323        assert!(!x_t.is_contiguous()); // Should be a view
324
325        let norm_orig = x.norm();
326        let norm_view = x_t.norm();
327
328        // Both should give the same result
329        assert!((norm_orig.get(&[0]) - norm_view.get(&[0])).abs() < 1e-6);
330
331        // Expected norm of [3,4,0,12,5,0]: sqrt(3²+4²+0²+12²+5²+0²) = sqrt(9+16+144+25) = sqrt(194) ≈ 13.928
332        let expected_norm = 194.0_f32.sqrt();
333        assert!((norm_orig.get(&[0]) - expected_norm).abs() < 1e-5);
334    }
335
336    #[test]
337    fn test_norm_dims_non_contiguous() {
338        // Test norm_dims on non-contiguous tensor
339        let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 12.0, 5.0, 0.0], vec![2, 3]).unwrap();
340        let x_t = x.transpose(0, 1); // [3, 2]
341        assert!(!x_t.is_contiguous());
342
343        // Norm along dim 0 of transposed tensor
344        let norm_dim0 = x_t.norm_dims(&[0], false);
345        assert_eq!(norm_dim0.shape().dims, vec![2]);
346
347        // For dim 0: [3,4,0] and [12,5,0]
348        // norm([3,4,0]) = sqrt(3²+4²+0²) = sqrt(25) = 5
349        // norm([12,5,0]) = sqrt(12²+5²+0²) = sqrt(169) = 13
350        assert!((norm_dim0.get(&[0]) - 5.0).abs() < 1e-6);
351        assert!((norm_dim0.get(&[1]) - 13.0).abs() < 1e-6);
352
353        // Norm along dim 1 of transposed tensor
354        let norm_dim1 = x_t.norm_dims(&[1], false);
355        assert_eq!(norm_dim1.shape().dims, vec![3]);
356        // norm([3,12]) = sqrt(9+144) = sqrt(153) ≈ 12.369
357        // norm([4,5]) = sqrt(16+25) = sqrt(41) ≈ 6.403
358        // norm([0,0]) = sqrt(0+0) = 0
359        assert!((norm_dim1.get(&[0]) - 153.0_f32.sqrt()).abs() < 1e-5);
360        assert!((norm_dim1.get(&[1]) - 41.0_f32.sqrt()).abs() < 1e-5);
361        assert!((norm_dim1.get(&[2]) - 0.0).abs() < 1e-6);
362    }
363
364    #[test]
365    fn test_norm_permuted_tensor() {
366        // Test with permuted tensor
367        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
368        let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
369
370        // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
371        let x_perm = x.permute(vec![2, 1, 0]);
372        assert!(!x_perm.is_contiguous());
373
374        let norm_orig = x.norm();
375        let norm_perm = x_perm.norm();
376
377        // Should give same result
378        assert!((norm_orig.get(&[0]) - norm_perm.get(&[0])).abs() < 1e-6);
379
380        // norm([1,2,3,4,5,6,7,8]) = sqrt(1+4+9+16+25+36+49+64) = sqrt(204) ≈ 14.283
381        let expected_norm = 204.0_f32.sqrt();
382        assert!((norm_orig.get(&[0]) - expected_norm).abs() < 1e-5);
383    }
384}