train_station/tensor/reductions/norm.rs
1//! L2 norm reduction operations for tensors
2//!
3//! This module provides L2 norm (Euclidean norm) reduction operations for tensors.
4//! The L2 norm computes the square root of the sum of squared elements, which is
5//! commonly used in machine learning for regularization, distance calculations,
6//! and gradient clipping.
7//!
8//! # Operations
9//!
10//! * `norm()` - Computes L2 norm over all elements, returning a scalar tensor
11//! * `norm_dims()` - Computes L2 norm over specified dimensions with optional dimension preservation
12//!
13//! # Examples
14//!
15//! ```
16//! use train_station::Tensor;
17//!
18//! // Compute L2 norm of all elements
19//! let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
20//! let norm = tensor.norm();
21//! assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5
22//!
23//! // Compute L2 norm along specific dimensions
24//! let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
25//! let row_norms = matrix.norm_dims(&[1], true);
26//! assert_eq!(row_norms.shape().dims(), vec![2, 1]);
27//! ```
28//!
29//! # Performance
30//!
31//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
32//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
33//! correctness while preserving memory layout efficiency.
34//!
35//! # Gradient Tracking
36//!
37//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
38//! The gradient computation follows the mathematical derivative of the L2 norm operation.
39
40use crate::gradtrack::{GradEngine, GradFn};
41use crate::tensor::core::Tensor;
42
43impl Tensor {
44 /// Computes the L2 norm (Euclidean norm) over all elements
45 ///
46 /// The L2 norm is calculated as sqrt(sum(x²)) where x represents each element
47 /// in the tensor. This operation reduces the tensor to a scalar value \[1\].
48 ///
49 /// # Returns
50 ///
51 /// A scalar tensor containing the L2 norm value
52 ///
53 /// # Examples
54 ///
55 /// ```
56 /// use train_station::Tensor;
57 ///
58 /// // Basic L2 norm calculation
59 /// let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
60 /// let norm = tensor.norm();
61 /// assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5
62 /// ```
63 ///
64 /// ```
65 /// use train_station::Tensor;
66 ///
67 /// // L2 norm of a larger tensor
68 /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
69 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
70 /// let norm = tensor.norm();
71 /// // sqrt(1² + 2² + 3² + 4² + 5² + 6² + 7² + 8²) = sqrt(204) ≈ 14.283
72 /// let expected = 204.0_f32.sqrt();
73 /// assert!((norm.get(&[0]) - expected).abs() < 1e-5);
74 /// ```
75 ///
76 /// # Performance
77 ///
78 /// Uses optimized contiguous tensor path with 4x loop unrolling for better
79 /// performance. Non-contiguous tensors use stride-aware iteration.
80 #[track_caller]
81 pub fn norm(&self) -> Tensor {
82 // Compute sqrt(sum(x^2))
83 let mut sumsq = 0.0f32;
84 let n = self.size();
85
86 if self.is_contiguous() {
87 // Fast path for contiguous tensors
88 unsafe {
89 let src = self.as_ptr();
90 let mut i = 0usize;
91 while i + 4 <= n {
92 let x0 = *src.add(i);
93 let x1 = *src.add(i + 1);
94 let x2 = *src.add(i + 2);
95 let x3 = *src.add(i + 3);
96 sumsq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
97 i += 4;
98 }
99 while i < n {
100 let v = *src.add(i);
101 sumsq += v * v;
102 i += 1;
103 }
104 }
105 } else {
106 // Stride-aware path for non-contiguous tensors
107 let dims = self.shape().dims().to_vec();
108 for flat_idx in 0..n {
109 // Convert flat index to multi-dimensional coordinates
110 let mut coords = vec![0; dims.len()];
111 let mut tmp = flat_idx;
112 for k in (0..dims.len()).rev() {
113 coords[k] = tmp % dims[k];
114 tmp /= dims[k];
115 }
116
117 // Get value using stride-aware offset
118 let offset = self.shape().offset(&coords);
119 let value = unsafe { *self.as_ptr().add(offset) };
120 sumsq += value * value;
121 }
122 }
123 let mut out = Tensor::new(vec![1]);
124 unsafe {
125 *out.as_mut_ptr() = sumsq.sqrt();
126 }
127
128 if self.requires_grad() {
129 let mut result = out.clone();
130 result.set_requires_grad_internal(true);
131 let grad_fn = GradFn::ReduceNorm {
132 saved_norm: Box::new(out.clone()),
133 saved_input: Box::new(self.clone()),
134 input_shape: self.shape().dims().to_vec(),
135 };
136 result.set_grad_fn(grad_fn.clone());
137 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
138 return result;
139 }
140
141 out
142 }
143
144 /// Computes the L2 norm over specified dimensions
145 ///
146 /// Reduces the tensor along the specified dimensions by computing the L2 norm
147 /// of each slice. The result maintains the original tensor structure with
148 /// reduced dimensions optionally preserved as size-1 dimensions.
149 ///
150 /// # Arguments
151 ///
152 /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
153 /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
154 ///
155 /// # Returns
156 ///
157 /// A tensor with L2 norm computed over the specified dimensions
158 ///
159 /// # Examples
160 ///
161 /// ```
162 /// use train_station::Tensor;
163 ///
164 /// // Norm along rows (dimension 1) with keepdim=true
165 /// let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
166 /// let row_norms = matrix.norm_dims(&[1], true);
167 /// assert_eq!(row_norms.shape().dims(), vec![2, 1]);
168 /// assert!((row_norms.get(&[0, 0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²)
169 /// assert!((row_norms.get(&[1, 0]) - 5.0).abs() < 1e-6); // sqrt(0² + 5²)
170 /// ```
171 ///
172 /// ```
173 /// use train_station::Tensor;
174 ///
175 /// // Norm along columns (dimension 0) with keepdim=false
176 /// let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
177 /// let col_norms = matrix.norm_dims(&[0], false);
178 /// assert_eq!(col_norms.shape().dims(), vec![2]);
179 /// assert!((col_norms.get(&[0]) - 3.0).abs() < 1e-6); // sqrt(3² + 0²)
180 /// assert!((col_norms.get(&[1]) - 6.403).abs() < 1e-3); // sqrt(4² + 5²)
181 /// ```
182 ///
183 /// ```
184 /// use train_station::Tensor;
185 ///
186 /// // Norm over multiple dimensions
187 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
188 /// let norm_all = tensor.norm_dims(&[0, 1], false);
189 /// assert_eq!(norm_all.shape().dims(), vec![1]);
190 /// // sqrt(1² + 2² + 3² + 4²) = sqrt(30) ≈ 5.477
191 /// assert!((norm_all.get(&[0]) - 30.0_f32.sqrt()).abs() < 1e-5);
192 /// ```
193 ///
194 /// # Panics
195 ///
196 /// * If `dims` is empty
197 /// * If any dimension index is out of bounds for the tensor rank
198 ///
199 /// # Performance
200 ///
201 /// Uses efficient coordinate-based iteration that works correctly with
202 /// both contiguous and non-contiguous tensor layouts.
203 #[track_caller]
204 pub fn norm_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
205 assert!(
206 !dims.is_empty(),
207 "norm_dims requires at least one dimension"
208 );
209 let rank = self.shape().rank();
210 for &d in dims {
211 assert!(
212 d < rank,
213 "norm_dims dim {} out of bounds for rank {}",
214 d,
215 rank
216 );
217 }
218
219 // Build output shape
220 let in_shape = self.shape().dims().to_vec();
221 let mut out_dims = in_shape.clone();
222 let mut reduced: Vec<usize> = dims.to_vec();
223 reduced.sort_unstable();
224 reduced.dedup();
225 for &d in reduced.iter() {
226 out_dims[d] = if keepdim { 1 } else { 0 };
227 }
228 if !keepdim {
229 out_dims.retain(|&s| s != 0);
230 }
231 if out_dims.is_empty() {
232 out_dims.push(1);
233 }
234 let mut out = Tensor::zeros(out_dims.clone());
235
236 // Compute sum of squares reduced, then sqrt
237 let out_rank = out.shape().rank();
238 let mut coords = vec![0usize; rank];
239 unsafe {
240 let sptr = out.as_mut_ptr();
241 for lin in 0..self.size() {
242 let mut tmp = lin;
243 for i in (0..rank).rev() {
244 let s = in_shape[i];
245 coords[i] = if s == 0 { 0 } else { tmp % s };
246 if s != 0 {
247 tmp /= s;
248 }
249 }
250 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
251 for (i, &c) in coords.iter().enumerate().take(rank) {
252 if reduced.contains(&i) {
253 if keepdim {
254 out_coords.push(0);
255 }
256 } else {
257 out_coords.push(c);
258 }
259 }
260 let off = if out_coords.is_empty() {
261 0
262 } else {
263 out.shape().offset(&out_coords)
264 };
265 // Get input value using stride-aware offset
266 let in_offset = self.shape().offset(&coords);
267 let v = *self.as_ptr().add(in_offset);
268 *sptr.add(off) += v * v;
269 }
270 // sqrt in place
271 for i in 0..out.size() {
272 *sptr.add(i) = (*sptr.add(i)).sqrt();
273 }
274 }
275
276 if self.requires_grad() {
277 let mut result = out.clone();
278 result.set_requires_grad_internal(true);
279 let grad_fn = GradFn::ReduceNormDims {
280 dims: reduced,
281 keepdim,
282 input_shape: self.shape().dims().to_vec(),
283 saved_norm: Box::new(out.clone()),
284 saved_input: Box::new(self.clone()),
285 };
286 result.set_grad_fn(grad_fn.clone());
287 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
288 return result;
289 }
290
291 out
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_norm_forward_basic() {
301 let x = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
302 let n = x.norm();
303 unsafe {
304 assert!((*n.as_ptr() - 5.0).abs() < 1e-6);
305 }
306 }
307
308 #[test]
309 fn test_norm_dims_forward() {
310 let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
311 let n = x.norm_dims(&[1], true);
312 assert_eq!(n.shape().dims(), vec![2, 1]);
313 assert!((n.get(&[0, 0]) - 5.0).abs() < 1e-6);
314 assert!((n.get(&[1, 0]) - 5.0).abs() < 1e-6);
315 }
316
317 #[test]
318 fn test_norm_non_contiguous_transpose() {
319 // Test norm on transposed tensor (non-contiguous view)
320 let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 12.0, 5.0, 0.0], vec![2, 3]).unwrap();
321 // Original: [[3, 4, 0], [12, 5, 0]]
322
323 let x_t = x.transpose(0, 1);
324 // Transposed: [[3, 12], [4, 5], [0, 0]]
325 assert!(!x_t.is_contiguous()); // Should be a view
326
327 let norm_orig = x.norm();
328 let norm_view = x_t.norm();
329
330 // Both should give the same result
331 assert!((norm_orig.get(&[0]) - norm_view.get(&[0])).abs() < 1e-6);
332
333 // Expected norm of [3,4,0,12,5,0]: sqrt(3²+4²+0²+12²+5²+0²) = sqrt(9+16+144+25) = sqrt(194) ≈ 13.928
334 let expected_norm = 194.0_f32.sqrt();
335 assert!((norm_orig.get(&[0]) - expected_norm).abs() < 1e-5);
336 }
337
338 #[test]
339 fn test_norm_dims_non_contiguous() {
340 // Test norm_dims on non-contiguous tensor
341 let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 12.0, 5.0, 0.0], vec![2, 3]).unwrap();
342 let x_t = x.transpose(0, 1); // [3, 2]
343 assert!(!x_t.is_contiguous());
344
345 // Norm along dim 0 of transposed tensor
346 let norm_dim0 = x_t.norm_dims(&[0], false);
347 assert_eq!(norm_dim0.shape().dims(), vec![2]);
348
349 // For dim 0: [3,4,0] and [12,5,0]
350 // norm([3,4,0]) = sqrt(3²+4²+0²) = sqrt(25) = 5
351 // norm([12,5,0]) = sqrt(12²+5²+0²) = sqrt(169) = 13
352 assert!((norm_dim0.get(&[0]) - 5.0).abs() < 1e-6);
353 assert!((norm_dim0.get(&[1]) - 13.0).abs() < 1e-6);
354
355 // Norm along dim 1 of transposed tensor
356 let norm_dim1 = x_t.norm_dims(&[1], false);
357 assert_eq!(norm_dim1.shape().dims(), vec![3]);
358 // norm([3,12]) = sqrt(9+144) = sqrt(153) ≈ 12.369
359 // norm([4,5]) = sqrt(16+25) = sqrt(41) ≈ 6.403
360 // norm([0,0]) = sqrt(0+0) = 0
361 assert!((norm_dim1.get(&[0]) - 153.0_f32.sqrt()).abs() < 1e-5);
362 assert!((norm_dim1.get(&[1]) - 41.0_f32.sqrt()).abs() < 1e-5);
363 assert!((norm_dim1.get(&[2]) - 0.0).abs() < 1e-6);
364 }
365
366 #[test]
367 fn test_norm_permuted_tensor() {
368 // Test with permuted tensor
369 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
370 let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
371
372 // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
373 let x_perm = x.permute(vec![2, 1, 0]);
374 assert!(!x_perm.is_contiguous());
375
376 let norm_orig = x.norm();
377 let norm_perm = x_perm.norm();
378
379 // Should give same result
380 assert!((norm_orig.get(&[0]) - norm_perm.get(&[0])).abs() < 1e-6);
381
382 // norm([1,2,3,4,5,6,7,8]) = sqrt(1+4+9+16+25+36+49+64) = sqrt(204) ≈ 14.283
383 let expected_norm = 204.0_f32.sqrt();
384 assert!((norm_orig.get(&[0]) - expected_norm).abs() < 1e-5);
385 }
386}