Skip to main content

svod_tensor/
arithmetic.rs

1use snafu::ResultExt;
2
3use super::*;
4
5/// Unified macro for implementing Tensor operations.
6///
7/// Automatically handles:
8/// - Binary operations (with `other` parameter): Always use Result path
9/// - Unary operations with `@infallible` marker: Wrap in Ok()
10/// - Unary operations without marker: Use Result path
11macro_rules! impl_tensor_ops {
12    (
13        binary { $($bin_method:ident => $bin_uop:ident),* $(,)? }
14        unary_infallible { $($inf_method:ident => $inf_uop:ident),* $(,)? }
15        unary_fallible { $($fall_method:ident => $fall_uop:ident),* $(,)? }
16    ) => {
17        // Binary operations (with automatic broadcasting)
18        $(
19            #[track_caller]
20            pub fn $bin_method(&self, other: &Tensor) -> Result<Tensor> {
21                // Broadcast tensors to common shape
22                let (lhs, rhs) = self.broadcast_for_binop(other)?;
23
24                // Now call UOp operation with matching shapes
25                lhs.uop().$bin_uop(&rhs.uop()).map(Self::new).context(UOpSnafu)
26            }
27        )*
28
29        // Unary infallible operations
30        $(
31            #[track_caller]
32            pub fn $inf_method(&self) -> Result<Tensor> {
33                Ok(Self::new(self.uop().$inf_uop()))
34            }
35        )*
36
37        // Unary fallible operations
38        $(
39            #[track_caller]
40            pub fn $fall_method(&self) -> Result<Tensor> {
41                self.uop().$fall_uop().map(Self::new).context(UOpSnafu)
42            }
43        )*
44    };
45}
46
47impl Tensor {
48    impl_tensor_ops! {
49        binary {
50            try_add => try_add,
51            try_sub => try_sub,
52            try_mul => try_mul,
53            try_div => try_div,
54            try_mod => try_mod,
55            try_pow => try_pow,
56            try_eq => try_cmpeq,
57            try_ne => try_cmpne,
58            try_lt => try_cmplt,
59            try_le => try_cmple,
60            try_gt => try_cmpgt,
61            try_ge => try_cmpge,
62            try_bitor => try_or_op,
63            try_bitand => try_and_op,
64            try_bitxor => try_xor_op,
65            try_shl => try_shl_op,
66            try_shr => try_shr_op,
67        }
68        unary_infallible {
69            try_neg => neg,
70            try_abs => abs,
71        }
72        unary_fallible {
73            try_sqrt => try_sqrt,
74            try_rsqrt => try_rsqrt,
75            try_exp => try_exp,
76            try_exp2 => try_exp2,
77            try_log => try_log,
78            try_log2 => try_log2,
79        }
80    }
81
82    /// Logical NOT for boolean tensors.
83    ///
84    /// Converts to boolean dtype and applies logical negation.
85    /// For non-boolean tensors, treats zero as false, non-zero as true.
86    ///
87    /// # Examples
88    /// ```ignore
89    /// let t = Tensor::from_slice(&[true, false, true]);
90    /// let result = t.logical_not()?;  // [false, true, false]
91    ///
92    /// let nums = Tensor::from_slice(&[0.0f32, 1.0, 2.0]);
93    /// let result = nums.logical_not()?;  // [true, false, false]
94    /// ```
95    pub fn logical_not(&self) -> Result<Tensor> {
96        use svod_dtype::DType;
97
98        // Cast to bool (non-zero becomes true)
99        let as_bool = self.cast(DType::Bool)?;
100
101        // Create true constant tensor and broadcast to match shape
102        let true_scalar = Self::from_slice([true]);
103        let self_shape = as_bool.shape()?;
104
105        let true_broadcast = if self_shape.is_empty() {
106            // Input is scalar - reshape [1] to []
107            true_scalar.try_reshape(&[] as &[isize])?
108        } else {
109            // Broadcast to match non-scalar shape
110            true_scalar.broadcast_to(&self_shape)?
111        };
112
113        // Compare: !x ≡ (x != true)
114        as_bool.try_ne(&true_broadcast)
115    }
116
117    /// Bitwise NOT for integer tensors.
118    ///
119    /// Applies bitwise NOT operation using two's complement: `~x = -x - 1`.
120    /// Only works for integer dtypes.
121    ///
122    /// # Examples
123    /// ```ignore
124    /// let t = Tensor::from_slice(&[0i32, 1, 2, -1]);
125    /// let result = t.bitwise_not()?;  // [-1, -2, -3, 0]
126    /// ```
127    ///
128    /// # Errors
129    ///
130    /// Returns error if called on non-integer dtype.
131    pub fn bitwise_not(&self) -> Result<Tensor> {
132        // Verify dtype is integer
133        let dtype = self.uop().dtype();
134        if !dtype.is_int() {
135            return Err(Error::SymbolicShapeUnsupported {
136                operation: format!("bitwise_not on non-integer dtype {:?}", dtype),
137            });
138        }
139
140        // Bitwise NOT using two's complement: ~x = -x - 1
141        let negated = self.try_neg()?;
142        let one_scalar = Self::from_slice([1i32]).cast(dtype)?;
143
144        // Broadcast one to match self shape
145        let self_shape = self.shape()?;
146
147        let one_broadcast = if self_shape.is_empty() {
148            // Input is scalar - reshape [1] to []
149            one_scalar.try_reshape(&[] as &[isize])?
150        } else {
151            // Broadcast to match non-scalar shape
152            one_scalar.broadcast_to(&self_shape)?
153        };
154
155        negated.try_sub(&one_broadcast)
156    }
157}