1use crate::error::{GnnError, Result};
11use rand_distr::{Distribution, Normal, Uniform};
12
13#[derive(Debug, Clone, PartialEq)]
15pub struct Tensor {
16 pub data: Vec<f32>,
18 pub shape: Vec<usize>,
20}
21
22impl Tensor {
23 pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Result<Self> {
35 let expected_len: usize = shape.iter().product();
36 if data.len() != expected_len {
37 return Err(GnnError::invalid_shape(format!(
38 "Data length {} doesn't match shape {:?} (expected {})",
39 data.len(),
40 shape,
41 expected_len
42 )));
43 }
44 Ok(Self { data, shape })
45 }
46
47 pub fn zeros(shape: &[usize]) -> Result<Self> {
58 if shape.is_empty() || shape.contains(&0) {
59 return Err(GnnError::invalid_shape(format!(
60 "Invalid shape: {:?}",
61 shape
62 )));
63 }
64 let size: usize = shape.iter().product();
65 Ok(Self {
66 data: vec![0.0; size],
67 shape: shape.to_vec(),
68 })
69 }
70
71 pub fn from_vec(data: Vec<f32>) -> Self {
79 let len = data.len();
80 Self {
81 data,
82 shape: vec![len],
83 }
84 }
85
86 pub fn dot(&self, other: &Tensor) -> Result<f32> {
97 if self.shape.len() != 1 || other.shape.len() != 1 {
98 return Err(GnnError::dimension_mismatch(
99 "1D tensors",
100 format!("{}D and {}D", self.shape.len(), other.shape.len()),
101 ));
102 }
103 if self.shape[0] != other.shape[0] {
104 return Err(GnnError::dimension_mismatch(
105 format!("length {}", self.shape[0]),
106 format!("length {}", other.shape[0]),
107 ));
108 }
109
110 let result = self
111 .data
112 .iter()
113 .zip(other.data.iter())
114 .map(|(a, b)| a * b)
115 .sum();
116 Ok(result)
117 }
118
119 pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
130 match (self.shape.len(), other.shape.len()) {
132 (1, 1) => {
133 let dot = self.dot(other)?;
134 Ok(Tensor::from_vec(vec![dot]))
135 }
136 (2, 1) => {
137 let m = self.shape[0];
139 let n = self.shape[1];
140 if n != other.shape[0] {
141 return Err(GnnError::dimension_mismatch(
142 format!("{}x{}", m, n),
143 format!("vector of length {}", other.shape[0]),
144 ));
145 }
146
147 let mut result = vec![0.0; m];
148 for (i, slot) in result.iter_mut().enumerate().take(m) {
149 for j in 0..n {
150 *slot += self.data[i * n + j] * other.data[j];
151 }
152 }
153 Ok(Tensor::from_vec(result))
154 }
155 (2, 2) => {
156 let m = self.shape[0];
158 let n = self.shape[1];
159 let p = other.shape[1];
160
161 if n != other.shape[0] {
162 return Err(GnnError::dimension_mismatch(
163 format!("{}x{}", m, n),
164 format!("{}x{}", other.shape[0], p),
165 ));
166 }
167
168 let mut result = vec![0.0; m * p];
169 for i in 0..m {
170 for j in 0..p {
171 for k in 0..n {
172 result[i * p + j] += self.data[i * n + k] * other.data[k * p + j];
173 }
174 }
175 }
176 Tensor::new(result, vec![m, p])
177 }
178 _ => Err(GnnError::dimension_mismatch(
179 "1D or 2D tensors",
180 format!("{}D and {}D", self.shape.len(), other.shape.len()),
181 )),
182 }
183 }
184
185 pub fn add(&self, other: &Tensor) -> Result<Tensor> {
196 if self.shape != other.shape {
197 return Err(GnnError::dimension_mismatch(
198 format!("{:?}", self.shape),
199 format!("{:?}", other.shape),
200 ));
201 }
202
203 let result: Vec<f32> = self
204 .data
205 .iter()
206 .zip(other.data.iter())
207 .map(|(a, b)| a + b)
208 .collect();
209
210 Tensor::new(result, self.shape.clone())
211 }
212
213 pub fn scale(&self, scalar: f32) -> Tensor {
221 let result: Vec<f32> = self.data.iter().map(|&x| x * scalar).collect();
222 Tensor {
223 data: result,
224 shape: self.shape.clone(),
225 }
226 }
227
228 pub fn relu(&self) -> Tensor {
233 let result: Vec<f32> = self.data.iter().map(|&x| x.max(0.0)).collect();
234 Tensor {
235 data: result,
236 shape: self.shape.clone(),
237 }
238 }
239
240 pub fn sigmoid(&self) -> Tensor {
245 let result: Vec<f32> = self
246 .data
247 .iter()
248 .map(|&x| {
249 if x > 0.0 {
250 1.0 / (1.0 + (-x).exp())
251 } else {
252 let ex = x.exp();
253 ex / (1.0 + ex)
254 }
255 })
256 .collect();
257 Tensor {
258 data: result,
259 shape: self.shape.clone(),
260 }
261 }
262
263 pub fn tanh(&self) -> Tensor {
268 let result: Vec<f32> = self.data.iter().map(|&x| x.tanh()).collect();
269 Tensor {
270 data: result,
271 shape: self.shape.clone(),
272 }
273 }
274
275 pub fn l2_norm(&self) -> f32 {
280 let sum_squares: f64 = self.data.iter().map(|&x| (x as f64) * (x as f64)).sum();
282 (sum_squares.sqrt()) as f32
283 }
284
285 pub fn normalize(&self) -> Result<Tensor> {
293 let norm = self.l2_norm();
294 if norm == 0.0 {
295 return Err(GnnError::invalid_input(
296 "Cannot normalize zero vector".to_string(),
297 ));
298 }
299 Ok(self.scale(1.0 / norm))
300 }
301
302 pub fn as_slice(&self) -> &[f32] {
307 &self.data
308 }
309
310 pub fn into_vec(self) -> Vec<f32> {
315 self.data
316 }
317
318 pub fn len(&self) -> usize {
320 self.data.len()
321 }
322
323 pub fn is_empty(&self) -> bool {
325 self.data.is_empty()
326 }
327}
328
329pub fn xavier_init(fan_in: usize, fan_out: usize) -> Vec<f32> {
343 assert!(
344 fan_in > 0 && fan_out > 0,
345 "fan_in and fan_out must be positive"
346 );
347
348 let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
349 let uniform = Uniform::new(-limit, limit);
350 let mut rng = rand::thread_rng();
351
352 (0..fan_in * fan_out)
353 .map(|_| uniform.sample(&mut rng))
354 .collect()
355}
356
357pub fn he_init(fan_in: usize) -> Vec<f32> {
370 assert!(fan_in > 0, "fan_in must be positive");
371
372 let std_dev = (2.0 / fan_in as f32).sqrt();
373 let normal = Normal::new(0.0, std_dev).expect("Invalid normal distribution parameters");
374 let mut rng = rand::thread_rng();
375
376 (0..fan_in).map(|_| normal.sample(&mut rng)).collect()
377}
378
379pub fn hadamard_product(a: &[f32], b: &[f32]) -> Vec<f32> {
391 assert_eq!(a.len(), b.len(), "Vectors must have the same length");
392 a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
393}
394
395pub fn vector_add(a: &[f32], b: &[f32]) -> Vec<f32> {
407 assert_eq!(a.len(), b.len(), "Vectors must have the same length");
408 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
409}
410
411pub fn vector_scale(v: &[f32], scalar: f32) -> Vec<f32> {
420 v.iter().map(|&x| x * scalar).collect()
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 const EPSILON: f32 = 1e-6;
428
429 fn assert_vec_approx_eq(a: &[f32], b: &[f32], epsilon: f32) {
430 assert_eq!(a.len(), b.len(), "Vectors have different lengths");
431 for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
432 assert!(
433 (x - y).abs() < epsilon,
434 "Values at index {} differ: {} vs {} (diff: {})",
435 i,
436 x,
437 y,
438 (x - y).abs()
439 );
440 }
441 }
442
443 #[test]
444 fn test_tensor_new() {
445 let data = vec![1.0, 2.0, 3.0, 4.0];
446 let tensor = Tensor::new(data.clone(), vec![2, 2]).unwrap();
447 assert_eq!(tensor.data, data);
448 assert_eq!(tensor.shape, vec![2, 2]);
449 }
450
451 #[test]
452 fn test_tensor_new_invalid_shape() {
453 let data = vec![1.0, 2.0, 3.0];
454 let result = Tensor::new(data, vec![2, 2]);
455 assert!(result.is_err());
456 }
457
458 #[test]
459 fn test_tensor_zeros() {
460 let tensor = Tensor::zeros(&[3, 2]).unwrap();
461 assert_eq!(tensor.data, vec![0.0; 6]);
462 assert_eq!(tensor.shape, vec![3, 2]);
463 }
464
465 #[test]
466 fn test_tensor_zeros_invalid_shape() {
467 let result = Tensor::zeros(&[0, 2]);
468 assert!(result.is_err());
469
470 let result = Tensor::zeros(&[]);
471 assert!(result.is_err());
472 }
473
474 #[test]
475 fn test_tensor_from_vec() {
476 let data = vec![1.0, 2.0, 3.0];
477 let tensor = Tensor::from_vec(data.clone());
478 assert_eq!(tensor.data, data);
479 assert_eq!(tensor.shape, vec![3]);
480 }
481
482 #[test]
483 fn test_dot_product() {
484 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
485 let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
486 let result = a.dot(&b).unwrap();
487 assert!((result - 32.0).abs() < EPSILON); }
489
490 #[test]
491 fn test_dot_product_dimension_mismatch() {
492 let a = Tensor::from_vec(vec![1.0, 2.0]);
493 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
494 let result = a.dot(&b);
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_matmul_1d() {
500 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
501 let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
502 let result = a.matmul(&b).unwrap();
503 assert_eq!(result.shape, vec![1]);
504 assert!((result.data[0] - 32.0).abs() < EPSILON);
505 }
506
507 #[test]
508 fn test_matmul_2d_1d() {
509 let mat = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
511 let vec = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
512 let result = mat.matmul(&vec).unwrap();
513
514 assert_eq!(result.shape, vec![2]);
515 assert_vec_approx_eq(&result.data, &[14.0, 32.0], EPSILON);
518 }
519
520 #[test]
521 fn test_matmul_2d_2d() {
522 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
524 let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
525 let result = a.matmul(&b).unwrap();
526
527 assert_eq!(result.shape, vec![2, 2]);
528 assert_vec_approx_eq(&result.data, &[19.0, 22.0, 43.0, 50.0], EPSILON);
530 }
531
532 #[test]
533 fn test_matmul_dimension_mismatch() {
534 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
535 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
536 let result = a.matmul(&b);
537 assert!(result.is_err());
538 }
539
540 #[test]
541 fn test_add() {
542 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
543 let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
544 let result = a.add(&b).unwrap();
545 assert_eq!(result.data, vec![5.0, 7.0, 9.0]);
546 }
547
548 #[test]
549 fn test_add_dimension_mismatch() {
550 let a = Tensor::from_vec(vec![1.0, 2.0]);
551 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
552 let result = a.add(&b);
553 assert!(result.is_err());
554 }
555
556 #[test]
557 fn test_scale() {
558 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
559 let result = tensor.scale(2.0);
560 assert_eq!(result.data, vec![2.0, 4.0, 6.0]);
561 }
562
563 #[test]
564 fn test_relu() {
565 let tensor = Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0]);
566 let result = tensor.relu();
567 assert_eq!(result.data, vec![0.0, 0.0, 1.0, 2.0]);
568 }
569
570 #[test]
571 fn test_sigmoid() {
572 let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
573 let result = tensor.sigmoid();
574
575 assert!((result.data[0] - 0.5).abs() < EPSILON);
576 assert!((result.data[1] - 0.7310586).abs() < EPSILON);
577 assert!((result.data[2] - 0.26894143).abs() < EPSILON);
578 }
579
580 #[test]
581 fn test_tanh() {
582 let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
583 let result = tensor.tanh();
584
585 assert!((result.data[0] - 0.0).abs() < EPSILON);
586 assert!((result.data[1] - 0.7615942).abs() < EPSILON);
587 assert!((result.data[2] - (-0.7615942)).abs() < EPSILON);
588 }
589
590 #[test]
591 fn test_l2_norm() {
592 let tensor = Tensor::from_vec(vec![3.0, 4.0]);
593 let norm = tensor.l2_norm();
594 assert!((norm - 5.0).abs() < EPSILON);
595 }
596
597 #[test]
598 fn test_normalize() {
599 let tensor = Tensor::from_vec(vec![3.0, 4.0]);
600 let result = tensor.normalize().unwrap();
601 assert_vec_approx_eq(&result.data, &[0.6, 0.8], EPSILON);
602 assert!((result.l2_norm() - 1.0).abs() < EPSILON);
603 }
604
605 #[test]
606 fn test_normalize_zero_vector() {
607 let tensor = Tensor::from_vec(vec![0.0, 0.0]);
608 let result = tensor.normalize();
609 assert!(result.is_err());
610 }
611
612 #[test]
613 fn test_as_slice() {
614 let data = vec![1.0, 2.0, 3.0];
615 let tensor = Tensor::from_vec(data.clone());
616 assert_eq!(tensor.as_slice(), &data[..]);
617 }
618
619 #[test]
620 fn test_into_vec() {
621 let data = vec![1.0, 2.0, 3.0];
622 let tensor = Tensor::from_vec(data.clone());
623 assert_eq!(tensor.into_vec(), data);
624 }
625
626 #[test]
627 fn test_len() {
628 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
629 assert_eq!(tensor.len(), 3);
630 }
631
632 #[test]
633 fn test_is_empty() {
634 let tensor = Tensor::from_vec(vec![]);
635 assert!(tensor.is_empty());
636
637 let tensor = Tensor::from_vec(vec![1.0]);
638 assert!(!tensor.is_empty());
639 }
640
641 #[test]
642 fn test_xavier_init() {
643 let weights = xavier_init(100, 50);
644 assert_eq!(weights.len(), 5000);
645
646 let limit = (6.0 / 150.0_f32).sqrt();
648 for &w in &weights {
649 assert!(w >= -limit && w <= limit);
650 }
651
652 let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
654 assert!(mean.abs() < 0.1); }
656
657 #[test]
658 #[should_panic(expected = "fan_in and fan_out must be positive")]
659 fn test_xavier_init_zero_fan() {
660 xavier_init(0, 10);
661 }
662
663 #[test]
664 fn test_he_init() {
665 let weights = he_init(100);
666 assert_eq!(weights.len(), 100);
667
668 let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
670 assert!(mean.abs() < 0.2); }
672
673 #[test]
674 #[should_panic(expected = "fan_in must be positive")]
675 fn test_he_init_zero_fan() {
676 he_init(0);
677 }
678
679 #[test]
680 fn test_hadamard_product() {
681 let a = vec![1.0, 2.0, 3.0];
682 let b = vec![4.0, 5.0, 6.0];
683 let result = hadamard_product(&a, &b);
684 assert_eq!(result, vec![4.0, 10.0, 18.0]);
685 }
686
687 #[test]
688 #[should_panic(expected = "Vectors must have the same length")]
689 fn test_hadamard_product_length_mismatch() {
690 let a = vec![1.0, 2.0];
691 let b = vec![1.0, 2.0, 3.0];
692 hadamard_product(&a, &b);
693 }
694
695 #[test]
696 fn test_vector_add() {
697 let a = vec![1.0, 2.0, 3.0];
698 let b = vec![4.0, 5.0, 6.0];
699 let result = vector_add(&a, &b);
700 assert_eq!(result, vec![5.0, 7.0, 9.0]);
701 }
702
703 #[test]
704 #[should_panic(expected = "Vectors must have the same length")]
705 fn test_vector_add_length_mismatch() {
706 let a = vec![1.0, 2.0];
707 let b = vec![1.0, 2.0, 3.0];
708 vector_add(&a, &b);
709 }
710
711 #[test]
712 fn test_vector_scale() {
713 let v = vec![1.0, 2.0, 3.0];
714 let result = vector_scale(&v, 2.5);
715 assert_vec_approx_eq(&result, &[2.5, 5.0, 7.5], EPSILON);
716 }
717
718 #[test]
719 fn test_complex_operations() {
720 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
722 let b = Tensor::from_vec(vec![0.5, 1.0, 1.5]);
723
724 let sum = a.add(&b).unwrap();
725 let scaled = sum.scale(2.0);
726 let activated = scaled.relu();
727 let normalized = activated.normalize().unwrap();
728
729 assert!((normalized.l2_norm() - 1.0).abs() < EPSILON);
730 }
731
732 #[test]
733 fn test_edge_case_single_element() {
734 let tensor = Tensor::from_vec(vec![5.0]);
735 assert_eq!(tensor.len(), 1);
736 assert_eq!(tensor.l2_norm(), 5.0);
737
738 let normalized = tensor.normalize().unwrap();
739 assert_vec_approx_eq(&normalized.data, &[1.0], EPSILON);
740 }
741
742 #[test]
743 fn test_edge_case_negative_values() {
744 let tensor = Tensor::from_vec(vec![-3.0, -4.0]);
745 assert!((tensor.l2_norm() - 5.0).abs() < EPSILON);
746
747 let relu_result = tensor.relu();
748 assert_eq!(relu_result.data, vec![0.0, 0.0]);
749 }
750
751 #[test]
752 fn test_large_matrix_multiplication() {
753 let size = 10;
755 let a_data: Vec<f32> = (0..size * size).map(|i| i as f32).collect();
756 let b_data: Vec<f32> = (0..size * size).map(|i| (i % 2) as f32).collect();
757
758 let a = Tensor::new(a_data, vec![size, size]).unwrap();
759 let b = Tensor::new(b_data, vec![size, size]).unwrap();
760
761 let result = a.matmul(&b).unwrap();
762 assert_eq!(result.shape, vec![size, size]);
763 assert_eq!(result.len(), size * size);
764 }
765
766 #[test]
767 fn test_activation_functions_range() {
768 let tensor = Tensor::from_vec(vec![-10.0, -1.0, 0.0, 1.0, 10.0]);
769
770 let sigmoid = tensor.sigmoid();
772 for &val in &sigmoid.data {
773 assert!(val > 0.0 && val < 1.0);
774 }
775
776 let tanh = tensor.tanh();
778 for &val in &tanh.data {
779 assert!((-1.0..=1.0).contains(&val));
780 }
781
782 let relu = tensor.relu();
784 for &val in &relu.data {
785 assert!(val >= 0.0);
786 }
787 }
788}