Skip to main content

runmat_vm/ops/
arithmetic.rs

1use crate::interpreter::errors::mex;
2use crate::interpreter::stack::pop2;
3use runmat_builtins::Value;
4use runmat_runtime::builtins::common::shape::is_scalar_shape;
5use runmat_runtime::RuntimeError;
6use std::future::Future;
7
8pub async fn add<CM, CMFut, F, FFut>(
9    stack: &mut Vec<Value>,
10    mut call_method: CM,
11    mut fallback: F,
12) -> Result<(), RuntimeError>
13where
14    CM: FnMut(Value, &'static str, Value) -> CMFut,
15    CMFut: Future<Output = Result<Value, RuntimeError>>,
16    F: FnMut(Value, Value) -> FFut,
17    FFut: Future<Output = Result<Value, RuntimeError>>,
18{
19    let (a, b) = pop2(stack)?;
20    let result = match (&a, &b) {
21        (Value::Object(obj), _) => {
22            match call_method(Value::Object(obj.clone()), "plus", b.clone()).await {
23                Ok(v) => v,
24                Err(_) => fallback(a.clone(), b.clone()).await?,
25            }
26        }
27        (_, Value::Object(obj)) => {
28            match call_method(Value::Object(obj.clone()), "plus", a.clone()).await {
29                Ok(v) => v,
30                Err(_) => fallback(a.clone(), b.clone()).await?,
31            }
32        }
33        _ => fallback(a.clone(), b.clone()).await?,
34    };
35    stack.push(result);
36    Ok(())
37}
38
39pub async fn sub<CM, CMFut, RM, RMFut, F, FFut>(
40    stack: &mut Vec<Value>,
41    mut call_method: CM,
42    mut right_method: RM,
43    mut fallback: F,
44) -> Result<(), RuntimeError>
45where
46    CM: FnMut(Value, &'static str, Value) -> CMFut,
47    CMFut: Future<Output = Result<Value, RuntimeError>>,
48    RM: FnMut(Value, Value) -> RMFut,
49    RMFut: Future<Output = Result<Value, RuntimeError>>,
50    F: FnMut(Value, Value) -> FFut,
51    FFut: Future<Output = Result<Value, RuntimeError>>,
52{
53    let (a, b) = pop2(stack)?;
54    let result = match (&a, &b) {
55        (Value::Object(obj), _) => {
56            match call_method(Value::Object(obj.clone()), "minus", b.clone()).await {
57                Ok(v) => v,
58                Err(_) => fallback(a.clone(), b.clone()).await?,
59            }
60        }
61        (_, Value::Object(obj)) => {
62            match right_method(Value::Object(obj.clone()), a.clone()).await {
63                Ok(v) => v,
64                Err(_) => fallback(a.clone(), b.clone()).await?,
65            }
66        }
67        _ => fallback(a.clone(), b.clone()).await?,
68    };
69    stack.push(result);
70    Ok(())
71}
72
73pub async fn mul<CM, CMFut, F, FFut>(
74    stack: &mut Vec<Value>,
75    mut call_method: CM,
76    mut fallback: F,
77) -> Result<(), RuntimeError>
78where
79    CM: FnMut(Value, &'static str, Value) -> CMFut,
80    CMFut: Future<Output = Result<Value, RuntimeError>>,
81    F: FnMut(Value, Value) -> FFut,
82    FFut: Future<Output = Result<Value, RuntimeError>>,
83{
84    let (a, b) = pop2(stack)?;
85    let result = match (&a, &b) {
86        (Value::Object(obj), _) => {
87            match call_method(Value::Object(obj.clone()), "mtimes", b.clone()).await {
88                Ok(v) => v,
89                Err(_) => fallback(a.clone(), b.clone()).await?,
90            }
91        }
92        (_, Value::Object(obj)) => {
93            match call_method(Value::Object(obj.clone()), "mtimes", a.clone()).await {
94                Ok(v) => v,
95                Err(_) => fallback(a.clone(), b.clone()).await?,
96            }
97        }
98        _ => fallback(a.clone(), b.clone()).await?,
99    };
100    stack.push(result);
101    Ok(())
102}
103
104pub async fn binary_method<CM, CMFut, F, FFut>(
105    stack: &mut Vec<Value>,
106    method: &'static str,
107    mut call_method: CM,
108    mut fallback: F,
109) -> Result<(), RuntimeError>
110where
111    CM: FnMut(Value, &'static str, Value) -> CMFut,
112    CMFut: Future<Output = Result<Value, RuntimeError>>,
113    F: FnMut(Value, Value) -> FFut,
114    FFut: Future<Output = Result<Value, RuntimeError>>,
115{
116    let (a, b) = pop2(stack)?;
117    let result = match (&a, &b) {
118        (Value::Object(obj), _) => {
119            match call_method(Value::Object(obj.clone()), method, b.clone()).await {
120                Ok(v) => v,
121                Err(_) => fallback(a.clone(), b.clone()).await?,
122            }
123        }
124        (_, Value::Object(obj)) => {
125            match call_method(Value::Object(obj.clone()), method, a.clone()).await {
126                Ok(v) => v,
127                Err(_) => fallback(a.clone(), b.clone()).await?,
128            }
129        }
130        _ => fallback(a.clone(), b.clone()).await?,
131    };
132    stack.push(result);
133    Ok(())
134}
135
136pub async fn binary_fallback<F, FFut>(
137    stack: &mut Vec<Value>,
138    mut fallback: F,
139) -> Result<(), RuntimeError>
140where
141    F: FnMut(Value, Value) -> FFut,
142    FFut: Future<Output = Result<Value, RuntimeError>>,
143{
144    let (a, b) = pop2(stack)?;
145    stack.push(fallback(a, b).await?);
146    Ok(())
147}
148
149pub async fn power<CM, CMFut, F, FFut>(
150    stack: &mut Vec<Value>,
151    mut call_method: CM,
152    mut fallback: F,
153) -> Result<(), RuntimeError>
154where
155    CM: FnMut(Value, &'static str, Value) -> CMFut,
156    CMFut: Future<Output = Result<Value, RuntimeError>>,
157    F: FnMut(Value, Value) -> FFut,
158    FFut: Future<Output = Result<Value, RuntimeError>>,
159{
160    let (a, b) = pop2(stack)?;
161    let result = match (&a, &b) {
162        (Value::Object(obj), _) => {
163            match call_method(Value::Object(obj.clone()), "power", b.clone()).await {
164                Ok(v) => v,
165                Err(_) => fallback(a.clone(), b.clone()).await?,
166            }
167        }
168        (_, Value::Object(obj)) => {
169            match call_method(Value::Object(obj.clone()), "power", a.clone()).await {
170                Ok(v) => v,
171                Err(_) => fallback(a.clone(), b.clone()).await?,
172            }
173        }
174        _ => fallback(a.clone(), b.clone()).await?,
175    };
176    stack.push(result);
177    Ok(())
178}
179
180pub async fn unary<UF, UFut>(stack: &mut Vec<Value>, mut op: UF) -> Result<(), RuntimeError>
181where
182    UF: FnMut(Value) -> UFut,
183    UFut: Future<Output = Result<Value, RuntimeError>>,
184{
185    let value = stack
186        .pop()
187        .ok_or(mex("StackUnderflow", "stack underflow"))?;
188    stack.push(op(value).await?);
189    Ok(())
190}
191
192pub fn is_scalarish_for_division(value: &Value) -> bool {
193    match value {
194        Value::Int(_) | Value::Num(_) | Value::Complex(_, _) | Value::Bool(_) => true,
195        Value::LogicalArray(arr) => is_scalar_shape(&arr.shape),
196        Value::Tensor(tensor) => is_scalar_shape(&tensor.shape),
197        Value::ComplexTensor(tensor) => is_scalar_shape(&tensor.shape),
198        Value::GpuTensor(handle) => is_scalar_shape(&handle.shape),
199        _ => false,
200    }
201}
202
203pub async fn execute_right_division<CM, CMFut, SF, SFFut, MF, MFFut>(
204    lhs: &Value,
205    rhs: &Value,
206    mut call_method: CM,
207    mut scalarish_fallback: SF,
208    mut matrix_fallback: MF,
209) -> Result<Value, RuntimeError>
210where
211    CM: FnMut(Value, &'static str, Value) -> CMFut,
212    CMFut: Future<Output = Result<Value, RuntimeError>>,
213    SF: FnMut(Value, Value) -> SFFut,
214    SFFut: Future<Output = Result<Value, RuntimeError>>,
215    MF: FnMut(Value, Value) -> MFFut,
216    MFFut: Future<Output = Result<Value, RuntimeError>>,
217{
218    match (lhs, rhs) {
219        (Value::Object(obj), _) => {
220            match call_method(Value::Object(obj.clone()), "mrdivide", rhs.clone()).await {
221                Ok(v) => Ok(v),
222                Err(_) => {
223                    if is_scalarish_for_division(rhs) {
224                        scalarish_fallback(lhs.clone(), rhs.clone()).await
225                    } else {
226                        matrix_fallback(lhs.clone(), rhs.clone()).await
227                    }
228                }
229            }
230        }
231        (_, Value::Object(obj)) => {
232            match call_method(Value::Object(obj.clone()), "mrdivide", lhs.clone()).await {
233                Ok(v) => Ok(v),
234                Err(_) => {
235                    if is_scalarish_for_division(rhs) {
236                        scalarish_fallback(lhs.clone(), rhs.clone()).await
237                    } else {
238                        matrix_fallback(lhs.clone(), rhs.clone()).await
239                    }
240                }
241            }
242        }
243        _ => {
244            if is_scalarish_for_division(rhs) {
245                scalarish_fallback(lhs.clone(), rhs.clone()).await
246            } else {
247                matrix_fallback(lhs.clone(), rhs.clone()).await
248            }
249        }
250    }
251}
252
253pub async fn execute_left_division<CM, CMFut, SF, SFFut, MF, MFFut>(
254    lhs: &Value,
255    rhs: &Value,
256    mut call_method: CM,
257    mut scalarish_fallback: SF,
258    mut matrix_fallback: MF,
259) -> Result<Value, RuntimeError>
260where
261    CM: FnMut(Value, &'static str, Value) -> CMFut,
262    CMFut: Future<Output = Result<Value, RuntimeError>>,
263    SF: FnMut(Value, Value) -> SFFut,
264    SFFut: Future<Output = Result<Value, RuntimeError>>,
265    MF: FnMut(Value, Value) -> MFFut,
266    MFFut: Future<Output = Result<Value, RuntimeError>>,
267{
268    match (lhs, rhs) {
269        (Value::Object(obj), _) => {
270            match call_method(Value::Object(obj.clone()), "mldivide", rhs.clone()).await {
271                Ok(v) => Ok(v),
272                Err(_) => {
273                    if is_scalarish_for_division(lhs) {
274                        scalarish_fallback(lhs.clone(), rhs.clone()).await
275                    } else {
276                        matrix_fallback(lhs.clone(), rhs.clone()).await
277                    }
278                }
279            }
280        }
281        (_, Value::Object(obj)) => {
282            match call_method(Value::Object(obj.clone()), "mldivide", lhs.clone()).await {
283                Ok(v) => Ok(v),
284                Err(_) => {
285                    if is_scalarish_for_division(lhs) {
286                        scalarish_fallback(lhs.clone(), rhs.clone()).await
287                    } else {
288                        matrix_fallback(lhs.clone(), rhs.clone()).await
289                    }
290                }
291            }
292        }
293        _ => {
294            if is_scalarish_for_division(lhs) {
295                scalarish_fallback(lhs.clone(), rhs.clone()).await
296            } else {
297                matrix_fallback(lhs.clone(), rhs.clone()).await
298            }
299        }
300    }
301}