svod_tensor/arithmetic.rs
1use snafu::ResultExt;
2
3use super::*;
4
5/// Unified macro for implementing Tensor operations.
6///
7/// Automatically handles:
8/// - Binary operations (with `other` parameter): Always use Result path
9/// - Unary operations with `@infallible` marker: Wrap in Ok()
10/// - Unary operations without marker: Use Result path
11macro_rules! impl_tensor_ops {
12 (
13 binary { $($bin_method:ident => $bin_uop:ident),* $(,)? }
14 unary_infallible { $($inf_method:ident => $inf_uop:ident),* $(,)? }
15 unary_fallible { $($fall_method:ident => $fall_uop:ident),* $(,)? }
16 ) => {
17 // Binary operations (with automatic broadcasting)
18 $(
19 #[track_caller]
20 pub fn $bin_method(&self, other: &Tensor) -> Result<Tensor> {
21 // Broadcast tensors to common shape
22 let (lhs, rhs) = self.broadcast_for_binop(other)?;
23
24 // Now call UOp operation with matching shapes
25 lhs.uop().$bin_uop(&rhs.uop()).map(Self::new).context(UOpSnafu)
26 }
27 )*
28
29 // Unary infallible operations
30 $(
31 #[track_caller]
32 pub fn $inf_method(&self) -> Result<Tensor> {
33 Ok(Self::new(self.uop().$inf_uop()))
34 }
35 )*
36
37 // Unary fallible operations
38 $(
39 #[track_caller]
40 pub fn $fall_method(&self) -> Result<Tensor> {
41 self.uop().$fall_uop().map(Self::new).context(UOpSnafu)
42 }
43 )*
44 };
45}
46
47impl Tensor {
48 impl_tensor_ops! {
49 binary {
50 try_add => try_add,
51 try_sub => try_sub,
52 try_mul => try_mul,
53 try_div => try_div,
54 try_mod => try_mod,
55 try_pow => try_pow,
56 try_eq => try_cmpeq,
57 try_ne => try_cmpne,
58 try_lt => try_cmplt,
59 try_le => try_cmple,
60 try_gt => try_cmpgt,
61 try_ge => try_cmpge,
62 try_bitor => try_or_op,
63 try_bitand => try_and_op,
64 try_bitxor => try_xor_op,
65 try_shl => try_shl_op,
66 try_shr => try_shr_op,
67 }
68 unary_infallible {
69 try_neg => neg,
70 try_abs => abs,
71 }
72 unary_fallible {
73 try_sqrt => try_sqrt,
74 try_rsqrt => try_rsqrt,
75 try_exp => try_exp,
76 try_exp2 => try_exp2,
77 try_log => try_log,
78 try_log2 => try_log2,
79 }
80 }
81
82 /// Logical NOT for boolean tensors.
83 ///
84 /// Converts to boolean dtype and applies logical negation.
85 /// For non-boolean tensors, treats zero as false, non-zero as true.
86 ///
87 /// # Examples
88 /// ```ignore
89 /// let t = Tensor::from_slice(&[true, false, true]);
90 /// let result = t.logical_not()?; // [false, true, false]
91 ///
92 /// let nums = Tensor::from_slice(&[0.0f32, 1.0, 2.0]);
93 /// let result = nums.logical_not()?; // [true, false, false]
94 /// ```
95 pub fn logical_not(&self) -> Result<Tensor> {
96 use svod_dtype::DType;
97
98 // Cast to bool (non-zero becomes true)
99 let as_bool = self.cast(DType::Bool)?;
100
101 // Create true constant tensor and broadcast to match shape
102 let true_scalar = Self::from_slice([true]);
103 let self_shape = as_bool.shape()?;
104
105 let true_broadcast = if self_shape.is_empty() {
106 // Input is scalar - reshape [1] to []
107 true_scalar.try_reshape(&[] as &[isize])?
108 } else {
109 // Broadcast to match non-scalar shape
110 true_scalar.broadcast_to(&self_shape)?
111 };
112
113 // Compare: !x ≡ (x != true)
114 as_bool.try_ne(&true_broadcast)
115 }
116
117 /// Bitwise NOT for integer tensors.
118 ///
119 /// Applies bitwise NOT operation using two's complement: `~x = -x - 1`.
120 /// Only works for integer dtypes.
121 ///
122 /// # Examples
123 /// ```ignore
124 /// let t = Tensor::from_slice(&[0i32, 1, 2, -1]);
125 /// let result = t.bitwise_not()?; // [-1, -2, -3, 0]
126 /// ```
127 ///
128 /// # Errors
129 ///
130 /// Returns error if called on non-integer dtype.
131 pub fn bitwise_not(&self) -> Result<Tensor> {
132 // Verify dtype is integer
133 let dtype = self.uop().dtype();
134 if !dtype.is_int() {
135 return Err(Error::SymbolicShapeUnsupported {
136 operation: format!("bitwise_not on non-integer dtype {:?}", dtype),
137 });
138 }
139
140 // Bitwise NOT using two's complement: ~x = -x - 1
141 let negated = self.try_neg()?;
142 let one_scalar = Self::from_slice([1i32]).cast(dtype)?;
143
144 // Broadcast one to match self shape
145 let self_shape = self.shape()?;
146
147 let one_broadcast = if self_shape.is_empty() {
148 // Input is scalar - reshape [1] to []
149 one_scalar.try_reshape(&[] as &[isize])?
150 } else {
151 // Broadcast to match non-scalar shape
152 one_scalar.broadcast_to(&self_shape)?
153 };
154
155 negated.try_sub(&one_broadcast)
156 }
157}