train_station/tensor/reductions/norm.rs
1//! L2 norm reduction operations for tensors
2//!
3//! This module provides L2 norm (Euclidean norm) reduction operations for tensors.
4//! The L2 norm computes the square root of the sum of squared elements, which is
5//! commonly used in machine learning for regularization, distance calculations,
6//! and gradient clipping.
7//!
8//! # Operations
9//!
10//! * `norm()` - Computes L2 norm over all elements, returning a scalar tensor
11//! * `norm_dims()` - Computes L2 norm over specified dimensions with optional dimension preservation
12//!
13//! # Examples
14//!
15//! ```
16//! use train_station::Tensor;
17//!
18//! // Compute L2 norm of all elements
19//! let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
20//! let norm = tensor.norm();
21//! assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5
22//!
23//! // Compute L2 norm along specific dimensions
24//! let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
25//! let row_norms = matrix.norm_dims(&[1], true);
26//! assert_eq!(row_norms.shape().dims, vec![2, 1]);
27//! ```
28//!
29//! # Performance
30//!
31//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
32//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
33//! correctness while preserving memory layout efficiency.
34//!
35//! # Gradient Tracking
36//!
37//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
38//! The gradient computation follows the mathematical derivative of the L2 norm operation.
39
40use crate::gradtrack::{GradEngine, GradFn};
41use crate::tensor::core::Tensor;
42
43impl Tensor {
44 /// Computes the L2 norm (Euclidean norm) over all elements
45 ///
46 /// The L2 norm is calculated as sqrt(sum(x²)) where x represents each element
47 /// in the tensor. This operation reduces the tensor to a scalar value \[1\].
48 ///
49 /// # Returns
50 ///
51 /// A scalar tensor containing the L2 norm value
52 ///
53 /// # Examples
54 ///
55 /// ```
56 /// use train_station::Tensor;
57 ///
58 /// // Basic L2 norm calculation
59 /// let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
60 /// let norm = tensor.norm();
61 /// assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5
62 /// ```
63 ///
64 /// ```
65 /// use train_station::Tensor;
66 ///
67 /// // L2 norm of a larger tensor
68 /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
69 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
70 /// let norm = tensor.norm();
71 /// // sqrt(1² + 2² + 3² + 4² + 5² + 6² + 7² + 8²) = sqrt(204) ≈ 14.283
72 /// let expected = 204.0_f32.sqrt();
73 /// assert!((norm.get(&[0]) - expected).abs() < 1e-5);
74 /// ```
75 ///
76 /// # Performance
77 ///
78 /// Uses optimized contiguous tensor path with 4x loop unrolling for better
79 /// performance. Non-contiguous tensors use stride-aware iteration.
80 pub fn norm(&self) -> Tensor {
81 // Compute sqrt(sum(x^2))
82 let mut sumsq = 0.0f32;
83 let n = self.size();
84
85 if self.is_contiguous() {
86 // Fast path for contiguous tensors
87 unsafe {
88 let src = self.as_ptr();
89 let mut i = 0usize;
90 while i + 4 <= n {
91 let x0 = *src.add(i);
92 let x1 = *src.add(i + 1);
93 let x2 = *src.add(i + 2);
94 let x3 = *src.add(i + 3);
95 sumsq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
96 i += 4;
97 }
98 while i < n {
99 let v = *src.add(i);
100 sumsq += v * v;
101 i += 1;
102 }
103 }
104 } else {
105 // Stride-aware path for non-contiguous tensors
106 let dims = self.shape().dims.clone();
107 for flat_idx in 0..n {
108 // Convert flat index to multi-dimensional coordinates
109 let mut coords = vec![0; dims.len()];
110 let mut tmp = flat_idx;
111 for k in (0..dims.len()).rev() {
112 coords[k] = tmp % dims[k];
113 tmp /= dims[k];
114 }
115
116 // Get value using stride-aware offset
117 let offset = self.shape().offset(&coords);
118 let value = unsafe { *self.as_ptr().add(offset) };
119 sumsq += value * value;
120 }
121 }
122 let mut out = Tensor::new(vec![1]);
123 unsafe {
124 *out.as_mut_ptr() = sumsq.sqrt();
125 }
126
127 if self.requires_grad() {
128 let mut result = out.clone();
129 result.set_requires_grad_internal(true);
130 let grad_fn = GradFn::ReduceNorm {
131 saved_norm: Box::new(out.clone()),
132 saved_input: Box::new(self.clone()),
133 input_shape: self.shape().dims.clone(),
134 };
135 result.set_grad_fn(grad_fn.clone());
136 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
137 return result;
138 }
139
140 out
141 }
142
143 /// Computes the L2 norm over specified dimensions
144 ///
145 /// Reduces the tensor along the specified dimensions by computing the L2 norm
146 /// of each slice. The result maintains the original tensor structure with
147 /// reduced dimensions optionally preserved as size-1 dimensions.
148 ///
149 /// # Arguments
150 ///
151 /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
152 /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
153 ///
154 /// # Returns
155 ///
156 /// A tensor with L2 norm computed over the specified dimensions
157 ///
158 /// # Examples
159 ///
160 /// ```
161 /// use train_station::Tensor;
162 ///
163 /// // Norm along rows (dimension 1) with keepdim=true
164 /// let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
165 /// let row_norms = matrix.norm_dims(&[1], true);
166 /// assert_eq!(row_norms.shape().dims, vec![2, 1]);
167 /// assert!((row_norms.get(&[0, 0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²)
168 /// assert!((row_norms.get(&[1, 0]) - 5.0).abs() < 1e-6); // sqrt(0² + 5²)
169 /// ```
170 ///
171 /// ```
172 /// use train_station::Tensor;
173 ///
174 /// // Norm along columns (dimension 0) with keepdim=false
175 /// let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
176 /// let col_norms = matrix.norm_dims(&[0], false);
177 /// assert_eq!(col_norms.shape().dims, vec![2]);
178 /// assert!((col_norms.get(&[0]) - 3.0).abs() < 1e-6); // sqrt(3² + 0²)
179 /// assert!((col_norms.get(&[1]) - 6.403).abs() < 1e-3); // sqrt(4² + 5²)
180 /// ```
181 ///
182 /// ```
183 /// use train_station::Tensor;
184 ///
185 /// // Norm over multiple dimensions
186 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
187 /// let norm_all = tensor.norm_dims(&[0, 1], false);
188 /// assert_eq!(norm_all.shape().dims, vec![1]);
189 /// // sqrt(1² + 2² + 3² + 4²) = sqrt(30) ≈ 5.477
190 /// assert!((norm_all.get(&[0]) - 30.0_f32.sqrt()).abs() < 1e-5);
191 /// ```
192 ///
193 /// # Panics
194 ///
195 /// * If `dims` is empty
196 /// * If any dimension index is out of bounds for the tensor rank
197 ///
198 /// # Performance
199 ///
200 /// Uses efficient coordinate-based iteration that works correctly with
201 /// both contiguous and non-contiguous tensor layouts.
202 pub fn norm_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
203 assert!(
204 !dims.is_empty(),
205 "norm_dims requires at least one dimension"
206 );
207 let rank = self.shape().rank();
208 for &d in dims {
209 assert!(
210 d < rank,
211 "norm_dims dim {} out of bounds for rank {}",
212 d,
213 rank
214 );
215 }
216
217 // Build output shape
218 let in_shape = self.shape().dims.clone();
219 let mut out_dims = in_shape.clone();
220 let mut reduced: Vec<usize> = dims.to_vec();
221 reduced.sort_unstable();
222 reduced.dedup();
223 for &d in reduced.iter() {
224 out_dims[d] = if keepdim { 1 } else { 0 };
225 }
226 if !keepdim {
227 out_dims.retain(|&s| s != 0);
228 }
229 if out_dims.is_empty() {
230 out_dims.push(1);
231 }
232 let mut out = Tensor::zeros(out_dims.clone());
233
234 // Compute sum of squares reduced, then sqrt
235 let out_rank = out.shape().rank();
236 let mut coords = vec![0usize; rank];
237 unsafe {
238 let sptr = out.as_mut_ptr();
239 for lin in 0..self.size() {
240 let mut tmp = lin;
241 for i in (0..rank).rev() {
242 let s = in_shape[i];
243 coords[i] = if s == 0 { 0 } else { tmp % s };
244 if s != 0 {
245 tmp /= s;
246 }
247 }
248 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
249 for (i, &c) in coords.iter().enumerate().take(rank) {
250 if reduced.contains(&i) {
251 if keepdim {
252 out_coords.push(0);
253 }
254 } else {
255 out_coords.push(c);
256 }
257 }
258 let off = if out_coords.is_empty() {
259 0
260 } else {
261 out.shape().offset(&out_coords)
262 };
263 // Get input value using stride-aware offset
264 let in_offset = self.shape().offset(&coords);
265 let v = *self.as_ptr().add(in_offset);
266 *sptr.add(off) += v * v;
267 }
268 // sqrt in place
269 for i in 0..out.size() {
270 *sptr.add(i) = (*sptr.add(i)).sqrt();
271 }
272 }
273
274 if self.requires_grad() {
275 let mut result = out.clone();
276 result.set_requires_grad_internal(true);
277 let grad_fn = GradFn::ReduceNormDims {
278 dims: reduced,
279 keepdim,
280 input_shape: self.shape().dims.clone(),
281 saved_norm: Box::new(out.clone()),
282 saved_input: Box::new(self.clone()),
283 };
284 result.set_grad_fn(grad_fn.clone());
285 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
286 return result;
287 }
288
289 out
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_norm_forward_basic() {
299 let x = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
300 let n = x.norm();
301 unsafe {
302 assert!((*n.as_ptr() - 5.0).abs() < 1e-6);
303 }
304 }
305
306 #[test]
307 fn test_norm_dims_forward() {
308 let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
309 let n = x.norm_dims(&[1], true);
310 assert_eq!(n.shape().dims, vec![2, 1]);
311 assert!((n.get(&[0, 0]) - 5.0).abs() < 1e-6);
312 assert!((n.get(&[1, 0]) - 5.0).abs() < 1e-6);
313 }
314
315 #[test]
316 fn test_norm_non_contiguous_transpose() {
317 // Test norm on transposed tensor (non-contiguous view)
318 let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 12.0, 5.0, 0.0], vec![2, 3]).unwrap();
319 // Original: [[3, 4, 0], [12, 5, 0]]
320
321 let x_t = x.transpose(0, 1);
322 // Transposed: [[3, 12], [4, 5], [0, 0]]
323 assert!(!x_t.is_contiguous()); // Should be a view
324
325 let norm_orig = x.norm();
326 let norm_view = x_t.norm();
327
328 // Both should give the same result
329 assert!((norm_orig.get(&[0]) - norm_view.get(&[0])).abs() < 1e-6);
330
331 // Expected norm of [3,4,0,12,5,0]: sqrt(3²+4²+0²+12²+5²+0²) = sqrt(9+16+144+25) = sqrt(194) ≈ 13.928
332 let expected_norm = 194.0_f32.sqrt();
333 assert!((norm_orig.get(&[0]) - expected_norm).abs() < 1e-5);
334 }
335
336 #[test]
337 fn test_norm_dims_non_contiguous() {
338 // Test norm_dims on non-contiguous tensor
339 let x = Tensor::from_slice(&[3.0, 4.0, 0.0, 12.0, 5.0, 0.0], vec![2, 3]).unwrap();
340 let x_t = x.transpose(0, 1); // [3, 2]
341 assert!(!x_t.is_contiguous());
342
343 // Norm along dim 0 of transposed tensor
344 let norm_dim0 = x_t.norm_dims(&[0], false);
345 assert_eq!(norm_dim0.shape().dims, vec![2]);
346
347 // For dim 0: [3,4,0] and [12,5,0]
348 // norm([3,4,0]) = sqrt(3²+4²+0²) = sqrt(25) = 5
349 // norm([12,5,0]) = sqrt(12²+5²+0²) = sqrt(169) = 13
350 assert!((norm_dim0.get(&[0]) - 5.0).abs() < 1e-6);
351 assert!((norm_dim0.get(&[1]) - 13.0).abs() < 1e-6);
352
353 // Norm along dim 1 of transposed tensor
354 let norm_dim1 = x_t.norm_dims(&[1], false);
355 assert_eq!(norm_dim1.shape().dims, vec![3]);
356 // norm([3,12]) = sqrt(9+144) = sqrt(153) ≈ 12.369
357 // norm([4,5]) = sqrt(16+25) = sqrt(41) ≈ 6.403
358 // norm([0,0]) = sqrt(0+0) = 0
359 assert!((norm_dim1.get(&[0]) - 153.0_f32.sqrt()).abs() < 1e-5);
360 assert!((norm_dim1.get(&[1]) - 41.0_f32.sqrt()).abs() < 1e-5);
361 assert!((norm_dim1.get(&[2]) - 0.0).abs() < 1e-6);
362 }
363
364 #[test]
365 fn test_norm_permuted_tensor() {
366 // Test with permuted tensor
367 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
368 let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
369
370 // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
371 let x_perm = x.permute(vec![2, 1, 0]);
372 assert!(!x_perm.is_contiguous());
373
374 let norm_orig = x.norm();
375 let norm_perm = x_perm.norm();
376
377 // Should give same result
378 assert!((norm_orig.get(&[0]) - norm_perm.get(&[0])).abs() < 1e-6);
379
380 // norm([1,2,3,4,5,6,7,8]) = sqrt(1+4+9+16+25+36+49+64) = sqrt(204) ≈ 14.283
381 let expected_norm = 204.0_f32.sqrt();
382 assert!((norm_orig.get(&[0]) - expected_norm).abs() < 1e-5);
383 }
384}