1use crate::interpreter::errors::mex;
2use crate::interpreter::stack::pop2;
3use runmat_builtins::Value;
4use runmat_runtime::builtins::common::shape::is_scalar_shape;
5use runmat_runtime::RuntimeError;
6use std::future::Future;
7
8pub async fn add<CM, CMFut, F, FFut>(
9 stack: &mut Vec<Value>,
10 mut call_method: CM,
11 mut fallback: F,
12) -> Result<(), RuntimeError>
13where
14 CM: FnMut(Value, &'static str, Value) -> CMFut,
15 CMFut: Future<Output = Result<Value, RuntimeError>>,
16 F: FnMut(Value, Value) -> FFut,
17 FFut: Future<Output = Result<Value, RuntimeError>>,
18{
19 let (a, b) = pop2(stack)?;
20 let result = match (&a, &b) {
21 (Value::Object(obj), _) => {
22 match call_method(Value::Object(obj.clone()), "plus", b.clone()).await {
23 Ok(v) => v,
24 Err(_) => fallback(a.clone(), b.clone()).await?,
25 }
26 }
27 (_, Value::Object(obj)) => {
28 match call_method(Value::Object(obj.clone()), "plus", a.clone()).await {
29 Ok(v) => v,
30 Err(_) => fallback(a.clone(), b.clone()).await?,
31 }
32 }
33 _ => fallback(a.clone(), b.clone()).await?,
34 };
35 stack.push(result);
36 Ok(())
37}
38
39pub async fn sub<CM, CMFut, RM, RMFut, F, FFut>(
40 stack: &mut Vec<Value>,
41 mut call_method: CM,
42 mut right_method: RM,
43 mut fallback: F,
44) -> Result<(), RuntimeError>
45where
46 CM: FnMut(Value, &'static str, Value) -> CMFut,
47 CMFut: Future<Output = Result<Value, RuntimeError>>,
48 RM: FnMut(Value, Value) -> RMFut,
49 RMFut: Future<Output = Result<Value, RuntimeError>>,
50 F: FnMut(Value, Value) -> FFut,
51 FFut: Future<Output = Result<Value, RuntimeError>>,
52{
53 let (a, b) = pop2(stack)?;
54 let result = match (&a, &b) {
55 (Value::Object(obj), _) => {
56 match call_method(Value::Object(obj.clone()), "minus", b.clone()).await {
57 Ok(v) => v,
58 Err(_) => fallback(a.clone(), b.clone()).await?,
59 }
60 }
61 (_, Value::Object(obj)) => {
62 match right_method(Value::Object(obj.clone()), a.clone()).await {
63 Ok(v) => v,
64 Err(_) => fallback(a.clone(), b.clone()).await?,
65 }
66 }
67 _ => fallback(a.clone(), b.clone()).await?,
68 };
69 stack.push(result);
70 Ok(())
71}
72
73pub async fn mul<CM, CMFut, F, FFut>(
74 stack: &mut Vec<Value>,
75 mut call_method: CM,
76 mut fallback: F,
77) -> Result<(), RuntimeError>
78where
79 CM: FnMut(Value, &'static str, Value) -> CMFut,
80 CMFut: Future<Output = Result<Value, RuntimeError>>,
81 F: FnMut(Value, Value) -> FFut,
82 FFut: Future<Output = Result<Value, RuntimeError>>,
83{
84 let (a, b) = pop2(stack)?;
85 let result = match (&a, &b) {
86 (Value::Object(obj), _) => {
87 match call_method(Value::Object(obj.clone()), "mtimes", b.clone()).await {
88 Ok(v) => v,
89 Err(_) => fallback(a.clone(), b.clone()).await?,
90 }
91 }
92 (_, Value::Object(obj)) => {
93 match call_method(Value::Object(obj.clone()), "mtimes", a.clone()).await {
94 Ok(v) => v,
95 Err(_) => fallback(a.clone(), b.clone()).await?,
96 }
97 }
98 _ => fallback(a.clone(), b.clone()).await?,
99 };
100 stack.push(result);
101 Ok(())
102}
103
104pub async fn binary_method<CM, CMFut, F, FFut>(
105 stack: &mut Vec<Value>,
106 method: &'static str,
107 mut call_method: CM,
108 mut fallback: F,
109) -> Result<(), RuntimeError>
110where
111 CM: FnMut(Value, &'static str, Value) -> CMFut,
112 CMFut: Future<Output = Result<Value, RuntimeError>>,
113 F: FnMut(Value, Value) -> FFut,
114 FFut: Future<Output = Result<Value, RuntimeError>>,
115{
116 let (a, b) = pop2(stack)?;
117 let result = match (&a, &b) {
118 (Value::Object(obj), _) => {
119 match call_method(Value::Object(obj.clone()), method, b.clone()).await {
120 Ok(v) => v,
121 Err(_) => fallback(a.clone(), b.clone()).await?,
122 }
123 }
124 (_, Value::Object(obj)) => {
125 match call_method(Value::Object(obj.clone()), method, a.clone()).await {
126 Ok(v) => v,
127 Err(_) => fallback(a.clone(), b.clone()).await?,
128 }
129 }
130 _ => fallback(a.clone(), b.clone()).await?,
131 };
132 stack.push(result);
133 Ok(())
134}
135
136pub async fn binary_fallback<F, FFut>(
137 stack: &mut Vec<Value>,
138 mut fallback: F,
139) -> Result<(), RuntimeError>
140where
141 F: FnMut(Value, Value) -> FFut,
142 FFut: Future<Output = Result<Value, RuntimeError>>,
143{
144 let (a, b) = pop2(stack)?;
145 stack.push(fallback(a, b).await?);
146 Ok(())
147}
148
149pub async fn power<CM, CMFut, F, FFut>(
150 stack: &mut Vec<Value>,
151 mut call_method: CM,
152 mut fallback: F,
153) -> Result<(), RuntimeError>
154where
155 CM: FnMut(Value, &'static str, Value) -> CMFut,
156 CMFut: Future<Output = Result<Value, RuntimeError>>,
157 F: FnMut(Value, Value) -> FFut,
158 FFut: Future<Output = Result<Value, RuntimeError>>,
159{
160 let (a, b) = pop2(stack)?;
161 let result = match (&a, &b) {
162 (Value::Object(obj), _) => {
163 match call_method(Value::Object(obj.clone()), "power", b.clone()).await {
164 Ok(v) => v,
165 Err(_) => fallback(a.clone(), b.clone()).await?,
166 }
167 }
168 (_, Value::Object(obj)) => {
169 match call_method(Value::Object(obj.clone()), "power", a.clone()).await {
170 Ok(v) => v,
171 Err(_) => fallback(a.clone(), b.clone()).await?,
172 }
173 }
174 _ => fallback(a.clone(), b.clone()).await?,
175 };
176 stack.push(result);
177 Ok(())
178}
179
180pub async fn unary<UF, UFut>(stack: &mut Vec<Value>, mut op: UF) -> Result<(), RuntimeError>
181where
182 UF: FnMut(Value) -> UFut,
183 UFut: Future<Output = Result<Value, RuntimeError>>,
184{
185 let value = stack
186 .pop()
187 .ok_or(mex("StackUnderflow", "stack underflow"))?;
188 stack.push(op(value).await?);
189 Ok(())
190}
191
192pub fn is_scalarish_for_division(value: &Value) -> bool {
193 match value {
194 Value::Int(_) | Value::Num(_) | Value::Complex(_, _) | Value::Bool(_) => true,
195 Value::LogicalArray(arr) => is_scalar_shape(&arr.shape),
196 Value::Tensor(tensor) => is_scalar_shape(&tensor.shape),
197 Value::ComplexTensor(tensor) => is_scalar_shape(&tensor.shape),
198 Value::GpuTensor(handle) => is_scalar_shape(&handle.shape),
199 _ => false,
200 }
201}
202
203pub async fn execute_right_division<CM, CMFut, SF, SFFut, MF, MFFut>(
204 lhs: &Value,
205 rhs: &Value,
206 mut call_method: CM,
207 mut scalarish_fallback: SF,
208 mut matrix_fallback: MF,
209) -> Result<Value, RuntimeError>
210where
211 CM: FnMut(Value, &'static str, Value) -> CMFut,
212 CMFut: Future<Output = Result<Value, RuntimeError>>,
213 SF: FnMut(Value, Value) -> SFFut,
214 SFFut: Future<Output = Result<Value, RuntimeError>>,
215 MF: FnMut(Value, Value) -> MFFut,
216 MFFut: Future<Output = Result<Value, RuntimeError>>,
217{
218 match (lhs, rhs) {
219 (Value::Object(obj), _) => {
220 match call_method(Value::Object(obj.clone()), "mrdivide", rhs.clone()).await {
221 Ok(v) => Ok(v),
222 Err(_) => {
223 if is_scalarish_for_division(rhs) {
224 scalarish_fallback(lhs.clone(), rhs.clone()).await
225 } else {
226 matrix_fallback(lhs.clone(), rhs.clone()).await
227 }
228 }
229 }
230 }
231 (_, Value::Object(obj)) => {
232 match call_method(Value::Object(obj.clone()), "mrdivide", lhs.clone()).await {
233 Ok(v) => Ok(v),
234 Err(_) => {
235 if is_scalarish_for_division(rhs) {
236 scalarish_fallback(lhs.clone(), rhs.clone()).await
237 } else {
238 matrix_fallback(lhs.clone(), rhs.clone()).await
239 }
240 }
241 }
242 }
243 _ => {
244 if is_scalarish_for_division(rhs) {
245 scalarish_fallback(lhs.clone(), rhs.clone()).await
246 } else {
247 matrix_fallback(lhs.clone(), rhs.clone()).await
248 }
249 }
250 }
251}
252
253pub async fn execute_left_division<CM, CMFut, SF, SFFut, MF, MFFut>(
254 lhs: &Value,
255 rhs: &Value,
256 mut call_method: CM,
257 mut scalarish_fallback: SF,
258 mut matrix_fallback: MF,
259) -> Result<Value, RuntimeError>
260where
261 CM: FnMut(Value, &'static str, Value) -> CMFut,
262 CMFut: Future<Output = Result<Value, RuntimeError>>,
263 SF: FnMut(Value, Value) -> SFFut,
264 SFFut: Future<Output = Result<Value, RuntimeError>>,
265 MF: FnMut(Value, Value) -> MFFut,
266 MFFut: Future<Output = Result<Value, RuntimeError>>,
267{
268 match (lhs, rhs) {
269 (Value::Object(obj), _) => {
270 match call_method(Value::Object(obj.clone()), "mldivide", rhs.clone()).await {
271 Ok(v) => Ok(v),
272 Err(_) => {
273 if is_scalarish_for_division(lhs) {
274 scalarish_fallback(lhs.clone(), rhs.clone()).await
275 } else {
276 matrix_fallback(lhs.clone(), rhs.clone()).await
277 }
278 }
279 }
280 }
281 (_, Value::Object(obj)) => {
282 match call_method(Value::Object(obj.clone()), "mldivide", lhs.clone()).await {
283 Ok(v) => Ok(v),
284 Err(_) => {
285 if is_scalarish_for_division(lhs) {
286 scalarish_fallback(lhs.clone(), rhs.clone()).await
287 } else {
288 matrix_fallback(lhs.clone(), rhs.clone()).await
289 }
290 }
291 }
292 }
293 _ => {
294 if is_scalarish_for_division(lhs) {
295 scalarish_fallback(lhs.clone(), rhs.clone()).await
296 } else {
297 matrix_fallback(lhs.clone(), rhs.clone()).await
298 }
299 }
300 }
301}