Skip to main content

svod_tensor/
traits.rs

1use super::*;
2
3/// Macro to implement binary operation traits for Tensor.
4///
5/// Generates all 4 ownership combinations:
6/// - &Tensor op &Tensor (primary implementation, calls try_* method)
7/// - Tensor op Tensor (forwards to &self op &other)
8/// - &Tensor op Tensor (forwards to self op &other)
9/// - Tensor op &Tensor (forwards to &self op other)
10macro_rules! impl_binary_op {
11    ($trait:ident, $method:ident, $try_method:ident, $error_msg:expr) => {
12        impl std::ops::$trait for &Tensor {
13            type Output = Tensor;
14
15            #[track_caller]
16            fn $method(self, other: &Tensor) -> Tensor {
17                self.$try_method(other).expect($error_msg)
18            }
19        }
20
21        impl std::ops::$trait for Tensor {
22            type Output = Tensor;
23
24            #[track_caller]
25            fn $method(self, other: Tensor) -> Tensor {
26                (&self).$method(&other)
27            }
28        }
29
30        impl std::ops::$trait<Tensor> for &Tensor {
31            type Output = Tensor;
32
33            #[track_caller]
34            fn $method(self, other: Tensor) -> Tensor {
35                self.$method(&other)
36            }
37        }
38
39        impl std::ops::$trait<&Tensor> for Tensor {
40            type Output = Tensor;
41
42            #[track_caller]
43            fn $method(self, other: &Tensor) -> Tensor {
44                (&self).$method(other)
45            }
46        }
47    };
48}
49
50// Binary arithmetic operations
51impl_binary_op!(Add, add, try_add, "Addition failed");
52impl_binary_op!(Sub, sub, try_sub, "Subtraction failed");
53impl_binary_op!(Mul, mul, try_mul, "Multiplication failed");
54impl_binary_op!(Div, div, try_div, "Division failed");
55
56// Unary operations
57impl std::ops::Neg for &Tensor {
58    type Output = Tensor;
59
60    #[track_caller]
61    fn neg(self) -> Tensor {
62        self.try_neg().expect("Negation failed")
63    }
64}
65
66impl std::ops::Neg for Tensor {
67    type Output = Tensor;
68
69    #[track_caller]
70    fn neg(self) -> Tensor {
71        (&self).neg()
72    }
73}