Skip to main content

ruvector_gnn/
tensor.rs

1//! Tensor operations for GNN computations.
2//!
3//! Provides efficient tensor operations including:
4//! - Matrix multiplication
5//! - Element-wise operations
6//! - Activation functions
7//! - Weight initialization
8//! - Normalization
9
10use crate::error::{GnnError, Result};
11use rand::Rng;
12use rand_distr::{Distribution, Normal, Uniform};
13
14/// Basic tensor operations for GNN computations
15#[derive(Debug, Clone, PartialEq)]
16pub struct Tensor {
17    /// Flattened tensor data
18    pub data: Vec<f32>,
19    /// Shape of the tensor (dimensions)
20    pub shape: Vec<usize>,
21}
22
23impl Tensor {
24    /// Create a new tensor from data and shape
25    ///
26    /// # Arguments
27    /// * `data` - Flattened tensor data
28    /// * `shape` - Dimensions of the tensor
29    ///
30    /// # Returns
31    /// A new `Tensor` instance
32    ///
33    /// # Errors
34    /// Returns `GnnError::InvalidShape` if data length doesn't match shape
35    pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Result<Self> {
36        let expected_len: usize = shape.iter().product();
37        if data.len() != expected_len {
38            return Err(GnnError::invalid_shape(format!(
39                "Data length {} doesn't match shape {:?} (expected {})",
40                data.len(),
41                shape,
42                expected_len
43            )));
44        }
45        Ok(Self { data, shape })
46    }
47
48    /// Create a zero-filled tensor with the given shape
49    ///
50    /// # Arguments
51    /// * `shape` - Dimensions of the tensor
52    ///
53    /// # Returns
54    /// A new zero-filled `Tensor`
55    ///
56    /// # Errors
57    /// Returns `GnnError::InvalidShape` if shape is empty or contains zero
58    pub fn zeros(shape: &[usize]) -> Result<Self> {
59        if shape.is_empty() || shape.iter().any(|&d| d == 0) {
60            return Err(GnnError::invalid_shape(format!(
61                "Invalid shape: {:?}",
62                shape
63            )));
64        }
65        let size: usize = shape.iter().product();
66        Ok(Self {
67            data: vec![0.0; size],
68            shape: shape.to_vec(),
69        })
70    }
71
72    /// Create a 1D tensor from a vector
73    ///
74    /// # Arguments
75    /// * `data` - Vector data
76    ///
77    /// # Returns
78    /// A new 1D `Tensor`
79    pub fn from_vec(data: Vec<f32>) -> Self {
80        let len = data.len();
81        Self {
82            data,
83            shape: vec![len],
84        }
85    }
86
87    /// Compute dot product with another tensor (both must be 1D)
88    ///
89    /// # Arguments
90    /// * `other` - Another tensor to compute dot product with
91    ///
92    /// # Returns
93    /// The dot product as a scalar
94    ///
95    /// # Errors
96    /// Returns `GnnError::DimensionMismatch` if tensors are not 1D or have different lengths
97    pub fn dot(&self, other: &Tensor) -> Result<f32> {
98        if self.shape.len() != 1 || other.shape.len() != 1 {
99            return Err(GnnError::dimension_mismatch(
100                "1D tensors",
101                format!("{}D and {}D", self.shape.len(), other.shape.len()),
102            ));
103        }
104        if self.shape[0] != other.shape[0] {
105            return Err(GnnError::dimension_mismatch(
106                format!("length {}", self.shape[0]),
107                format!("length {}", other.shape[0]),
108            ));
109        }
110
111        let result = self
112            .data
113            .iter()
114            .zip(other.data.iter())
115            .map(|(a, b)| a * b)
116            .sum();
117        Ok(result)
118    }
119
120    /// Matrix multiplication
121    ///
122    /// # Arguments
123    /// * `other` - Another tensor to multiply with
124    ///
125    /// # Returns
126    /// The result of matrix multiplication
127    ///
128    /// # Errors
129    /// Returns `GnnError::DimensionMismatch` if dimensions are incompatible
130    pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
131        // Support 1D x 1D (dot product), 2D x 1D, 2D x 2D
132        match (self.shape.len(), other.shape.len()) {
133            (1, 1) => {
134                let dot = self.dot(other)?;
135                Ok(Tensor::from_vec(vec![dot]))
136            }
137            (2, 1) => {
138                // Matrix-vector multiplication
139                let m = self.shape[0];
140                let n = self.shape[1];
141                if n != other.shape[0] {
142                    return Err(GnnError::dimension_mismatch(
143                        format!("{}x{}", m, n),
144                        format!("vector of length {}", other.shape[0]),
145                    ));
146                }
147
148                let mut result = vec![0.0; m];
149                for i in 0..m {
150                    for j in 0..n {
151                        result[i] += self.data[i * n + j] * other.data[j];
152                    }
153                }
154                Ok(Tensor::from_vec(result))
155            }
156            (2, 2) => {
157                // Matrix-matrix multiplication
158                let m = self.shape[0];
159                let n = self.shape[1];
160                let p = other.shape[1];
161
162                if n != other.shape[0] {
163                    return Err(GnnError::dimension_mismatch(
164                        format!("{}x{}", m, n),
165                        format!("{}x{}", other.shape[0], p),
166                    ));
167                }
168
169                let mut result = vec![0.0; m * p];
170                for i in 0..m {
171                    for j in 0..p {
172                        for k in 0..n {
173                            result[i * p + j] += self.data[i * n + k] * other.data[k * p + j];
174                        }
175                    }
176                }
177                Tensor::new(result, vec![m, p])
178            }
179            _ => Err(GnnError::dimension_mismatch(
180                "1D or 2D tensors",
181                format!("{}D and {}D", self.shape.len(), other.shape.len()),
182            )),
183        }
184    }
185
186    /// Element-wise addition
187    ///
188    /// # Arguments
189    /// * `other` - Another tensor to add
190    ///
191    /// # Returns
192    /// The sum of the two tensors
193    ///
194    /// # Errors
195    /// Returns `GnnError::DimensionMismatch` if shapes don't match
196    pub fn add(&self, other: &Tensor) -> Result<Tensor> {
197        if self.shape != other.shape {
198            return Err(GnnError::dimension_mismatch(
199                format!("{:?}", self.shape),
200                format!("{:?}", other.shape),
201            ));
202        }
203
204        let result: Vec<f32> = self
205            .data
206            .iter()
207            .zip(other.data.iter())
208            .map(|(a, b)| a + b)
209            .collect();
210
211        Tensor::new(result, self.shape.clone())
212    }
213
214    /// Scalar multiplication
215    ///
216    /// # Arguments
217    /// * `scalar` - Scalar value to multiply by
218    ///
219    /// # Returns
220    /// A new tensor with all elements scaled
221    pub fn scale(&self, scalar: f32) -> Tensor {
222        let result: Vec<f32> = self.data.iter().map(|&x| x * scalar).collect();
223        Tensor {
224            data: result,
225            shape: self.shape.clone(),
226        }
227    }
228
229    /// ReLU activation function (max(0, x))
230    ///
231    /// # Returns
232    /// A new tensor with ReLU applied element-wise
233    pub fn relu(&self) -> Tensor {
234        let result: Vec<f32> = self.data.iter().map(|&x| x.max(0.0)).collect();
235        Tensor {
236            data: result,
237            shape: self.shape.clone(),
238        }
239    }
240
241    /// Sigmoid activation function (1 / (1 + e^(-x))) with numerical stability
242    ///
243    /// # Returns
244    /// A new tensor with sigmoid applied element-wise
245    pub fn sigmoid(&self) -> Tensor {
246        let result: Vec<f32> = self
247            .data
248            .iter()
249            .map(|&x| {
250                if x > 0.0 {
251                    1.0 / (1.0 + (-x).exp())
252                } else {
253                    let ex = x.exp();
254                    ex / (1.0 + ex)
255                }
256            })
257            .collect();
258        Tensor {
259            data: result,
260            shape: self.shape.clone(),
261        }
262    }
263
264    /// Tanh activation function
265    ///
266    /// # Returns
267    /// A new tensor with tanh applied element-wise
268    pub fn tanh(&self) -> Tensor {
269        let result: Vec<f32> = self.data.iter().map(|&x| x.tanh()).collect();
270        Tensor {
271            data: result,
272            shape: self.shape.clone(),
273        }
274    }
275
276    /// Compute L2 norm (Euclidean norm) with improved precision
277    ///
278    /// # Returns
279    /// The L2 norm of the tensor
280    pub fn l2_norm(&self) -> f32 {
281        // Use f64 accumulator for better numerical precision
282        let sum_squares: f64 = self.data.iter().map(|&x| (x as f64) * (x as f64)).sum();
283        (sum_squares.sqrt()) as f32
284    }
285
286    /// Normalize the tensor to unit L2 norm
287    ///
288    /// # Returns
289    /// A normalized tensor
290    ///
291    /// # Errors
292    /// Returns `GnnError::InvalidInput` if norm is zero
293    pub fn normalize(&self) -> Result<Tensor> {
294        let norm = self.l2_norm();
295        if norm == 0.0 {
296            return Err(GnnError::invalid_input(
297                "Cannot normalize zero vector".to_string(),
298            ));
299        }
300        Ok(self.scale(1.0 / norm))
301    }
302
303    /// Get a slice view of the tensor data
304    ///
305    /// # Returns
306    /// A slice reference to the underlying data
307    pub fn as_slice(&self) -> &[f32] {
308        &self.data
309    }
310
311    /// Consume the tensor and return the underlying vector
312    ///
313    /// # Returns
314    /// The vector containing the tensor data
315    pub fn into_vec(self) -> Vec<f32> {
316        self.data
317    }
318
319    /// Get the number of elements in the tensor
320    pub fn len(&self) -> usize {
321        self.data.len()
322    }
323
324    /// Check if the tensor is empty
325    pub fn is_empty(&self) -> bool {
326        self.data.is_empty()
327    }
328}
329
330/// Xavier/Glorot initialization for neural network weights
331///
332/// Samples from uniform distribution U(-a, a) where a = sqrt(6 / (fan_in + fan_out))
333///
334/// # Arguments
335/// * `fan_in` - Number of input units
336/// * `fan_out` - Number of output units
337///
338/// # Returns
339/// A vector of initialized weights
340///
341/// # Panics
342/// Panics if fan_in or fan_out is 0
343pub fn xavier_init(fan_in: usize, fan_out: usize) -> Vec<f32> {
344    assert!(
345        fan_in > 0 && fan_out > 0,
346        "fan_in and fan_out must be positive"
347    );
348
349    let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
350    let uniform = Uniform::new(-limit, limit);
351    let mut rng = rand::thread_rng();
352
353    (0..fan_in * fan_out)
354        .map(|_| uniform.sample(&mut rng))
355        .collect()
356}
357
358/// He initialization for ReLU networks
359///
360/// Samples from normal distribution N(0, sqrt(2 / fan_in))
361///
362/// # Arguments
363/// * `fan_in` - Number of input units
364///
365/// # Returns
366/// A vector of initialized weights
367///
368/// # Panics
369/// Panics if fan_in is 0
370pub fn he_init(fan_in: usize) -> Vec<f32> {
371    assert!(fan_in > 0, "fan_in must be positive");
372
373    let std_dev = (2.0 / fan_in as f32).sqrt();
374    let normal = Normal::new(0.0, std_dev).expect("Invalid normal distribution parameters");
375    let mut rng = rand::thread_rng();
376
377    (0..fan_in).map(|_| normal.sample(&mut rng)).collect()
378}
379
380/// Element-wise (Hadamard) product
381///
382/// # Arguments
383/// * `a` - First vector
384/// * `b` - Second vector
385///
386/// # Returns
387/// Element-wise product of the two vectors
388///
389/// # Panics
390/// Panics if vectors have different lengths
391pub fn hadamard_product(a: &[f32], b: &[f32]) -> Vec<f32> {
392    assert_eq!(a.len(), b.len(), "Vectors must have the same length");
393    a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
394}
395
396/// Element-wise vector addition
397///
398/// # Arguments
399/// * `a` - First vector
400/// * `b` - Second vector
401///
402/// # Returns
403/// Element-wise sum of the two vectors
404///
405/// # Panics
406/// Panics if vectors have different lengths
407pub fn vector_add(a: &[f32], b: &[f32]) -> Vec<f32> {
408    assert_eq!(a.len(), b.len(), "Vectors must have the same length");
409    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
410}
411
412/// Scalar multiplication of a vector
413///
414/// # Arguments
415/// * `v` - Input vector
416/// * `scalar` - Scalar multiplier
417///
418/// # Returns
419/// Vector with all elements multiplied by scalar
420pub fn vector_scale(v: &[f32], scalar: f32) -> Vec<f32> {
421    v.iter().map(|&x| x * scalar).collect()
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    const EPSILON: f32 = 1e-6;
429
430    fn assert_vec_approx_eq(a: &[f32], b: &[f32], epsilon: f32) {
431        assert_eq!(a.len(), b.len(), "Vectors have different lengths");
432        for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
433            assert!(
434                (x - y).abs() < epsilon,
435                "Values at index {} differ: {} vs {} (diff: {})",
436                i,
437                x,
438                y,
439                (x - y).abs()
440            );
441        }
442    }
443
444    #[test]
445    fn test_tensor_new() {
446        let data = vec![1.0, 2.0, 3.0, 4.0];
447        let tensor = Tensor::new(data.clone(), vec![2, 2]).unwrap();
448        assert_eq!(tensor.data, data);
449        assert_eq!(tensor.shape, vec![2, 2]);
450    }
451
452    #[test]
453    fn test_tensor_new_invalid_shape() {
454        let data = vec![1.0, 2.0, 3.0];
455        let result = Tensor::new(data, vec![2, 2]);
456        assert!(result.is_err());
457    }
458
459    #[test]
460    fn test_tensor_zeros() {
461        let tensor = Tensor::zeros(&[3, 2]).unwrap();
462        assert_eq!(tensor.data, vec![0.0; 6]);
463        assert_eq!(tensor.shape, vec![3, 2]);
464    }
465
466    #[test]
467    fn test_tensor_zeros_invalid_shape() {
468        let result = Tensor::zeros(&[0, 2]);
469        assert!(result.is_err());
470
471        let result = Tensor::zeros(&[]);
472        assert!(result.is_err());
473    }
474
475    #[test]
476    fn test_tensor_from_vec() {
477        let data = vec![1.0, 2.0, 3.0];
478        let tensor = Tensor::from_vec(data.clone());
479        assert_eq!(tensor.data, data);
480        assert_eq!(tensor.shape, vec![3]);
481    }
482
483    #[test]
484    fn test_dot_product() {
485        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
486        let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
487        let result = a.dot(&b).unwrap();
488        assert!((result - 32.0).abs() < EPSILON); // 1*4 + 2*5 + 3*6 = 32
489    }
490
491    #[test]
492    fn test_dot_product_dimension_mismatch() {
493        let a = Tensor::from_vec(vec![1.0, 2.0]);
494        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
495        let result = a.dot(&b);
496        assert!(result.is_err());
497    }
498
499    #[test]
500    fn test_matmul_1d() {
501        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
502        let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
503        let result = a.matmul(&b).unwrap();
504        assert_eq!(result.shape, vec![1]);
505        assert!((result.data[0] - 32.0).abs() < EPSILON);
506    }
507
508    #[test]
509    fn test_matmul_2d_1d() {
510        // Matrix-vector multiplication
511        let mat = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
512        let vec = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
513        let result = mat.matmul(&vec).unwrap();
514
515        assert_eq!(result.shape, vec![2]);
516        // [1,2,3] * [1,2,3]' = 14
517        // [4,5,6] * [1,2,3]' = 32
518        assert_vec_approx_eq(&result.data, &[14.0, 32.0], EPSILON);
519    }
520
521    #[test]
522    fn test_matmul_2d_2d() {
523        // Matrix-matrix multiplication
524        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
525        let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
526        let result = a.matmul(&b).unwrap();
527
528        assert_eq!(result.shape, vec![2, 2]);
529        // [[1,2], [3,4]] * [[5,6], [7,8]] = [[19,22], [43,50]]
530        assert_vec_approx_eq(&result.data, &[19.0, 22.0, 43.0, 50.0], EPSILON);
531    }
532
533    #[test]
534    fn test_matmul_dimension_mismatch() {
535        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
536        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
537        let result = a.matmul(&b);
538        assert!(result.is_err());
539    }
540
541    #[test]
542    fn test_add() {
543        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
544        let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
545        let result = a.add(&b).unwrap();
546        assert_eq!(result.data, vec![5.0, 7.0, 9.0]);
547    }
548
549    #[test]
550    fn test_add_dimension_mismatch() {
551        let a = Tensor::from_vec(vec![1.0, 2.0]);
552        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
553        let result = a.add(&b);
554        assert!(result.is_err());
555    }
556
557    #[test]
558    fn test_scale() {
559        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
560        let result = tensor.scale(2.0);
561        assert_eq!(result.data, vec![2.0, 4.0, 6.0]);
562    }
563
564    #[test]
565    fn test_relu() {
566        let tensor = Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0]);
567        let result = tensor.relu();
568        assert_eq!(result.data, vec![0.0, 0.0, 1.0, 2.0]);
569    }
570
571    #[test]
572    fn test_sigmoid() {
573        let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
574        let result = tensor.sigmoid();
575
576        assert!((result.data[0] - 0.5).abs() < EPSILON);
577        assert!((result.data[1] - 0.7310586).abs() < EPSILON);
578        assert!((result.data[2] - 0.26894143).abs() < EPSILON);
579    }
580
581    #[test]
582    fn test_tanh() {
583        let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
584        let result = tensor.tanh();
585
586        assert!((result.data[0] - 0.0).abs() < EPSILON);
587        assert!((result.data[1] - 0.7615942).abs() < EPSILON);
588        assert!((result.data[2] - (-0.7615942)).abs() < EPSILON);
589    }
590
591    #[test]
592    fn test_l2_norm() {
593        let tensor = Tensor::from_vec(vec![3.0, 4.0]);
594        let norm = tensor.l2_norm();
595        assert!((norm - 5.0).abs() < EPSILON);
596    }
597
598    #[test]
599    fn test_normalize() {
600        let tensor = Tensor::from_vec(vec![3.0, 4.0]);
601        let result = tensor.normalize().unwrap();
602        assert_vec_approx_eq(&result.data, &[0.6, 0.8], EPSILON);
603        assert!((result.l2_norm() - 1.0).abs() < EPSILON);
604    }
605
606    #[test]
607    fn test_normalize_zero_vector() {
608        let tensor = Tensor::from_vec(vec![0.0, 0.0]);
609        let result = tensor.normalize();
610        assert!(result.is_err());
611    }
612
613    #[test]
614    fn test_as_slice() {
615        let data = vec![1.0, 2.0, 3.0];
616        let tensor = Tensor::from_vec(data.clone());
617        assert_eq!(tensor.as_slice(), &data[..]);
618    }
619
620    #[test]
621    fn test_into_vec() {
622        let data = vec![1.0, 2.0, 3.0];
623        let tensor = Tensor::from_vec(data.clone());
624        assert_eq!(tensor.into_vec(), data);
625    }
626
627    #[test]
628    fn test_len() {
629        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
630        assert_eq!(tensor.len(), 3);
631    }
632
633    #[test]
634    fn test_is_empty() {
635        let tensor = Tensor::from_vec(vec![]);
636        assert!(tensor.is_empty());
637
638        let tensor = Tensor::from_vec(vec![1.0]);
639        assert!(!tensor.is_empty());
640    }
641
642    #[test]
643    fn test_xavier_init() {
644        let weights = xavier_init(100, 50);
645        assert_eq!(weights.len(), 5000);
646
647        // Check that values are in expected range
648        let limit = (6.0 / 150.0_f32).sqrt();
649        for &w in &weights {
650            assert!(w >= -limit && w <= limit);
651        }
652
653        // Check distribution properties
654        let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
655        assert!(mean.abs() < 0.1); // Mean should be close to 0
656    }
657
658    #[test]
659    #[should_panic(expected = "fan_in and fan_out must be positive")]
660    fn test_xavier_init_zero_fan() {
661        xavier_init(0, 10);
662    }
663
664    #[test]
665    fn test_he_init() {
666        let weights = he_init(100);
667        assert_eq!(weights.len(), 100);
668
669        // Check distribution properties
670        let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
671        assert!(mean.abs() < 0.2); // Mean should be close to 0
672    }
673
674    #[test]
675    #[should_panic(expected = "fan_in must be positive")]
676    fn test_he_init_zero_fan() {
677        he_init(0);
678    }
679
680    #[test]
681    fn test_hadamard_product() {
682        let a = vec![1.0, 2.0, 3.0];
683        let b = vec![4.0, 5.0, 6.0];
684        let result = hadamard_product(&a, &b);
685        assert_eq!(result, vec![4.0, 10.0, 18.0]);
686    }
687
688    #[test]
689    #[should_panic(expected = "Vectors must have the same length")]
690    fn test_hadamard_product_length_mismatch() {
691        let a = vec![1.0, 2.0];
692        let b = vec![1.0, 2.0, 3.0];
693        hadamard_product(&a, &b);
694    }
695
696    #[test]
697    fn test_vector_add() {
698        let a = vec![1.0, 2.0, 3.0];
699        let b = vec![4.0, 5.0, 6.0];
700        let result = vector_add(&a, &b);
701        assert_eq!(result, vec![5.0, 7.0, 9.0]);
702    }
703
704    #[test]
705    #[should_panic(expected = "Vectors must have the same length")]
706    fn test_vector_add_length_mismatch() {
707        let a = vec![1.0, 2.0];
708        let b = vec![1.0, 2.0, 3.0];
709        vector_add(&a, &b);
710    }
711
712    #[test]
713    fn test_vector_scale() {
714        let v = vec![1.0, 2.0, 3.0];
715        let result = vector_scale(&v, 2.5);
716        assert_vec_approx_eq(&result, &[2.5, 5.0, 7.5], EPSILON);
717    }
718
719    #[test]
720    fn test_complex_operations() {
721        // Test chaining operations
722        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
723        let b = Tensor::from_vec(vec![0.5, 1.0, 1.5]);
724
725        let sum = a.add(&b).unwrap();
726        let scaled = sum.scale(2.0);
727        let activated = scaled.relu();
728        let normalized = activated.normalize().unwrap();
729
730        assert!((normalized.l2_norm() - 1.0).abs() < EPSILON);
731    }
732
733    #[test]
734    fn test_edge_case_single_element() {
735        let tensor = Tensor::from_vec(vec![5.0]);
736        assert_eq!(tensor.len(), 1);
737        assert_eq!(tensor.l2_norm(), 5.0);
738
739        let normalized = tensor.normalize().unwrap();
740        assert_vec_approx_eq(&normalized.data, &[1.0], EPSILON);
741    }
742
743    #[test]
744    fn test_edge_case_negative_values() {
745        let tensor = Tensor::from_vec(vec![-3.0, -4.0]);
746        assert!((tensor.l2_norm() - 5.0).abs() < EPSILON);
747
748        let relu_result = tensor.relu();
749        assert_eq!(relu_result.data, vec![0.0, 0.0]);
750    }
751
752    #[test]
753    fn test_large_matrix_multiplication() {
754        // 10x10 matrix multiplication
755        let size = 10;
756        let a_data: Vec<f32> = (0..size * size).map(|i| i as f32).collect();
757        let b_data: Vec<f32> = (0..size * size).map(|i| (i % 2) as f32).collect();
758
759        let a = Tensor::new(a_data, vec![size, size]).unwrap();
760        let b = Tensor::new(b_data, vec![size, size]).unwrap();
761
762        let result = a.matmul(&b).unwrap();
763        assert_eq!(result.shape, vec![size, size]);
764        assert_eq!(result.len(), size * size);
765    }
766
767    #[test]
768    fn test_activation_functions_range() {
769        let tensor = Tensor::from_vec(vec![-10.0, -1.0, 0.0, 1.0, 10.0]);
770
771        // Sigmoid should be in (0, 1)
772        let sigmoid = tensor.sigmoid();
773        for &val in &sigmoid.data {
774            assert!(val > 0.0 && val < 1.0);
775        }
776
777        // Tanh should be in [-1, 1]
778        let tanh = tensor.tanh();
779        for &val in &tanh.data {
780            assert!(val >= -1.0 && val <= 1.0);
781        }
782
783        // ReLU should be non-negative
784        let relu = tensor.relu();
785        for &val in &relu.data {
786            assert!(val >= 0.0);
787        }
788    }
789}