train_station/tensor/reductions/
norm.rs

1//! L2 norm reduction operations for tensors
2//!
3//! This module provides L2 norm (Euclidean norm) reduction operations for tensors.
4//! The L2 norm computes the square root of the sum of squared elements, which is
5//! commonly used in machine learning for regularization, distance calculations,
6//! and gradient clipping.
7//!
8//! # Operations
9//!
10//! * `norm()` - Computes L2 norm over all elements, returning a scalar tensor
11//! * `norm_dims()` - Computes L2 norm over specified dimensions with optional dimension preservation
12//!
13//! # Examples
14//!
15//! ```
16//! use train_station::Tensor;
17//!
18//! // Compute L2 norm of all elements
19//! let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
20//! let norm = tensor.norm();
21//! assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5
22//!
23//! // Compute L2 norm along specific dimensions
24//! let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
25//! let row_norms = matrix.norm_dims(&[1], true);
26//! assert_eq!(row_norms.shape().dims, vec![2, 1]);
27//! ```
28//!
29//! # Performance
30//!
31//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
32//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
33//! correctness while preserving memory layout efficiency.
34//!
35//! # Gradient Tracking
36//!
37//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
38//! The gradient computation follows the mathematical derivative of the L2 norm operation.
39
40use crate::gradtrack::{GradEngine, GradFn};
41use crate::tensor::core::Tensor;
42
43impl Tensor {
44    /// Computes the L2 norm (Euclidean norm) over all elements
45    ///
46    /// The L2 norm is calculated as sqrt(sum(x²)) where x represents each element
47    /// in the tensor. This operation reduces the tensor to a scalar value \[1\].
48    ///
49    /// # Returns
50    ///
51    /// A scalar tensor containing the L2 norm value
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// use train_station::Tensor;
57    ///
58    /// // Basic L2 norm calculation
59    /// let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
60    /// let norm = tensor.norm();
61    /// assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5
62    /// ```
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// // L2 norm of a larger tensor
68    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
69    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
70    /// let norm = tensor.norm();
71    /// // sqrt(1² + 2² + 3² + 4² + 5² + 6² + 7² + 8²) = sqrt(204) ≈ 14.283
72    /// let expected = 204.0_f32.sqrt();
73    /// assert!((norm.get(&[0]) - expected).abs() < 1e-5);
74    /// ```
75    ///
76    /// # Performance
77    ///
78    /// Uses optimized contiguous tensor path with 4x loop unrolling for better
79    /// performance. Non-contiguous tensors use stride-aware iteration.
80    #[track_caller]
81    pub fn norm(&self) -> Tensor {
82        // Compute sqrt(sum(x^2))
83        let mut sumsq = 0.0f32;
84        let n = self.size();
85
86        if self.is_contiguous() {
87            // Fast path for contiguous tensors
88            unsafe {
89                let src = self.as_ptr();
90                let mut i = 0usize;
91                while i + 4 <= n {
92                    let x0 = *src.add(i);
93                    let x1 = *src.add(i + 1);
94                    let x2 = *src.add(i + 2);
95                    let x3 = *src.add(i + 3);
96                    sumsq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
97                    i += 4;
98                }
99                while i < n {
100                    let v = *src.add(i);
101                    sumsq += v * v;
102                    i += 1;
103                }
104            }
105        } else {
106            // Stride-aware path for non-contiguous tensors
107            let dims = self.shape().dims.clone();
108            for flat_idx in 0..n {
109                // Convert flat index to multi-dimensional coordinates
110                let mut coords = vec![0; dims.len()];
111                let mut tmp = flat_idx;
112                for k in (0..dims.len()).rev() {
113                    coords[k] = tmp % dims[k];
114                    tmp /= dims[k];
115                }
116
117                // Get value using stride-aware offset
118                let offset = self.shape().offset(&coords);
119                let value = unsafe { *self.as_ptr().add(offset) };
120                sumsq += value * value;
121            }
122        }
123        let mut out = Tensor::new(vec![1]);
124        unsafe {
125            *out.as_mut_ptr() = sumsq.sqrt();
126        }
127
128        if self.requires_grad() {
129            let mut result = out.clone();
130            result.set_requires_grad_internal(true);
131            let grad_fn = GradFn::ReduceNorm {
132                saved_norm: Box::new(out.clone()),
133                saved_input: Box::new(self.clone()),
134                input_shape: self.shape().dims.clone(),
135            };
136            result.set_grad_fn(grad_fn.clone());
137            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
138            return result;
139        }
140
141        out
142    }
143
144    /// Computes the L2 norm over specified dimensions
145    ///
146    /// Reduces the tensor along the specified dimensions by computing the L2 norm
147    /// of each slice. The result maintains the original tensor structure with
148    /// reduced dimensions optionally preserved as size-1 dimensions.
149    ///
150    /// # Arguments
151    ///
152    /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
153    /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
154    ///
155    /// # Returns
156    ///
157    /// A tensor with L2 norm computed over the specified dimensions
158    ///
159    /// # Examples
160    ///
161    /// ```
162    /// use train_station::Tensor;
163    ///
164    /// // Norm along rows (dimension 1) with keepdim=true
165    /// let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
166    /// let row_norms = matrix.norm_dims(&[1], true);
167    /// assert_eq!(row_norms.shape().dims, vec![2, 1]);
168    /// assert!((row_norms.get(&[0, 0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²)
169    /// assert!((row_norms.get(&[1, 0]) - 5.0).abs() < 1e-6); // sqrt(0² + 5²)
170    /// ```
171    ///
172    /// ```
173    /// use train_station::Tensor;
174    ///
175    /// // Norm along columns (dimension 0) with keepdim=false
176    /// let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
177    /// let col_norms = matrix.norm_dims(&[0], false);
178    /// assert_eq!(col_norms.shape().dims, vec![2]);
179    /// assert!((col_norms.get(&[0]) - 3.0).abs() < 1e-6); // sqrt(3² + 0²)
180    /// assert!((col_norms.get(&[1]) - 6.403).abs() < 1e-3); // sqrt(4² + 5²)
181    /// ```
182    ///
183    /// ```
184    /// use train_station::Tensor;
185    ///
186    /// // Norm over multiple dimensions
187    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
188    /// let norm_all = tensor.norm_dims(&[0, 1], false);
189    /// assert_eq!(norm_all.shape().dims, vec![1]);
190    /// // sqrt(1² + 2² + 3² + 4²) = sqrt(30) ≈ 5.477
191    /// assert!((norm_all.get(&[0]) - 30.0_f32.sqrt()).abs() < 1e-5);
192    /// ```
193    ///
194    /// # Panics
195    ///
196    /// * If `dims` is empty
197    /// * If any dimension index is out of bounds for the tensor rank
198    ///
199    /// # Performance
200    ///
201    /// Uses efficient coordinate-based iteration that works correctly with
202    /// both contiguous and non-contiguous tensor layouts.
203    #[track_caller]
204    pub fn norm_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
205        assert!(
206            !dims.is_empty(),
207            "norm_dims requires at least one dimension"
208        );
209        let rank = self.shape().rank();
210        for &d in dims {
211            assert!(
212                d < rank,
213                "norm_dims dim {} out of bounds for rank {}",
214                d,
215                rank
216            );
217        }
218
219        // Build output shape
220        let in_shape = self.shape().dims.clone();
221        let mut out_dims = in_shape.clone();
222        let mut reduced: Vec<usize> = dims.to_vec();
223        reduced.sort_unstable();
224        reduced.dedup();
225        for &d in reduced.iter() {
226            out_dims[d] = if keepdim { 1 } else { 0 };
227        }
228        if !keepdim {
229            out_dims.retain(|&s| s != 0);
230        }
231        if out_dims.is_empty() {
232            out_dims.push(1);
233        }
234        let mut out = Tensor::zeros(out_dims.clone());
235
236        // Compute sum of squares reduced, then sqrt
237        let out_rank = out.shape().rank();
238        let mut coords = vec![0usize; rank];
239        unsafe {
240            let sptr = out.as_mut_ptr();
241            for lin in 0..self.size() {
242                let mut tmp = lin;
243                for i in (0..rank).rev() {
244                    let s = in_shape[i];
245                    coords[i] = if s == 0 { 0 } else { tmp % s };
246                    if s != 0 {
247                        tmp /= s;
248                    }
249                }
250                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
251                for (i, &c) in coords.iter().enumerate().take(rank) {
252                    if reduced.contains(&i) {
253                        if keepdim {
254                            out_coords.push(0);
255                        }
256                    } else {
257                        out_coords.push(c);
258                    }
259                }
260                let off = if out_coords.is_empty() {
261                    0
262                } else {
263                    out.shape().offset(&out_coords)
264                };
265                // Get input value using stride-aware offset
266                let in_offset = self.shape().offset(&coords);
267                let v = *self.as_ptr().add(in_offset);
268                *sptr.add(off) += v * v;
269            }
270            // sqrt in place
271            for i in 0..out.size() {
272                *sptr.add(i) = (*sptr.add(i)).sqrt();
273            }
274        }
275
276        if self.requires_grad() {
277            let mut result = out.clone();
278            result.set_requires_grad_internal(true);
279            let grad_fn = GradFn::ReduceNormDims {
280                dims: reduced,
281                keepdim,
282                input_shape: self.shape().dims.clone(),
283                saved_norm: Box::new(out.clone()),
284                saved_input: Box::new(self.clone()),
285            };
286            result.set_grad_fn(grad_fn.clone());
287            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
288            return result;
289        }
290
291        out
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_norm_forward_basic() {
301        let x = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
302        let n = x.norm();
303        unsafe {
304            assert!((*n.as_ptr() - 5.0).abs() < 1e-6);
305        }
306    }
307
308    #[test]
309    fn test_norm_dims_forward() {
310        let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
311        let n = x.norm_dims(&[1], true);
312        assert_eq!(n.shape().dims, vec![2, 1]);
313        assert!((n.get(&[0, 0]) - 5.0).abs() < 1e-6);
314        assert!((n.get(&[1, 0]) - 5.0).abs() < 1e-6);
315    }
316
317    #[test]
318    fn test_norm_non_contiguous_transpose() {
319        // Test norm on transposed tensor (non-contiguous view)
320        let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 12.0, 5.0, 0.0], vec![2, 3]).unwrap();
321        // Original: [[3, 4, 0], [12, 5, 0]]
322
323        let x_t = x.transpose(0, 1);
324        // Transposed: [[3, 12], [4, 5], [0, 0]]
325        assert!(!x_t.is_contiguous()); // Should be a view
326
327        let norm_orig = x.norm();
328        let norm_view = x_t.norm();
329
330        // Both should give the same result
331        assert!((norm_orig.get(&[0]) - norm_view.get(&[0])).abs() < 1e-6);
332
333        // Expected norm of [3,4,0,12,5,0]: sqrt(3²+4²+0²+12²+5²+0²) = sqrt(9+16+144+25) = sqrt(194) ≈ 13.928
334        let expected_norm = 194.0_f32.sqrt();
335        assert!((norm_orig.get(&[0]) - expected_norm).abs() < 1e-5);
336    }
337
338    #[test]
339    fn test_norm_dims_non_contiguous() {
340        // Test norm_dims on non-contiguous tensor
341        let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 12.0, 5.0, 0.0], vec![2, 3]).unwrap();
342        let x_t = x.transpose(0, 1); // [3, 2]
343        assert!(!x_t.is_contiguous());
344
345        // Norm along dim 0 of transposed tensor
346        let norm_dim0 = x_t.norm_dims(&[0], false);
347        assert_eq!(norm_dim0.shape().dims, vec![2]);
348
349        // For dim 0: [3,4,0] and [12,5,0]
350        // norm([3,4,0]) = sqrt(3²+4²+0²) = sqrt(25) = 5
351        // norm([12,5,0]) = sqrt(12²+5²+0²) = sqrt(169) = 13
352        assert!((norm_dim0.get(&[0]) - 5.0).abs() < 1e-6);
353        assert!((norm_dim0.get(&[1]) - 13.0).abs() < 1e-6);
354
355        // Norm along dim 1 of transposed tensor
356        let norm_dim1 = x_t.norm_dims(&[1], false);
357        assert_eq!(norm_dim1.shape().dims, vec![3]);
358        // norm([3,12]) = sqrt(9+144) = sqrt(153) ≈ 12.369
359        // norm([4,5]) = sqrt(16+25) = sqrt(41) ≈ 6.403
360        // norm([0,0]) = sqrt(0+0) = 0
361        assert!((norm_dim1.get(&[0]) - 153.0_f32.sqrt()).abs() < 1e-5);
362        assert!((norm_dim1.get(&[1]) - 41.0_f32.sqrt()).abs() < 1e-5);
363        assert!((norm_dim1.get(&[2]) - 0.0).abs() < 1e-6);
364    }
365
366    #[test]
367    fn test_norm_permuted_tensor() {
368        // Test with permuted tensor
369        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
370        let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
371
372        // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
373        let x_perm = x.permute(vec![2, 1, 0]);
374        assert!(!x_perm.is_contiguous());
375
376        let norm_orig = x.norm();
377        let norm_perm = x_perm.norm();
378
379        // Should give same result
380        assert!((norm_orig.get(&[0]) - norm_perm.get(&[0])).abs() < 1e-6);
381
382        // norm([1,2,3,4,5,6,7,8]) = sqrt(1+4+9+16+25+36+49+64) = sqrt(204) ≈ 14.283
383        let expected_norm = 204.0_f32.sqrt();
384        assert!((norm_orig.get(&[0]) - expected_norm).abs() < 1e-5);
385    }
386}