Skip to main content

ruvector_gnn/
tensor.rs

1//! Tensor operations for GNN computations.
2//!
3//! Provides efficient tensor operations including:
4//! - Matrix multiplication
5//! - Element-wise operations
6//! - Activation functions
7//! - Weight initialization
8//! - Normalization
9
10use crate::error::{GnnError, Result};
11use rand_distr::{Distribution, Normal, Uniform};
12
13/// Basic tensor operations for GNN computations
14#[derive(Debug, Clone, PartialEq)]
15pub struct Tensor {
16    /// Flattened tensor data
17    pub data: Vec<f32>,
18    /// Shape of the tensor (dimensions)
19    pub shape: Vec<usize>,
20}
21
22impl Tensor {
23    /// Create a new tensor from data and shape
24    ///
25    /// # Arguments
26    /// * `data` - Flattened tensor data
27    /// * `shape` - Dimensions of the tensor
28    ///
29    /// # Returns
30    /// A new `Tensor` instance
31    ///
32    /// # Errors
33    /// Returns `GnnError::InvalidShape` if data length doesn't match shape
34    pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Result<Self> {
35        let expected_len: usize = shape.iter().product();
36        if data.len() != expected_len {
37            return Err(GnnError::invalid_shape(format!(
38                "Data length {} doesn't match shape {:?} (expected {})",
39                data.len(),
40                shape,
41                expected_len
42            )));
43        }
44        Ok(Self { data, shape })
45    }
46
47    /// Create a zero-filled tensor with the given shape
48    ///
49    /// # Arguments
50    /// * `shape` - Dimensions of the tensor
51    ///
52    /// # Returns
53    /// A new zero-filled `Tensor`
54    ///
55    /// # Errors
56    /// Returns `GnnError::InvalidShape` if shape is empty or contains zero
57    pub fn zeros(shape: &[usize]) -> Result<Self> {
58        if shape.is_empty() || shape.contains(&0) {
59            return Err(GnnError::invalid_shape(format!(
60                "Invalid shape: {:?}",
61                shape
62            )));
63        }
64        let size: usize = shape.iter().product();
65        Ok(Self {
66            data: vec![0.0; size],
67            shape: shape.to_vec(),
68        })
69    }
70
71    /// Create a 1D tensor from a vector
72    ///
73    /// # Arguments
74    /// * `data` - Vector data
75    ///
76    /// # Returns
77    /// A new 1D `Tensor`
78    pub fn from_vec(data: Vec<f32>) -> Self {
79        let len = data.len();
80        Self {
81            data,
82            shape: vec![len],
83        }
84    }
85
86    /// Compute dot product with another tensor (both must be 1D)
87    ///
88    /// # Arguments
89    /// * `other` - Another tensor to compute dot product with
90    ///
91    /// # Returns
92    /// The dot product as a scalar
93    ///
94    /// # Errors
95    /// Returns `GnnError::DimensionMismatch` if tensors are not 1D or have different lengths
96    pub fn dot(&self, other: &Tensor) -> Result<f32> {
97        if self.shape.len() != 1 || other.shape.len() != 1 {
98            return Err(GnnError::dimension_mismatch(
99                "1D tensors",
100                format!("{}D and {}D", self.shape.len(), other.shape.len()),
101            ));
102        }
103        if self.shape[0] != other.shape[0] {
104            return Err(GnnError::dimension_mismatch(
105                format!("length {}", self.shape[0]),
106                format!("length {}", other.shape[0]),
107            ));
108        }
109
110        let result = self
111            .data
112            .iter()
113            .zip(other.data.iter())
114            .map(|(a, b)| a * b)
115            .sum();
116        Ok(result)
117    }
118
119    /// Matrix multiplication
120    ///
121    /// # Arguments
122    /// * `other` - Another tensor to multiply with
123    ///
124    /// # Returns
125    /// The result of matrix multiplication
126    ///
127    /// # Errors
128    /// Returns `GnnError::DimensionMismatch` if dimensions are incompatible
129    pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
130        // Support 1D x 1D (dot product), 2D x 1D, 2D x 2D
131        match (self.shape.len(), other.shape.len()) {
132            (1, 1) => {
133                let dot = self.dot(other)?;
134                Ok(Tensor::from_vec(vec![dot]))
135            }
136            (2, 1) => {
137                // Matrix-vector multiplication
138                let m = self.shape[0];
139                let n = self.shape[1];
140                if n != other.shape[0] {
141                    return Err(GnnError::dimension_mismatch(
142                        format!("{}x{}", m, n),
143                        format!("vector of length {}", other.shape[0]),
144                    ));
145                }
146
147                let mut result = vec![0.0; m];
148                for (i, slot) in result.iter_mut().enumerate().take(m) {
149                    for j in 0..n {
150                        *slot += self.data[i * n + j] * other.data[j];
151                    }
152                }
153                Ok(Tensor::from_vec(result))
154            }
155            (2, 2) => {
156                // Matrix-matrix multiplication
157                let m = self.shape[0];
158                let n = self.shape[1];
159                let p = other.shape[1];
160
161                if n != other.shape[0] {
162                    return Err(GnnError::dimension_mismatch(
163                        format!("{}x{}", m, n),
164                        format!("{}x{}", other.shape[0], p),
165                    ));
166                }
167
168                let mut result = vec![0.0; m * p];
169                for i in 0..m {
170                    for j in 0..p {
171                        for k in 0..n {
172                            result[i * p + j] += self.data[i * n + k] * other.data[k * p + j];
173                        }
174                    }
175                }
176                Tensor::new(result, vec![m, p])
177            }
178            _ => Err(GnnError::dimension_mismatch(
179                "1D or 2D tensors",
180                format!("{}D and {}D", self.shape.len(), other.shape.len()),
181            )),
182        }
183    }
184
185    /// Element-wise addition
186    ///
187    /// # Arguments
188    /// * `other` - Another tensor to add
189    ///
190    /// # Returns
191    /// The sum of the two tensors
192    ///
193    /// # Errors
194    /// Returns `GnnError::DimensionMismatch` if shapes don't match
195    pub fn add(&self, other: &Tensor) -> Result<Tensor> {
196        if self.shape != other.shape {
197            return Err(GnnError::dimension_mismatch(
198                format!("{:?}", self.shape),
199                format!("{:?}", other.shape),
200            ));
201        }
202
203        let result: Vec<f32> = self
204            .data
205            .iter()
206            .zip(other.data.iter())
207            .map(|(a, b)| a + b)
208            .collect();
209
210        Tensor::new(result, self.shape.clone())
211    }
212
213    /// Scalar multiplication
214    ///
215    /// # Arguments
216    /// * `scalar` - Scalar value to multiply by
217    ///
218    /// # Returns
219    /// A new tensor with all elements scaled
220    pub fn scale(&self, scalar: f32) -> Tensor {
221        let result: Vec<f32> = self.data.iter().map(|&x| x * scalar).collect();
222        Tensor {
223            data: result,
224            shape: self.shape.clone(),
225        }
226    }
227
228    /// ReLU activation function (max(0, x))
229    ///
230    /// # Returns
231    /// A new tensor with ReLU applied element-wise
232    pub fn relu(&self) -> Tensor {
233        let result: Vec<f32> = self.data.iter().map(|&x| x.max(0.0)).collect();
234        Tensor {
235            data: result,
236            shape: self.shape.clone(),
237        }
238    }
239
240    /// Sigmoid activation function (1 / (1 + e^(-x))) with numerical stability
241    ///
242    /// # Returns
243    /// A new tensor with sigmoid applied element-wise
244    pub fn sigmoid(&self) -> Tensor {
245        let result: Vec<f32> = self
246            .data
247            .iter()
248            .map(|&x| {
249                if x > 0.0 {
250                    1.0 / (1.0 + (-x).exp())
251                } else {
252                    let ex = x.exp();
253                    ex / (1.0 + ex)
254                }
255            })
256            .collect();
257        Tensor {
258            data: result,
259            shape: self.shape.clone(),
260        }
261    }
262
263    /// Tanh activation function
264    ///
265    /// # Returns
266    /// A new tensor with tanh applied element-wise
267    pub fn tanh(&self) -> Tensor {
268        let result: Vec<f32> = self.data.iter().map(|&x| x.tanh()).collect();
269        Tensor {
270            data: result,
271            shape: self.shape.clone(),
272        }
273    }
274
275    /// Compute L2 norm (Euclidean norm) with improved precision
276    ///
277    /// # Returns
278    /// The L2 norm of the tensor
279    pub fn l2_norm(&self) -> f32 {
280        // Use f64 accumulator for better numerical precision
281        let sum_squares: f64 = self.data.iter().map(|&x| (x as f64) * (x as f64)).sum();
282        (sum_squares.sqrt()) as f32
283    }
284
285    /// Normalize the tensor to unit L2 norm
286    ///
287    /// # Returns
288    /// A normalized tensor
289    ///
290    /// # Errors
291    /// Returns `GnnError::InvalidInput` if norm is zero
292    pub fn normalize(&self) -> Result<Tensor> {
293        let norm = self.l2_norm();
294        if norm == 0.0 {
295            return Err(GnnError::invalid_input(
296                "Cannot normalize zero vector".to_string(),
297            ));
298        }
299        Ok(self.scale(1.0 / norm))
300    }
301
302    /// Get a slice view of the tensor data
303    ///
304    /// # Returns
305    /// A slice reference to the underlying data
306    pub fn as_slice(&self) -> &[f32] {
307        &self.data
308    }
309
310    /// Consume the tensor and return the underlying vector
311    ///
312    /// # Returns
313    /// The vector containing the tensor data
314    pub fn into_vec(self) -> Vec<f32> {
315        self.data
316    }
317
318    /// Get the number of elements in the tensor
319    pub fn len(&self) -> usize {
320        self.data.len()
321    }
322
323    /// Check if the tensor is empty
324    pub fn is_empty(&self) -> bool {
325        self.data.is_empty()
326    }
327}
328
329/// Xavier/Glorot initialization for neural network weights
330///
331/// Samples from uniform distribution U(-a, a) where a = sqrt(6 / (fan_in + fan_out))
332///
333/// # Arguments
334/// * `fan_in` - Number of input units
335/// * `fan_out` - Number of output units
336///
337/// # Returns
338/// A vector of initialized weights
339///
340/// # Panics
341/// Panics if fan_in or fan_out is 0
342pub fn xavier_init(fan_in: usize, fan_out: usize) -> Vec<f32> {
343    assert!(
344        fan_in > 0 && fan_out > 0,
345        "fan_in and fan_out must be positive"
346    );
347
348    let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
349    let uniform = Uniform::new(-limit, limit);
350    let mut rng = rand::thread_rng();
351
352    (0..fan_in * fan_out)
353        .map(|_| uniform.sample(&mut rng))
354        .collect()
355}
356
357/// He initialization for ReLU networks
358///
359/// Samples from normal distribution N(0, sqrt(2 / fan_in))
360///
361/// # Arguments
362/// * `fan_in` - Number of input units
363///
364/// # Returns
365/// A vector of initialized weights
366///
367/// # Panics
368/// Panics if fan_in is 0
369pub fn he_init(fan_in: usize) -> Vec<f32> {
370    assert!(fan_in > 0, "fan_in must be positive");
371
372    let std_dev = (2.0 / fan_in as f32).sqrt();
373    let normal = Normal::new(0.0, std_dev).expect("Invalid normal distribution parameters");
374    let mut rng = rand::thread_rng();
375
376    (0..fan_in).map(|_| normal.sample(&mut rng)).collect()
377}
378
379/// Element-wise (Hadamard) product
380///
381/// # Arguments
382/// * `a` - First vector
383/// * `b` - Second vector
384///
385/// # Returns
386/// Element-wise product of the two vectors
387///
388/// # Panics
389/// Panics if vectors have different lengths
390pub fn hadamard_product(a: &[f32], b: &[f32]) -> Vec<f32> {
391    assert_eq!(a.len(), b.len(), "Vectors must have the same length");
392    a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
393}
394
395/// Element-wise vector addition
396///
397/// # Arguments
398/// * `a` - First vector
399/// * `b` - Second vector
400///
401/// # Returns
402/// Element-wise sum of the two vectors
403///
404/// # Panics
405/// Panics if vectors have different lengths
406pub fn vector_add(a: &[f32], b: &[f32]) -> Vec<f32> {
407    assert_eq!(a.len(), b.len(), "Vectors must have the same length");
408    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
409}
410
411/// Scalar multiplication of a vector
412///
413/// # Arguments
414/// * `v` - Input vector
415/// * `scalar` - Scalar multiplier
416///
417/// # Returns
418/// Vector with all elements multiplied by scalar
419pub fn vector_scale(v: &[f32], scalar: f32) -> Vec<f32> {
420    v.iter().map(|&x| x * scalar).collect()
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    const EPSILON: f32 = 1e-6;
428
429    fn assert_vec_approx_eq(a: &[f32], b: &[f32], epsilon: f32) {
430        assert_eq!(a.len(), b.len(), "Vectors have different lengths");
431        for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
432            assert!(
433                (x - y).abs() < epsilon,
434                "Values at index {} differ: {} vs {} (diff: {})",
435                i,
436                x,
437                y,
438                (x - y).abs()
439            );
440        }
441    }
442
443    #[test]
444    fn test_tensor_new() {
445        let data = vec![1.0, 2.0, 3.0, 4.0];
446        let tensor = Tensor::new(data.clone(), vec![2, 2]).unwrap();
447        assert_eq!(tensor.data, data);
448        assert_eq!(tensor.shape, vec![2, 2]);
449    }
450
451    #[test]
452    fn test_tensor_new_invalid_shape() {
453        let data = vec![1.0, 2.0, 3.0];
454        let result = Tensor::new(data, vec![2, 2]);
455        assert!(result.is_err());
456    }
457
458    #[test]
459    fn test_tensor_zeros() {
460        let tensor = Tensor::zeros(&[3, 2]).unwrap();
461        assert_eq!(tensor.data, vec![0.0; 6]);
462        assert_eq!(tensor.shape, vec![3, 2]);
463    }
464
465    #[test]
466    fn test_tensor_zeros_invalid_shape() {
467        let result = Tensor::zeros(&[0, 2]);
468        assert!(result.is_err());
469
470        let result = Tensor::zeros(&[]);
471        assert!(result.is_err());
472    }
473
474    #[test]
475    fn test_tensor_from_vec() {
476        let data = vec![1.0, 2.0, 3.0];
477        let tensor = Tensor::from_vec(data.clone());
478        assert_eq!(tensor.data, data);
479        assert_eq!(tensor.shape, vec![3]);
480    }
481
482    #[test]
483    fn test_dot_product() {
484        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
485        let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
486        let result = a.dot(&b).unwrap();
487        assert!((result - 32.0).abs() < EPSILON); // 1*4 + 2*5 + 3*6 = 32
488    }
489
490    #[test]
491    fn test_dot_product_dimension_mismatch() {
492        let a = Tensor::from_vec(vec![1.0, 2.0]);
493        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
494        let result = a.dot(&b);
495        assert!(result.is_err());
496    }
497
498    #[test]
499    fn test_matmul_1d() {
500        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
501        let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
502        let result = a.matmul(&b).unwrap();
503        assert_eq!(result.shape, vec![1]);
504        assert!((result.data[0] - 32.0).abs() < EPSILON);
505    }
506
507    #[test]
508    fn test_matmul_2d_1d() {
509        // Matrix-vector multiplication
510        let mat = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
511        let vec = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
512        let result = mat.matmul(&vec).unwrap();
513
514        assert_eq!(result.shape, vec![2]);
515        // [1,2,3] * [1,2,3]' = 14
516        // [4,5,6] * [1,2,3]' = 32
517        assert_vec_approx_eq(&result.data, &[14.0, 32.0], EPSILON);
518    }
519
520    #[test]
521    fn test_matmul_2d_2d() {
522        // Matrix-matrix multiplication
523        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
524        let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
525        let result = a.matmul(&b).unwrap();
526
527        assert_eq!(result.shape, vec![2, 2]);
528        // [[1,2], [3,4]] * [[5,6], [7,8]] = [[19,22], [43,50]]
529        assert_vec_approx_eq(&result.data, &[19.0, 22.0, 43.0, 50.0], EPSILON);
530    }
531
532    #[test]
533    fn test_matmul_dimension_mismatch() {
534        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
535        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
536        let result = a.matmul(&b);
537        assert!(result.is_err());
538    }
539
540    #[test]
541    fn test_add() {
542        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
543        let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
544        let result = a.add(&b).unwrap();
545        assert_eq!(result.data, vec![5.0, 7.0, 9.0]);
546    }
547
548    #[test]
549    fn test_add_dimension_mismatch() {
550        let a = Tensor::from_vec(vec![1.0, 2.0]);
551        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
552        let result = a.add(&b);
553        assert!(result.is_err());
554    }
555
556    #[test]
557    fn test_scale() {
558        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
559        let result = tensor.scale(2.0);
560        assert_eq!(result.data, vec![2.0, 4.0, 6.0]);
561    }
562
563    #[test]
564    fn test_relu() {
565        let tensor = Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0]);
566        let result = tensor.relu();
567        assert_eq!(result.data, vec![0.0, 0.0, 1.0, 2.0]);
568    }
569
570    #[test]
571    fn test_sigmoid() {
572        let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
573        let result = tensor.sigmoid();
574
575        assert!((result.data[0] - 0.5).abs() < EPSILON);
576        assert!((result.data[1] - 0.7310586).abs() < EPSILON);
577        assert!((result.data[2] - 0.26894143).abs() < EPSILON);
578    }
579
580    #[test]
581    fn test_tanh() {
582        let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
583        let result = tensor.tanh();
584
585        assert!((result.data[0] - 0.0).abs() < EPSILON);
586        assert!((result.data[1] - 0.7615942).abs() < EPSILON);
587        assert!((result.data[2] - (-0.7615942)).abs() < EPSILON);
588    }
589
590    #[test]
591    fn test_l2_norm() {
592        let tensor = Tensor::from_vec(vec![3.0, 4.0]);
593        let norm = tensor.l2_norm();
594        assert!((norm - 5.0).abs() < EPSILON);
595    }
596
597    #[test]
598    fn test_normalize() {
599        let tensor = Tensor::from_vec(vec![3.0, 4.0]);
600        let result = tensor.normalize().unwrap();
601        assert_vec_approx_eq(&result.data, &[0.6, 0.8], EPSILON);
602        assert!((result.l2_norm() - 1.0).abs() < EPSILON);
603    }
604
605    #[test]
606    fn test_normalize_zero_vector() {
607        let tensor = Tensor::from_vec(vec![0.0, 0.0]);
608        let result = tensor.normalize();
609        assert!(result.is_err());
610    }
611
612    #[test]
613    fn test_as_slice() {
614        let data = vec![1.0, 2.0, 3.0];
615        let tensor = Tensor::from_vec(data.clone());
616        assert_eq!(tensor.as_slice(), &data[..]);
617    }
618
619    #[test]
620    fn test_into_vec() {
621        let data = vec![1.0, 2.0, 3.0];
622        let tensor = Tensor::from_vec(data.clone());
623        assert_eq!(tensor.into_vec(), data);
624    }
625
626    #[test]
627    fn test_len() {
628        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
629        assert_eq!(tensor.len(), 3);
630    }
631
632    #[test]
633    fn test_is_empty() {
634        let tensor = Tensor::from_vec(vec![]);
635        assert!(tensor.is_empty());
636
637        let tensor = Tensor::from_vec(vec![1.0]);
638        assert!(!tensor.is_empty());
639    }
640
641    #[test]
642    fn test_xavier_init() {
643        let weights = xavier_init(100, 50);
644        assert_eq!(weights.len(), 5000);
645
646        // Check that values are in expected range
647        let limit = (6.0 / 150.0_f32).sqrt();
648        for &w in &weights {
649            assert!(w >= -limit && w <= limit);
650        }
651
652        // Check distribution properties
653        let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
654        assert!(mean.abs() < 0.1); // Mean should be close to 0
655    }
656
657    #[test]
658    #[should_panic(expected = "fan_in and fan_out must be positive")]
659    fn test_xavier_init_zero_fan() {
660        xavier_init(0, 10);
661    }
662
663    #[test]
664    fn test_he_init() {
665        let weights = he_init(100);
666        assert_eq!(weights.len(), 100);
667
668        // Check distribution properties
669        let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
670        assert!(mean.abs() < 0.2); // Mean should be close to 0
671    }
672
673    #[test]
674    #[should_panic(expected = "fan_in must be positive")]
675    fn test_he_init_zero_fan() {
676        he_init(0);
677    }
678
679    #[test]
680    fn test_hadamard_product() {
681        let a = vec![1.0, 2.0, 3.0];
682        let b = vec![4.0, 5.0, 6.0];
683        let result = hadamard_product(&a, &b);
684        assert_eq!(result, vec![4.0, 10.0, 18.0]);
685    }
686
687    #[test]
688    #[should_panic(expected = "Vectors must have the same length")]
689    fn test_hadamard_product_length_mismatch() {
690        let a = vec![1.0, 2.0];
691        let b = vec![1.0, 2.0, 3.0];
692        hadamard_product(&a, &b);
693    }
694
695    #[test]
696    fn test_vector_add() {
697        let a = vec![1.0, 2.0, 3.0];
698        let b = vec![4.0, 5.0, 6.0];
699        let result = vector_add(&a, &b);
700        assert_eq!(result, vec![5.0, 7.0, 9.0]);
701    }
702
703    #[test]
704    #[should_panic(expected = "Vectors must have the same length")]
705    fn test_vector_add_length_mismatch() {
706        let a = vec![1.0, 2.0];
707        let b = vec![1.0, 2.0, 3.0];
708        vector_add(&a, &b);
709    }
710
711    #[test]
712    fn test_vector_scale() {
713        let v = vec![1.0, 2.0, 3.0];
714        let result = vector_scale(&v, 2.5);
715        assert_vec_approx_eq(&result, &[2.5, 5.0, 7.5], EPSILON);
716    }
717
718    #[test]
719    fn test_complex_operations() {
720        // Test chaining operations
721        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
722        let b = Tensor::from_vec(vec![0.5, 1.0, 1.5]);
723
724        let sum = a.add(&b).unwrap();
725        let scaled = sum.scale(2.0);
726        let activated = scaled.relu();
727        let normalized = activated.normalize().unwrap();
728
729        assert!((normalized.l2_norm() - 1.0).abs() < EPSILON);
730    }
731
732    #[test]
733    fn test_edge_case_single_element() {
734        let tensor = Tensor::from_vec(vec![5.0]);
735        assert_eq!(tensor.len(), 1);
736        assert_eq!(tensor.l2_norm(), 5.0);
737
738        let normalized = tensor.normalize().unwrap();
739        assert_vec_approx_eq(&normalized.data, &[1.0], EPSILON);
740    }
741
742    #[test]
743    fn test_edge_case_negative_values() {
744        let tensor = Tensor::from_vec(vec![-3.0, -4.0]);
745        assert!((tensor.l2_norm() - 5.0).abs() < EPSILON);
746
747        let relu_result = tensor.relu();
748        assert_eq!(relu_result.data, vec![0.0, 0.0]);
749    }
750
751    #[test]
752    fn test_large_matrix_multiplication() {
753        // 10x10 matrix multiplication
754        let size = 10;
755        let a_data: Vec<f32> = (0..size * size).map(|i| i as f32).collect();
756        let b_data: Vec<f32> = (0..size * size).map(|i| (i % 2) as f32).collect();
757
758        let a = Tensor::new(a_data, vec![size, size]).unwrap();
759        let b = Tensor::new(b_data, vec![size, size]).unwrap();
760
761        let result = a.matmul(&b).unwrap();
762        assert_eq!(result.shape, vec![size, size]);
763        assert_eq!(result.len(), size * size);
764    }
765
766    #[test]
767    fn test_activation_functions_range() {
768        let tensor = Tensor::from_vec(vec![-10.0, -1.0, 0.0, 1.0, 10.0]);
769
770        // Sigmoid should be in (0, 1)
771        let sigmoid = tensor.sigmoid();
772        for &val in &sigmoid.data {
773            assert!(val > 0.0 && val < 1.0);
774        }
775
776        // Tanh should be in [-1, 1]
777        let tanh = tensor.tanh();
778        for &val in &tanh.data {
779            assert!((-1.0..=1.0).contains(&val));
780        }
781
782        // ReLU should be non-negative
783        let relu = tensor.relu();
784        for &val in &relu.data {
785            assert!(val >= 0.0);
786        }
787    }
788}