Skip to main content

runmat_vm/ops/
comparison.rs

1use crate::interpreter::stack::pop2;
2use runmat_builtins::Value;
3use runmat_runtime::builtins::common::shape::is_scalar_shape;
4use runmat_runtime::RuntimeError;
5use std::future::Future;
6
7fn rel_binary_use_builtin(a: &Value, b: &Value) -> bool {
8    !matches!(a, Value::Num(_) | Value::Int(_)) || !matches!(b, Value::Num(_) | Value::Int(_))
9}
10
11pub struct RelationInvertedSpec {
12    pub name: &'static str,
13    pub inverse_name: &'static str,
14    pub right_name: &'static str,
15    pub right_inverse_name: &'static str,
16    pub predicate: fn(f64, f64) -> bool,
17}
18
19pub async fn relation<CM, CMFut, B, BFut>(
20    stack: &mut Vec<Value>,
21    name: &'static str,
22    reverse_name: &'static str,
23    predicate: fn(f64, f64) -> bool,
24    mut call_method: CM,
25    mut call_builtin: B,
26) -> Result<(), RuntimeError>
27where
28    CM: FnMut(Value, &'static str, Value) -> CMFut,
29    CMFut: Future<Output = Result<Value, RuntimeError>>,
30    B: FnMut(&'static str, Value, Value) -> BFut,
31    BFut: Future<Output = Result<Value, RuntimeError>>,
32{
33    let (a, b) = pop2(stack)?;
34    let result = match (&a, &b) {
35        (Value::Object(obj), _) => {
36            match call_method(Value::Object(obj.clone()), name, b.clone()).await {
37                Ok(v) => v,
38                Err(_) => Value::Num(if predicate((&a).try_into()?, (&b).try_into()?) {
39                    1.0
40                } else {
41                    0.0
42                }),
43            }
44        }
45        (_, Value::Object(obj)) => {
46            match call_method(Value::Object(obj.clone()), reverse_name, a.clone()).await {
47                Ok(v) => v,
48                Err(_) => Value::Num(if predicate((&a).try_into()?, (&b).try_into()?) {
49                    1.0
50                } else {
51                    0.0
52                }),
53            }
54        }
55        _ => {
56            if rel_binary_use_builtin(&a, &b) {
57                call_builtin(name, a.clone(), b.clone()).await?
58            } else {
59                Value::Num(if predicate((&a).try_into()?, (&b).try_into()?) {
60                    1.0
61                } else {
62                    0.0
63                })
64            }
65        }
66    };
67    stack.push(result);
68    Ok(())
69}
70
71pub async fn relation_inverted<CM, CMFut, B, BFut, LT, LTFut>(
72    stack: &mut Vec<Value>,
73    spec: RelationInvertedSpec,
74    mut call_method: CM,
75    mut call_builtin: B,
76    mut logical_truth: LT,
77) -> Result<(), RuntimeError>
78where
79    CM: FnMut(Value, &'static str, Value) -> CMFut,
80    CMFut: Future<Output = Result<Value, RuntimeError>>,
81    B: FnMut(&'static str, Value, Value) -> BFut,
82    BFut: Future<Output = Result<Value, RuntimeError>>,
83    LT: FnMut(Value, String) -> LTFut,
84    LTFut: Future<Output = Result<bool, RuntimeError>>,
85{
86    let (a, b) = pop2(stack)?;
87    let result = match (&a, &b) {
88        (Value::Object(obj), _) => {
89            match call_method(Value::Object(obj.clone()), spec.name, b.clone()).await {
90                Ok(v) => v,
91                Err(_) => {
92                    match call_method(Value::Object(obj.clone()), spec.inverse_name, b.clone())
93                        .await
94                    {
95                        Ok(v) => Value::Num(
96                            if !logical_truth(v, "comparison result".to_string()).await? {
97                                1.0
98                            } else {
99                                0.0
100                            },
101                        ),
102                        Err(_) => {
103                            Value::Num(if (spec.predicate)((&a).try_into()?, (&b).try_into()?) {
104                                1.0
105                            } else {
106                                0.0
107                            })
108                        }
109                    }
110                }
111            }
112        }
113        (_, Value::Object(obj)) => {
114            match call_method(Value::Object(obj.clone()), spec.right_name, a.clone()).await {
115                Ok(v) => v,
116                Err(_) => {
117                    match call_method(
118                        Value::Object(obj.clone()),
119                        spec.right_inverse_name,
120                        a.clone(),
121                    )
122                    .await
123                    {
124                        Ok(v) => Value::Num(
125                            if !logical_truth(v, "comparison result".to_string()).await? {
126                                1.0
127                            } else {
128                                0.0
129                            },
130                        ),
131                        Err(_) => {
132                            Value::Num(if (spec.predicate)((&a).try_into()?, (&b).try_into()?) {
133                                1.0
134                            } else {
135                                0.0
136                            })
137                        }
138                    }
139                }
140            }
141        }
142        _ => {
143            if rel_binary_use_builtin(&a, &b) {
144                call_builtin(spec.name, a.clone(), b.clone()).await?
145            } else {
146                Value::Num(if (spec.predicate)((&a).try_into()?, (&b).try_into()?) {
147                    1.0
148                } else {
149                    0.0
150                })
151            }
152        }
153    };
154    stack.push(result);
155    Ok(())
156}
157
158pub async fn equal<CM, CMFut, B, BFut, LT, LTFut>(
159    stack: &mut Vec<Value>,
160    mut call_method: CM,
161    mut call_builtin: B,
162    _logical_truth: LT,
163) -> Result<(), RuntimeError>
164where
165    CM: FnMut(Value, &'static str, Value) -> CMFut,
166    CMFut: Future<Output = Result<Value, RuntimeError>>,
167    B: FnMut(&'static str, Value, Value) -> BFut,
168    BFut: Future<Output = Result<Value, RuntimeError>>,
169    LT: FnMut(Value, String) -> LTFut,
170    LTFut: Future<Output = Result<bool, RuntimeError>>,
171{
172    let (a, b) = pop2(stack)?;
173    let push_logical =
174        |data: Vec<u8>, shape: Vec<usize>, stack: &mut Vec<Value>| -> Result<(), RuntimeError> {
175            if data.len() == 1 && is_scalar_shape(&shape) {
176                stack.push(Value::Bool(data[0] != 0));
177                return Ok(());
178            }
179            let logical =
180                runmat_builtins::LogicalArray::new(data, shape).map_err(|e| format!("eq: {e}"))?;
181            stack.push(Value::LogicalArray(logical));
182            Ok(())
183        };
184    let logical_eq_scalar = |array: &runmat_builtins::LogicalArray,
185                             scalar: f64,
186                             stack: &mut Vec<Value>|
187     -> Result<(), RuntimeError> {
188        let mut out = Vec::with_capacity(array.data.len());
189        for &bit in &array.data {
190            let val = if bit != 0 { 1.0 } else { 0.0 };
191            out.push(if (val - scalar).abs() < 1e-12 { 1 } else { 0 });
192        }
193        push_logical(out, array.shape.clone(), stack)
194    };
195    let logical_eq_tensor = |array: &runmat_builtins::LogicalArray,
196                             tensor: &runmat_builtins::Tensor,
197                             stack: &mut Vec<Value>|
198     -> Result<(), RuntimeError> {
199        if array.shape != tensor.shape {
200            return Err(crate::interpreter::errors::mex(
201                "ShapeMismatch",
202                "shape mismatch for element-wise comparison",
203            ));
204        }
205        let mut out = Vec::with_capacity(array.data.len());
206        for i in 0..array.data.len() {
207            let val = if array.data[i] != 0 { 1.0 } else { 0.0 };
208            out.push(if (val - tensor.data[i]).abs() < 1e-12 {
209                1
210            } else {
211                0
212            });
213        }
214        push_logical(out, array.shape.clone(), stack)
215    };
216    match (&a, &b) {
217        (Value::Object(obj), _) => {
218            match call_method(Value::Object(obj.clone()), "eq", b.clone()).await {
219                Ok(v) => stack.push(v),
220                Err(_) => {
221                    let aa: f64 = (&a).try_into()?;
222                    let bb: f64 = (&b).try_into()?;
223                    stack.push(Value::Num(if aa == bb { 1.0 } else { 0.0 }))
224                }
225            }
226        }
227        (_, Value::Object(obj)) => {
228            match call_method(Value::Object(obj.clone()), "eq", a.clone()).await {
229                Ok(v) => stack.push(v),
230                Err(_) => {
231                    let aa: f64 = (&a).try_into()?;
232                    let bb: f64 = (&b).try_into()?;
233                    stack.push(Value::Num(if aa == bb { 1.0 } else { 0.0 }))
234                }
235            }
236        }
237        (Value::HandleObject(_), _) | (_, Value::HandleObject(_)) => {
238            stack.push(call_builtin("eq", a.clone(), b.clone()).await?);
239        }
240        (Value::LogicalArray(la), Value::LogicalArray(lb)) => {
241            if la.shape != lb.shape {
242                return Err(crate::interpreter::errors::mex(
243                    "ShapeMismatch",
244                    "shape mismatch for element-wise comparison",
245                ));
246            }
247            let mut out = Vec::with_capacity(la.data.len());
248            for i in 0..la.data.len() {
249                out.push(if la.data[i] == lb.data[i] { 1 } else { 0 });
250            }
251            push_logical(out, la.shape.clone(), stack)?;
252        }
253        (Value::LogicalArray(la), Value::Num(n)) => logical_eq_scalar(la, *n, stack)?,
254        (Value::LogicalArray(la), Value::Int(i)) => logical_eq_scalar(la, i.to_f64(), stack)?,
255        (Value::LogicalArray(la), Value::Bool(flag)) => {
256            logical_eq_scalar(la, if *flag { 1.0 } else { 0.0 }, stack)?
257        }
258        (Value::Num(n), Value::LogicalArray(lb)) => logical_eq_scalar(lb, *n, stack)?,
259        (Value::Int(i), Value::LogicalArray(lb)) => logical_eq_scalar(lb, i.to_f64(), stack)?,
260        (Value::Bool(flag), Value::LogicalArray(lb)) => {
261            logical_eq_scalar(lb, if *flag { 1.0 } else { 0.0 }, stack)?
262        }
263        (Value::LogicalArray(la), Value::Tensor(tb)) => logical_eq_tensor(la, tb, stack)?,
264        (Value::Tensor(ta), Value::LogicalArray(lb)) => logical_eq_tensor(lb, ta, stack)?,
265        (Value::Tensor(ta), Value::Tensor(tb)) => {
266            if ta.shape != tb.shape {
267                return Err(crate::interpreter::errors::mex(
268                    "ShapeMismatch",
269                    "shape mismatch for element-wise comparison",
270                ));
271            }
272            let mut out = Vec::with_capacity(ta.data.len());
273            for i in 0..ta.data.len() {
274                out.push(if (ta.data[i] - tb.data[i]).abs() < 1e-12 {
275                    1.0
276                } else {
277                    0.0
278                });
279            }
280            stack.push(Value::Tensor(
281                runmat_builtins::Tensor::new(out, ta.shape.clone())
282                    .map_err(|e| format!("eq: {e}"))?,
283            ));
284        }
285        (Value::Tensor(t), Value::Num(_)) | (Value::Tensor(t), Value::Int(_)) => {
286            let s = match &b {
287                Value::Num(n) => *n,
288                Value::Int(i) => i.to_f64(),
289                _ => 0.0,
290            };
291            let out: Vec<f64> = t
292                .data
293                .iter()
294                .map(|x| if (*x - s).abs() < 1e-12 { 1.0 } else { 0.0 })
295                .collect();
296            stack.push(Value::Tensor(
297                runmat_builtins::Tensor::new(out, t.shape.clone())
298                    .map_err(|e| format!("eq: {e}"))?,
299            ));
300        }
301        (Value::Num(_), Value::Tensor(t)) | (Value::Int(_), Value::Tensor(t)) => {
302            let s = match &a {
303                Value::Num(n) => *n,
304                Value::Int(i) => i.to_f64(),
305                _ => 0.0,
306            };
307            let out: Vec<f64> = t
308                .data
309                .iter()
310                .map(|x| if (s - *x).abs() < 1e-12 { 1.0 } else { 0.0 })
311                .collect();
312            stack.push(Value::Tensor(
313                runmat_builtins::Tensor::new(out, t.shape.clone())
314                    .map_err(|e| format!("eq: {e}"))?,
315            ));
316        }
317        (Value::StringArray(sa), Value::StringArray(sb)) => {
318            if sa.shape != sb.shape {
319                return Err(crate::interpreter::errors::mex(
320                    "ShapeMismatch",
321                    "shape mismatch for string array comparison",
322                ));
323            }
324            let mut out = Vec::with_capacity(sa.data.len());
325            for i in 0..sa.data.len() {
326                out.push(if sa.data[i] == sb.data[i] { 1.0 } else { 0.0 });
327            }
328            stack.push(Value::Tensor(
329                runmat_builtins::Tensor::new(out, sa.shape.clone())
330                    .map_err(|e| format!("eq: {e}"))?,
331            ));
332        }
333        (Value::StringArray(sa), Value::String(s)) => {
334            let mut out = Vec::with_capacity(sa.data.len());
335            for i in 0..sa.data.len() {
336                out.push(if sa.data[i] == *s { 1.0 } else { 0.0 });
337            }
338            stack.push(Value::Tensor(
339                runmat_builtins::Tensor::new(out, sa.shape.clone())
340                    .map_err(|e| format!("eq: {e}"))?,
341            ));
342        }
343        (Value::String(s), Value::StringArray(sa)) => {
344            let mut out = Vec::with_capacity(sa.data.len());
345            for i in 0..sa.data.len() {
346                out.push(if *s == sa.data[i] { 1.0 } else { 0.0 });
347            }
348            stack.push(Value::Tensor(
349                runmat_builtins::Tensor::new(out, sa.shape.clone())
350                    .map_err(|e| format!("eq: {e}"))?,
351            ));
352        }
353        (Value::String(a_s), Value::String(b_s)) => {
354            stack.push(Value::Num(if a_s == b_s { 1.0 } else { 0.0 }))
355        }
356        _ => {
357            let bb: f64 = (&b).try_into()?;
358            let aa: f64 = (&a).try_into()?;
359            stack.push(Value::Num(if aa == bb { 1.0 } else { 0.0 }));
360        }
361    }
362    Ok(())
363}
364
365pub async fn not_equal<CM, CMFut, B, BFut, LT, LTFut>(
366    stack: &mut Vec<Value>,
367    mut call_method: CM,
368    mut call_builtin: B,
369    mut logical_truth: LT,
370) -> Result<(), RuntimeError>
371where
372    CM: FnMut(Value, &'static str, Value) -> CMFut,
373    CMFut: Future<Output = Result<Value, RuntimeError>>,
374    B: FnMut(&'static str, Value, Value) -> BFut,
375    BFut: Future<Output = Result<Value, RuntimeError>>,
376    LT: FnMut(Value, String) -> LTFut,
377    LTFut: Future<Output = Result<bool, RuntimeError>>,
378{
379    let (a, b) = pop2(stack)?;
380    match (&a, &b) {
381        (Value::Object(obj), _) => {
382            match call_method(Value::Object(obj.clone()), "ne", b.clone()).await {
383                Ok(v) => stack.push(v),
384                Err(_) => match call_method(Value::Object(obj.clone()), "eq", b.clone()).await {
385                    Ok(v) => stack.push(Value::Num(
386                        if !logical_truth(v, "comparison result".to_string()).await? {
387                            1.0
388                        } else {
389                            0.0
390                        },
391                    )),
392                    Err(_) => {
393                        let aa: f64 = (&a).try_into()?;
394                        let bb: f64 = (&b).try_into()?;
395                        stack.push(Value::Num(if aa != bb { 1.0 } else { 0.0 }));
396                    }
397                },
398            }
399        }
400        (_, Value::Object(obj)) => {
401            match call_method(Value::Object(obj.clone()), "ne", a.clone()).await {
402                Ok(v) => stack.push(v),
403                Err(_) => match call_method(Value::Object(obj.clone()), "eq", a.clone()).await {
404                    Ok(v) => stack.push(Value::Num(
405                        if !logical_truth(v, "comparison result".to_string()).await? {
406                            1.0
407                        } else {
408                            0.0
409                        },
410                    )),
411                    Err(_) => {
412                        let aa: f64 = (&a).try_into()?;
413                        let bb: f64 = (&b).try_into()?;
414                        stack.push(Value::Num(if aa != bb { 1.0 } else { 0.0 }));
415                    }
416                },
417            }
418        }
419        (Value::HandleObject(_), _) | (_, Value::HandleObject(_)) => {
420            stack.push(call_builtin("ne", a.clone(), b.clone()).await?)
421        }
422        (Value::Tensor(ta), Value::Tensor(tb)) => {
423            if ta.shape != tb.shape {
424                return Err(crate::interpreter::errors::mex(
425                    "ShapeMismatch",
426                    "shape mismatch for element-wise comparison",
427                ));
428            }
429            let mut out = Vec::with_capacity(ta.data.len());
430            for i in 0..ta.data.len() {
431                out.push(if (ta.data[i] - tb.data[i]).abs() >= 1e-12 {
432                    1.0
433                } else {
434                    0.0
435                });
436            }
437            stack.push(Value::Tensor(
438                runmat_builtins::Tensor::new(out, ta.shape.clone())
439                    .map_err(|e| format!("ne: {e}"))?,
440            ));
441        }
442        (Value::Tensor(t), Value::Num(_)) | (Value::Tensor(t), Value::Int(_)) => {
443            let s = match &b {
444                Value::Num(n) => *n,
445                Value::Int(i) => i.to_f64(),
446                _ => 0.0,
447            };
448            let out: Vec<f64> = t
449                .data
450                .iter()
451                .map(|x| if (*x - s).abs() >= 1e-12 { 1.0 } else { 0.0 })
452                .collect();
453            stack.push(Value::Tensor(
454                runmat_builtins::Tensor::new(out, t.shape.clone())
455                    .map_err(|e| format!("ne: {e}"))?,
456            ));
457        }
458        (Value::Num(_), Value::Tensor(t)) | (Value::Int(_), Value::Tensor(t)) => {
459            let s = match &a {
460                Value::Num(n) => *n,
461                Value::Int(i) => i.to_f64(),
462                _ => 0.0,
463            };
464            let out: Vec<f64> = t
465                .data
466                .iter()
467                .map(|x| if (s - *x).abs() >= 1e-12 { 1.0 } else { 0.0 })
468                .collect();
469            stack.push(Value::Tensor(
470                runmat_builtins::Tensor::new(out, t.shape.clone())
471                    .map_err(|e| format!("ne: {e}"))?,
472            ));
473        }
474        (Value::StringArray(sa), Value::StringArray(sb)) => {
475            if sa.shape != sb.shape {
476                return Err(crate::interpreter::errors::mex(
477                    "ShapeMismatch",
478                    "shape mismatch for string array comparison",
479                ));
480            }
481            let mut out = Vec::with_capacity(sa.data.len());
482            for i in 0..sa.data.len() {
483                out.push(if sa.data[i] != sb.data[i] { 1.0 } else { 0.0 });
484            }
485            stack.push(Value::Tensor(
486                runmat_builtins::Tensor::new(out, sa.shape.clone())
487                    .map_err(|e| format!("ne: {e}"))?,
488            ));
489        }
490        (Value::StringArray(sa), Value::String(s)) => {
491            let mut out = Vec::with_capacity(sa.data.len());
492            for i in 0..sa.data.len() {
493                out.push(if sa.data[i] != *s { 1.0 } else { 0.0 });
494            }
495            stack.push(Value::Tensor(
496                runmat_builtins::Tensor::new(out, sa.shape.clone())
497                    .map_err(|e| format!("ne: {e}"))?,
498            ));
499        }
500        (Value::String(s), Value::StringArray(sa)) => {
501            let mut out = Vec::with_capacity(sa.data.len());
502            for i in 0..sa.data.len() {
503                out.push(if *s != sa.data[i] { 1.0 } else { 0.0 });
504            }
505            stack.push(Value::Tensor(
506                runmat_builtins::Tensor::new(out, sa.shape.clone())
507                    .map_err(|e| format!("ne: {e}"))?,
508            ));
509        }
510        (Value::String(a_s), Value::String(b_s)) => {
511            stack.push(Value::Num(if a_s != b_s { 1.0 } else { 0.0 }))
512        }
513        _ => {
514            let bb: f64 = (&b).try_into()?;
515            let aa: f64 = (&a).try_into()?;
516            stack.push(Value::Num(if aa != bb { 1.0 } else { 0.0 }));
517        }
518    }
519    Ok(())
520}