1use crate::error::{GnnError, Result};
11use rand::Rng;
12use rand_distr::{Distribution, Normal, Uniform};
13
14#[derive(Debug, Clone, PartialEq)]
16pub struct Tensor {
17 pub data: Vec<f32>,
19 pub shape: Vec<usize>,
21}
22
23impl Tensor {
24 pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Result<Self> {
36 let expected_len: usize = shape.iter().product();
37 if data.len() != expected_len {
38 return Err(GnnError::invalid_shape(format!(
39 "Data length {} doesn't match shape {:?} (expected {})",
40 data.len(),
41 shape,
42 expected_len
43 )));
44 }
45 Ok(Self { data, shape })
46 }
47
48 pub fn zeros(shape: &[usize]) -> Result<Self> {
59 if shape.is_empty() || shape.iter().any(|&d| d == 0) {
60 return Err(GnnError::invalid_shape(format!(
61 "Invalid shape: {:?}",
62 shape
63 )));
64 }
65 let size: usize = shape.iter().product();
66 Ok(Self {
67 data: vec![0.0; size],
68 shape: shape.to_vec(),
69 })
70 }
71
72 pub fn from_vec(data: Vec<f32>) -> Self {
80 let len = data.len();
81 Self {
82 data,
83 shape: vec![len],
84 }
85 }
86
87 pub fn dot(&self, other: &Tensor) -> Result<f32> {
98 if self.shape.len() != 1 || other.shape.len() != 1 {
99 return Err(GnnError::dimension_mismatch(
100 "1D tensors",
101 format!("{}D and {}D", self.shape.len(), other.shape.len()),
102 ));
103 }
104 if self.shape[0] != other.shape[0] {
105 return Err(GnnError::dimension_mismatch(
106 format!("length {}", self.shape[0]),
107 format!("length {}", other.shape[0]),
108 ));
109 }
110
111 let result = self
112 .data
113 .iter()
114 .zip(other.data.iter())
115 .map(|(a, b)| a * b)
116 .sum();
117 Ok(result)
118 }
119
120 pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
131 match (self.shape.len(), other.shape.len()) {
133 (1, 1) => {
134 let dot = self.dot(other)?;
135 Ok(Tensor::from_vec(vec![dot]))
136 }
137 (2, 1) => {
138 let m = self.shape[0];
140 let n = self.shape[1];
141 if n != other.shape[0] {
142 return Err(GnnError::dimension_mismatch(
143 format!("{}x{}", m, n),
144 format!("vector of length {}", other.shape[0]),
145 ));
146 }
147
148 let mut result = vec![0.0; m];
149 for i in 0..m {
150 for j in 0..n {
151 result[i] += self.data[i * n + j] * other.data[j];
152 }
153 }
154 Ok(Tensor::from_vec(result))
155 }
156 (2, 2) => {
157 let m = self.shape[0];
159 let n = self.shape[1];
160 let p = other.shape[1];
161
162 if n != other.shape[0] {
163 return Err(GnnError::dimension_mismatch(
164 format!("{}x{}", m, n),
165 format!("{}x{}", other.shape[0], p),
166 ));
167 }
168
169 let mut result = vec![0.0; m * p];
170 for i in 0..m {
171 for j in 0..p {
172 for k in 0..n {
173 result[i * p + j] += self.data[i * n + k] * other.data[k * p + j];
174 }
175 }
176 }
177 Tensor::new(result, vec![m, p])
178 }
179 _ => Err(GnnError::dimension_mismatch(
180 "1D or 2D tensors",
181 format!("{}D and {}D", self.shape.len(), other.shape.len()),
182 )),
183 }
184 }
185
186 pub fn add(&self, other: &Tensor) -> Result<Tensor> {
197 if self.shape != other.shape {
198 return Err(GnnError::dimension_mismatch(
199 format!("{:?}", self.shape),
200 format!("{:?}", other.shape),
201 ));
202 }
203
204 let result: Vec<f32> = self
205 .data
206 .iter()
207 .zip(other.data.iter())
208 .map(|(a, b)| a + b)
209 .collect();
210
211 Tensor::new(result, self.shape.clone())
212 }
213
214 pub fn scale(&self, scalar: f32) -> Tensor {
222 let result: Vec<f32> = self.data.iter().map(|&x| x * scalar).collect();
223 Tensor {
224 data: result,
225 shape: self.shape.clone(),
226 }
227 }
228
229 pub fn relu(&self) -> Tensor {
234 let result: Vec<f32> = self.data.iter().map(|&x| x.max(0.0)).collect();
235 Tensor {
236 data: result,
237 shape: self.shape.clone(),
238 }
239 }
240
241 pub fn sigmoid(&self) -> Tensor {
246 let result: Vec<f32> = self
247 .data
248 .iter()
249 .map(|&x| {
250 if x > 0.0 {
251 1.0 / (1.0 + (-x).exp())
252 } else {
253 let ex = x.exp();
254 ex / (1.0 + ex)
255 }
256 })
257 .collect();
258 Tensor {
259 data: result,
260 shape: self.shape.clone(),
261 }
262 }
263
264 pub fn tanh(&self) -> Tensor {
269 let result: Vec<f32> = self.data.iter().map(|&x| x.tanh()).collect();
270 Tensor {
271 data: result,
272 shape: self.shape.clone(),
273 }
274 }
275
276 pub fn l2_norm(&self) -> f32 {
281 let sum_squares: f64 = self.data.iter().map(|&x| (x as f64) * (x as f64)).sum();
283 (sum_squares.sqrt()) as f32
284 }
285
286 pub fn normalize(&self) -> Result<Tensor> {
294 let norm = self.l2_norm();
295 if norm == 0.0 {
296 return Err(GnnError::invalid_input(
297 "Cannot normalize zero vector".to_string(),
298 ));
299 }
300 Ok(self.scale(1.0 / norm))
301 }
302
303 pub fn as_slice(&self) -> &[f32] {
308 &self.data
309 }
310
311 pub fn into_vec(self) -> Vec<f32> {
316 self.data
317 }
318
319 pub fn len(&self) -> usize {
321 self.data.len()
322 }
323
324 pub fn is_empty(&self) -> bool {
326 self.data.is_empty()
327 }
328}
329
330pub fn xavier_init(fan_in: usize, fan_out: usize) -> Vec<f32> {
344 assert!(
345 fan_in > 0 && fan_out > 0,
346 "fan_in and fan_out must be positive"
347 );
348
349 let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
350 let uniform = Uniform::new(-limit, limit);
351 let mut rng = rand::thread_rng();
352
353 (0..fan_in * fan_out)
354 .map(|_| uniform.sample(&mut rng))
355 .collect()
356}
357
358pub fn he_init(fan_in: usize) -> Vec<f32> {
371 assert!(fan_in > 0, "fan_in must be positive");
372
373 let std_dev = (2.0 / fan_in as f32).sqrt();
374 let normal = Normal::new(0.0, std_dev).expect("Invalid normal distribution parameters");
375 let mut rng = rand::thread_rng();
376
377 (0..fan_in).map(|_| normal.sample(&mut rng)).collect()
378}
379
380pub fn hadamard_product(a: &[f32], b: &[f32]) -> Vec<f32> {
392 assert_eq!(a.len(), b.len(), "Vectors must have the same length");
393 a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
394}
395
396pub fn vector_add(a: &[f32], b: &[f32]) -> Vec<f32> {
408 assert_eq!(a.len(), b.len(), "Vectors must have the same length");
409 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
410}
411
412pub fn vector_scale(v: &[f32], scalar: f32) -> Vec<f32> {
421 v.iter().map(|&x| x * scalar).collect()
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 const EPSILON: f32 = 1e-6;
429
430 fn assert_vec_approx_eq(a: &[f32], b: &[f32], epsilon: f32) {
431 assert_eq!(a.len(), b.len(), "Vectors have different lengths");
432 for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
433 assert!(
434 (x - y).abs() < epsilon,
435 "Values at index {} differ: {} vs {} (diff: {})",
436 i,
437 x,
438 y,
439 (x - y).abs()
440 );
441 }
442 }
443
444 #[test]
445 fn test_tensor_new() {
446 let data = vec![1.0, 2.0, 3.0, 4.0];
447 let tensor = Tensor::new(data.clone(), vec![2, 2]).unwrap();
448 assert_eq!(tensor.data, data);
449 assert_eq!(tensor.shape, vec![2, 2]);
450 }
451
452 #[test]
453 fn test_tensor_new_invalid_shape() {
454 let data = vec![1.0, 2.0, 3.0];
455 let result = Tensor::new(data, vec![2, 2]);
456 assert!(result.is_err());
457 }
458
459 #[test]
460 fn test_tensor_zeros() {
461 let tensor = Tensor::zeros(&[3, 2]).unwrap();
462 assert_eq!(tensor.data, vec![0.0; 6]);
463 assert_eq!(tensor.shape, vec![3, 2]);
464 }
465
466 #[test]
467 fn test_tensor_zeros_invalid_shape() {
468 let result = Tensor::zeros(&[0, 2]);
469 assert!(result.is_err());
470
471 let result = Tensor::zeros(&[]);
472 assert!(result.is_err());
473 }
474
475 #[test]
476 fn test_tensor_from_vec() {
477 let data = vec![1.0, 2.0, 3.0];
478 let tensor = Tensor::from_vec(data.clone());
479 assert_eq!(tensor.data, data);
480 assert_eq!(tensor.shape, vec![3]);
481 }
482
483 #[test]
484 fn test_dot_product() {
485 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
486 let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
487 let result = a.dot(&b).unwrap();
488 assert!((result - 32.0).abs() < EPSILON); }
490
491 #[test]
492 fn test_dot_product_dimension_mismatch() {
493 let a = Tensor::from_vec(vec![1.0, 2.0]);
494 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
495 let result = a.dot(&b);
496 assert!(result.is_err());
497 }
498
499 #[test]
500 fn test_matmul_1d() {
501 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
502 let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
503 let result = a.matmul(&b).unwrap();
504 assert_eq!(result.shape, vec![1]);
505 assert!((result.data[0] - 32.0).abs() < EPSILON);
506 }
507
508 #[test]
509 fn test_matmul_2d_1d() {
510 let mat = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
512 let vec = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
513 let result = mat.matmul(&vec).unwrap();
514
515 assert_eq!(result.shape, vec![2]);
516 assert_vec_approx_eq(&result.data, &[14.0, 32.0], EPSILON);
519 }
520
521 #[test]
522 fn test_matmul_2d_2d() {
523 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
525 let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
526 let result = a.matmul(&b).unwrap();
527
528 assert_eq!(result.shape, vec![2, 2]);
529 assert_vec_approx_eq(&result.data, &[19.0, 22.0, 43.0, 50.0], EPSILON);
531 }
532
533 #[test]
534 fn test_matmul_dimension_mismatch() {
535 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
536 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
537 let result = a.matmul(&b);
538 assert!(result.is_err());
539 }
540
541 #[test]
542 fn test_add() {
543 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
544 let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
545 let result = a.add(&b).unwrap();
546 assert_eq!(result.data, vec![5.0, 7.0, 9.0]);
547 }
548
549 #[test]
550 fn test_add_dimension_mismatch() {
551 let a = Tensor::from_vec(vec![1.0, 2.0]);
552 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
553 let result = a.add(&b);
554 assert!(result.is_err());
555 }
556
557 #[test]
558 fn test_scale() {
559 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
560 let result = tensor.scale(2.0);
561 assert_eq!(result.data, vec![2.0, 4.0, 6.0]);
562 }
563
564 #[test]
565 fn test_relu() {
566 let tensor = Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0]);
567 let result = tensor.relu();
568 assert_eq!(result.data, vec![0.0, 0.0, 1.0, 2.0]);
569 }
570
571 #[test]
572 fn test_sigmoid() {
573 let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
574 let result = tensor.sigmoid();
575
576 assert!((result.data[0] - 0.5).abs() < EPSILON);
577 assert!((result.data[1] - 0.7310586).abs() < EPSILON);
578 assert!((result.data[2] - 0.26894143).abs() < EPSILON);
579 }
580
581 #[test]
582 fn test_tanh() {
583 let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0]);
584 let result = tensor.tanh();
585
586 assert!((result.data[0] - 0.0).abs() < EPSILON);
587 assert!((result.data[1] - 0.7615942).abs() < EPSILON);
588 assert!((result.data[2] - (-0.7615942)).abs() < EPSILON);
589 }
590
591 #[test]
592 fn test_l2_norm() {
593 let tensor = Tensor::from_vec(vec![3.0, 4.0]);
594 let norm = tensor.l2_norm();
595 assert!((norm - 5.0).abs() < EPSILON);
596 }
597
598 #[test]
599 fn test_normalize() {
600 let tensor = Tensor::from_vec(vec![3.0, 4.0]);
601 let result = tensor.normalize().unwrap();
602 assert_vec_approx_eq(&result.data, &[0.6, 0.8], EPSILON);
603 assert!((result.l2_norm() - 1.0).abs() < EPSILON);
604 }
605
606 #[test]
607 fn test_normalize_zero_vector() {
608 let tensor = Tensor::from_vec(vec![0.0, 0.0]);
609 let result = tensor.normalize();
610 assert!(result.is_err());
611 }
612
613 #[test]
614 fn test_as_slice() {
615 let data = vec![1.0, 2.0, 3.0];
616 let tensor = Tensor::from_vec(data.clone());
617 assert_eq!(tensor.as_slice(), &data[..]);
618 }
619
620 #[test]
621 fn test_into_vec() {
622 let data = vec![1.0, 2.0, 3.0];
623 let tensor = Tensor::from_vec(data.clone());
624 assert_eq!(tensor.into_vec(), data);
625 }
626
627 #[test]
628 fn test_len() {
629 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
630 assert_eq!(tensor.len(), 3);
631 }
632
633 #[test]
634 fn test_is_empty() {
635 let tensor = Tensor::from_vec(vec![]);
636 assert!(tensor.is_empty());
637
638 let tensor = Tensor::from_vec(vec![1.0]);
639 assert!(!tensor.is_empty());
640 }
641
642 #[test]
643 fn test_xavier_init() {
644 let weights = xavier_init(100, 50);
645 assert_eq!(weights.len(), 5000);
646
647 let limit = (6.0 / 150.0_f32).sqrt();
649 for &w in &weights {
650 assert!(w >= -limit && w <= limit);
651 }
652
653 let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
655 assert!(mean.abs() < 0.1); }
657
658 #[test]
659 #[should_panic(expected = "fan_in and fan_out must be positive")]
660 fn test_xavier_init_zero_fan() {
661 xavier_init(0, 10);
662 }
663
664 #[test]
665 fn test_he_init() {
666 let weights = he_init(100);
667 assert_eq!(weights.len(), 100);
668
669 let mean: f32 = weights.iter().sum::<f32>() / weights.len() as f32;
671 assert!(mean.abs() < 0.2); }
673
674 #[test]
675 #[should_panic(expected = "fan_in must be positive")]
676 fn test_he_init_zero_fan() {
677 he_init(0);
678 }
679
680 #[test]
681 fn test_hadamard_product() {
682 let a = vec![1.0, 2.0, 3.0];
683 let b = vec![4.0, 5.0, 6.0];
684 let result = hadamard_product(&a, &b);
685 assert_eq!(result, vec![4.0, 10.0, 18.0]);
686 }
687
688 #[test]
689 #[should_panic(expected = "Vectors must have the same length")]
690 fn test_hadamard_product_length_mismatch() {
691 let a = vec![1.0, 2.0];
692 let b = vec![1.0, 2.0, 3.0];
693 hadamard_product(&a, &b);
694 }
695
696 #[test]
697 fn test_vector_add() {
698 let a = vec![1.0, 2.0, 3.0];
699 let b = vec![4.0, 5.0, 6.0];
700 let result = vector_add(&a, &b);
701 assert_eq!(result, vec![5.0, 7.0, 9.0]);
702 }
703
704 #[test]
705 #[should_panic(expected = "Vectors must have the same length")]
706 fn test_vector_add_length_mismatch() {
707 let a = vec![1.0, 2.0];
708 let b = vec![1.0, 2.0, 3.0];
709 vector_add(&a, &b);
710 }
711
712 #[test]
713 fn test_vector_scale() {
714 let v = vec![1.0, 2.0, 3.0];
715 let result = vector_scale(&v, 2.5);
716 assert_vec_approx_eq(&result, &[2.5, 5.0, 7.5], EPSILON);
717 }
718
719 #[test]
720 fn test_complex_operations() {
721 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
723 let b = Tensor::from_vec(vec![0.5, 1.0, 1.5]);
724
725 let sum = a.add(&b).unwrap();
726 let scaled = sum.scale(2.0);
727 let activated = scaled.relu();
728 let normalized = activated.normalize().unwrap();
729
730 assert!((normalized.l2_norm() - 1.0).abs() < EPSILON);
731 }
732
733 #[test]
734 fn test_edge_case_single_element() {
735 let tensor = Tensor::from_vec(vec![5.0]);
736 assert_eq!(tensor.len(), 1);
737 assert_eq!(tensor.l2_norm(), 5.0);
738
739 let normalized = tensor.normalize().unwrap();
740 assert_vec_approx_eq(&normalized.data, &[1.0], EPSILON);
741 }
742
743 #[test]
744 fn test_edge_case_negative_values() {
745 let tensor = Tensor::from_vec(vec![-3.0, -4.0]);
746 assert!((tensor.l2_norm() - 5.0).abs() < EPSILON);
747
748 let relu_result = tensor.relu();
749 assert_eq!(relu_result.data, vec![0.0, 0.0]);
750 }
751
752 #[test]
753 fn test_large_matrix_multiplication() {
754 let size = 10;
756 let a_data: Vec<f32> = (0..size * size).map(|i| i as f32).collect();
757 let b_data: Vec<f32> = (0..size * size).map(|i| (i % 2) as f32).collect();
758
759 let a = Tensor::new(a_data, vec![size, size]).unwrap();
760 let b = Tensor::new(b_data, vec![size, size]).unwrap();
761
762 let result = a.matmul(&b).unwrap();
763 assert_eq!(result.shape, vec![size, size]);
764 assert_eq!(result.len(), size * size);
765 }
766
767 #[test]
768 fn test_activation_functions_range() {
769 let tensor = Tensor::from_vec(vec![-10.0, -1.0, 0.0, 1.0, 10.0]);
770
771 let sigmoid = tensor.sigmoid();
773 for &val in &sigmoid.data {
774 assert!(val > 0.0 && val < 1.0);
775 }
776
777 let tanh = tensor.tanh();
779 for &val in &tanh.data {
780 assert!(val >= -1.0 && val <= 1.0);
781 }
782
783 let relu = tensor.relu();
785 for &val in &relu.data {
786 assert!(val >= 0.0);
787 }
788 }
789}