1use crate::interpreter::stack::pop2;
2use runmat_builtins::Value;
3use runmat_runtime::builtins::common::shape::is_scalar_shape;
4use runmat_runtime::RuntimeError;
5use std::future::Future;
6
7fn rel_binary_use_builtin(a: &Value, b: &Value) -> bool {
8 !matches!(a, Value::Num(_) | Value::Int(_)) || !matches!(b, Value::Num(_) | Value::Int(_))
9}
10
11pub struct RelationInvertedSpec {
12 pub name: &'static str,
13 pub inverse_name: &'static str,
14 pub right_name: &'static str,
15 pub right_inverse_name: &'static str,
16 pub predicate: fn(f64, f64) -> bool,
17}
18
19pub async fn relation<CM, CMFut, B, BFut>(
20 stack: &mut Vec<Value>,
21 name: &'static str,
22 reverse_name: &'static str,
23 predicate: fn(f64, f64) -> bool,
24 mut call_method: CM,
25 mut call_builtin: B,
26) -> Result<(), RuntimeError>
27where
28 CM: FnMut(Value, &'static str, Value) -> CMFut,
29 CMFut: Future<Output = Result<Value, RuntimeError>>,
30 B: FnMut(&'static str, Value, Value) -> BFut,
31 BFut: Future<Output = Result<Value, RuntimeError>>,
32{
33 let (a, b) = pop2(stack)?;
34 let result = match (&a, &b) {
35 (Value::Object(obj), _) => {
36 match call_method(Value::Object(obj.clone()), name, b.clone()).await {
37 Ok(v) => v,
38 Err(_) => Value::Num(if predicate((&a).try_into()?, (&b).try_into()?) {
39 1.0
40 } else {
41 0.0
42 }),
43 }
44 }
45 (_, Value::Object(obj)) => {
46 match call_method(Value::Object(obj.clone()), reverse_name, a.clone()).await {
47 Ok(v) => v,
48 Err(_) => Value::Num(if predicate((&a).try_into()?, (&b).try_into()?) {
49 1.0
50 } else {
51 0.0
52 }),
53 }
54 }
55 _ => {
56 if rel_binary_use_builtin(&a, &b) {
57 call_builtin(name, a.clone(), b.clone()).await?
58 } else {
59 Value::Num(if predicate((&a).try_into()?, (&b).try_into()?) {
60 1.0
61 } else {
62 0.0
63 })
64 }
65 }
66 };
67 stack.push(result);
68 Ok(())
69}
70
71pub async fn relation_inverted<CM, CMFut, B, BFut, LT, LTFut>(
72 stack: &mut Vec<Value>,
73 spec: RelationInvertedSpec,
74 mut call_method: CM,
75 mut call_builtin: B,
76 mut logical_truth: LT,
77) -> Result<(), RuntimeError>
78where
79 CM: FnMut(Value, &'static str, Value) -> CMFut,
80 CMFut: Future<Output = Result<Value, RuntimeError>>,
81 B: FnMut(&'static str, Value, Value) -> BFut,
82 BFut: Future<Output = Result<Value, RuntimeError>>,
83 LT: FnMut(Value, String) -> LTFut,
84 LTFut: Future<Output = Result<bool, RuntimeError>>,
85{
86 let (a, b) = pop2(stack)?;
87 let result = match (&a, &b) {
88 (Value::Object(obj), _) => {
89 match call_method(Value::Object(obj.clone()), spec.name, b.clone()).await {
90 Ok(v) => v,
91 Err(_) => {
92 match call_method(Value::Object(obj.clone()), spec.inverse_name, b.clone())
93 .await
94 {
95 Ok(v) => Value::Num(
96 if !logical_truth(v, "comparison result".to_string()).await? {
97 1.0
98 } else {
99 0.0
100 },
101 ),
102 Err(_) => {
103 Value::Num(if (spec.predicate)((&a).try_into()?, (&b).try_into()?) {
104 1.0
105 } else {
106 0.0
107 })
108 }
109 }
110 }
111 }
112 }
113 (_, Value::Object(obj)) => {
114 match call_method(Value::Object(obj.clone()), spec.right_name, a.clone()).await {
115 Ok(v) => v,
116 Err(_) => {
117 match call_method(
118 Value::Object(obj.clone()),
119 spec.right_inverse_name,
120 a.clone(),
121 )
122 .await
123 {
124 Ok(v) => Value::Num(
125 if !logical_truth(v, "comparison result".to_string()).await? {
126 1.0
127 } else {
128 0.0
129 },
130 ),
131 Err(_) => {
132 Value::Num(if (spec.predicate)((&a).try_into()?, (&b).try_into()?) {
133 1.0
134 } else {
135 0.0
136 })
137 }
138 }
139 }
140 }
141 }
142 _ => {
143 if rel_binary_use_builtin(&a, &b) {
144 call_builtin(spec.name, a.clone(), b.clone()).await?
145 } else {
146 Value::Num(if (spec.predicate)((&a).try_into()?, (&b).try_into()?) {
147 1.0
148 } else {
149 0.0
150 })
151 }
152 }
153 };
154 stack.push(result);
155 Ok(())
156}
157
158pub async fn equal<CM, CMFut, B, BFut, LT, LTFut>(
159 stack: &mut Vec<Value>,
160 mut call_method: CM,
161 mut call_builtin: B,
162 _logical_truth: LT,
163) -> Result<(), RuntimeError>
164where
165 CM: FnMut(Value, &'static str, Value) -> CMFut,
166 CMFut: Future<Output = Result<Value, RuntimeError>>,
167 B: FnMut(&'static str, Value, Value) -> BFut,
168 BFut: Future<Output = Result<Value, RuntimeError>>,
169 LT: FnMut(Value, String) -> LTFut,
170 LTFut: Future<Output = Result<bool, RuntimeError>>,
171{
172 let (a, b) = pop2(stack)?;
173 let push_logical =
174 |data: Vec<u8>, shape: Vec<usize>, stack: &mut Vec<Value>| -> Result<(), RuntimeError> {
175 if data.len() == 1 && is_scalar_shape(&shape) {
176 stack.push(Value::Bool(data[0] != 0));
177 return Ok(());
178 }
179 let logical =
180 runmat_builtins::LogicalArray::new(data, shape).map_err(|e| format!("eq: {e}"))?;
181 stack.push(Value::LogicalArray(logical));
182 Ok(())
183 };
184 let logical_eq_scalar = |array: &runmat_builtins::LogicalArray,
185 scalar: f64,
186 stack: &mut Vec<Value>|
187 -> Result<(), RuntimeError> {
188 let mut out = Vec::with_capacity(array.data.len());
189 for &bit in &array.data {
190 let val = if bit != 0 { 1.0 } else { 0.0 };
191 out.push(if (val - scalar).abs() < 1e-12 { 1 } else { 0 });
192 }
193 push_logical(out, array.shape.clone(), stack)
194 };
195 let logical_eq_tensor = |array: &runmat_builtins::LogicalArray,
196 tensor: &runmat_builtins::Tensor,
197 stack: &mut Vec<Value>|
198 -> Result<(), RuntimeError> {
199 if array.shape != tensor.shape {
200 return Err(crate::interpreter::errors::mex(
201 "ShapeMismatch",
202 "shape mismatch for element-wise comparison",
203 ));
204 }
205 let mut out = Vec::with_capacity(array.data.len());
206 for i in 0..array.data.len() {
207 let val = if array.data[i] != 0 { 1.0 } else { 0.0 };
208 out.push(if (val - tensor.data[i]).abs() < 1e-12 {
209 1
210 } else {
211 0
212 });
213 }
214 push_logical(out, array.shape.clone(), stack)
215 };
216 match (&a, &b) {
217 (Value::Object(obj), _) => {
218 match call_method(Value::Object(obj.clone()), "eq", b.clone()).await {
219 Ok(v) => stack.push(v),
220 Err(_) => {
221 let aa: f64 = (&a).try_into()?;
222 let bb: f64 = (&b).try_into()?;
223 stack.push(Value::Num(if aa == bb { 1.0 } else { 0.0 }))
224 }
225 }
226 }
227 (_, Value::Object(obj)) => {
228 match call_method(Value::Object(obj.clone()), "eq", a.clone()).await {
229 Ok(v) => stack.push(v),
230 Err(_) => {
231 let aa: f64 = (&a).try_into()?;
232 let bb: f64 = (&b).try_into()?;
233 stack.push(Value::Num(if aa == bb { 1.0 } else { 0.0 }))
234 }
235 }
236 }
237 (Value::HandleObject(_), _) | (_, Value::HandleObject(_)) => {
238 stack.push(call_builtin("eq", a.clone(), b.clone()).await?);
239 }
240 (Value::LogicalArray(la), Value::LogicalArray(lb)) => {
241 if la.shape != lb.shape {
242 return Err(crate::interpreter::errors::mex(
243 "ShapeMismatch",
244 "shape mismatch for element-wise comparison",
245 ));
246 }
247 let mut out = Vec::with_capacity(la.data.len());
248 for i in 0..la.data.len() {
249 out.push(if la.data[i] == lb.data[i] { 1 } else { 0 });
250 }
251 push_logical(out, la.shape.clone(), stack)?;
252 }
253 (Value::LogicalArray(la), Value::Num(n)) => logical_eq_scalar(la, *n, stack)?,
254 (Value::LogicalArray(la), Value::Int(i)) => logical_eq_scalar(la, i.to_f64(), stack)?,
255 (Value::LogicalArray(la), Value::Bool(flag)) => {
256 logical_eq_scalar(la, if *flag { 1.0 } else { 0.0 }, stack)?
257 }
258 (Value::Num(n), Value::LogicalArray(lb)) => logical_eq_scalar(lb, *n, stack)?,
259 (Value::Int(i), Value::LogicalArray(lb)) => logical_eq_scalar(lb, i.to_f64(), stack)?,
260 (Value::Bool(flag), Value::LogicalArray(lb)) => {
261 logical_eq_scalar(lb, if *flag { 1.0 } else { 0.0 }, stack)?
262 }
263 (Value::LogicalArray(la), Value::Tensor(tb)) => logical_eq_tensor(la, tb, stack)?,
264 (Value::Tensor(ta), Value::LogicalArray(lb)) => logical_eq_tensor(lb, ta, stack)?,
265 (Value::Tensor(ta), Value::Tensor(tb)) => {
266 if ta.shape != tb.shape {
267 return Err(crate::interpreter::errors::mex(
268 "ShapeMismatch",
269 "shape mismatch for element-wise comparison",
270 ));
271 }
272 let mut out = Vec::with_capacity(ta.data.len());
273 for i in 0..ta.data.len() {
274 out.push(if (ta.data[i] - tb.data[i]).abs() < 1e-12 {
275 1.0
276 } else {
277 0.0
278 });
279 }
280 stack.push(Value::Tensor(
281 runmat_builtins::Tensor::new(out, ta.shape.clone())
282 .map_err(|e| format!("eq: {e}"))?,
283 ));
284 }
285 (Value::Tensor(t), Value::Num(_)) | (Value::Tensor(t), Value::Int(_)) => {
286 let s = match &b {
287 Value::Num(n) => *n,
288 Value::Int(i) => i.to_f64(),
289 _ => 0.0,
290 };
291 let out: Vec<f64> = t
292 .data
293 .iter()
294 .map(|x| if (*x - s).abs() < 1e-12 { 1.0 } else { 0.0 })
295 .collect();
296 stack.push(Value::Tensor(
297 runmat_builtins::Tensor::new(out, t.shape.clone())
298 .map_err(|e| format!("eq: {e}"))?,
299 ));
300 }
301 (Value::Num(_), Value::Tensor(t)) | (Value::Int(_), Value::Tensor(t)) => {
302 let s = match &a {
303 Value::Num(n) => *n,
304 Value::Int(i) => i.to_f64(),
305 _ => 0.0,
306 };
307 let out: Vec<f64> = t
308 .data
309 .iter()
310 .map(|x| if (s - *x).abs() < 1e-12 { 1.0 } else { 0.0 })
311 .collect();
312 stack.push(Value::Tensor(
313 runmat_builtins::Tensor::new(out, t.shape.clone())
314 .map_err(|e| format!("eq: {e}"))?,
315 ));
316 }
317 (Value::StringArray(sa), Value::StringArray(sb)) => {
318 if sa.shape != sb.shape {
319 return Err(crate::interpreter::errors::mex(
320 "ShapeMismatch",
321 "shape mismatch for string array comparison",
322 ));
323 }
324 let mut out = Vec::with_capacity(sa.data.len());
325 for i in 0..sa.data.len() {
326 out.push(if sa.data[i] == sb.data[i] { 1.0 } else { 0.0 });
327 }
328 stack.push(Value::Tensor(
329 runmat_builtins::Tensor::new(out, sa.shape.clone())
330 .map_err(|e| format!("eq: {e}"))?,
331 ));
332 }
333 (Value::StringArray(sa), Value::String(s)) => {
334 let mut out = Vec::with_capacity(sa.data.len());
335 for i in 0..sa.data.len() {
336 out.push(if sa.data[i] == *s { 1.0 } else { 0.0 });
337 }
338 stack.push(Value::Tensor(
339 runmat_builtins::Tensor::new(out, sa.shape.clone())
340 .map_err(|e| format!("eq: {e}"))?,
341 ));
342 }
343 (Value::String(s), Value::StringArray(sa)) => {
344 let mut out = Vec::with_capacity(sa.data.len());
345 for i in 0..sa.data.len() {
346 out.push(if *s == sa.data[i] { 1.0 } else { 0.0 });
347 }
348 stack.push(Value::Tensor(
349 runmat_builtins::Tensor::new(out, sa.shape.clone())
350 .map_err(|e| format!("eq: {e}"))?,
351 ));
352 }
353 (Value::String(a_s), Value::String(b_s)) => {
354 stack.push(Value::Num(if a_s == b_s { 1.0 } else { 0.0 }))
355 }
356 _ => {
357 let bb: f64 = (&b).try_into()?;
358 let aa: f64 = (&a).try_into()?;
359 stack.push(Value::Num(if aa == bb { 1.0 } else { 0.0 }));
360 }
361 }
362 Ok(())
363}
364
365pub async fn not_equal<CM, CMFut, B, BFut, LT, LTFut>(
366 stack: &mut Vec<Value>,
367 mut call_method: CM,
368 mut call_builtin: B,
369 mut logical_truth: LT,
370) -> Result<(), RuntimeError>
371where
372 CM: FnMut(Value, &'static str, Value) -> CMFut,
373 CMFut: Future<Output = Result<Value, RuntimeError>>,
374 B: FnMut(&'static str, Value, Value) -> BFut,
375 BFut: Future<Output = Result<Value, RuntimeError>>,
376 LT: FnMut(Value, String) -> LTFut,
377 LTFut: Future<Output = Result<bool, RuntimeError>>,
378{
379 let (a, b) = pop2(stack)?;
380 match (&a, &b) {
381 (Value::Object(obj), _) => {
382 match call_method(Value::Object(obj.clone()), "ne", b.clone()).await {
383 Ok(v) => stack.push(v),
384 Err(_) => match call_method(Value::Object(obj.clone()), "eq", b.clone()).await {
385 Ok(v) => stack.push(Value::Num(
386 if !logical_truth(v, "comparison result".to_string()).await? {
387 1.0
388 } else {
389 0.0
390 },
391 )),
392 Err(_) => {
393 let aa: f64 = (&a).try_into()?;
394 let bb: f64 = (&b).try_into()?;
395 stack.push(Value::Num(if aa != bb { 1.0 } else { 0.0 }));
396 }
397 },
398 }
399 }
400 (_, Value::Object(obj)) => {
401 match call_method(Value::Object(obj.clone()), "ne", a.clone()).await {
402 Ok(v) => stack.push(v),
403 Err(_) => match call_method(Value::Object(obj.clone()), "eq", a.clone()).await {
404 Ok(v) => stack.push(Value::Num(
405 if !logical_truth(v, "comparison result".to_string()).await? {
406 1.0
407 } else {
408 0.0
409 },
410 )),
411 Err(_) => {
412 let aa: f64 = (&a).try_into()?;
413 let bb: f64 = (&b).try_into()?;
414 stack.push(Value::Num(if aa != bb { 1.0 } else { 0.0 }));
415 }
416 },
417 }
418 }
419 (Value::HandleObject(_), _) | (_, Value::HandleObject(_)) => {
420 stack.push(call_builtin("ne", a.clone(), b.clone()).await?)
421 }
422 (Value::Tensor(ta), Value::Tensor(tb)) => {
423 if ta.shape != tb.shape {
424 return Err(crate::interpreter::errors::mex(
425 "ShapeMismatch",
426 "shape mismatch for element-wise comparison",
427 ));
428 }
429 let mut out = Vec::with_capacity(ta.data.len());
430 for i in 0..ta.data.len() {
431 out.push(if (ta.data[i] - tb.data[i]).abs() >= 1e-12 {
432 1.0
433 } else {
434 0.0
435 });
436 }
437 stack.push(Value::Tensor(
438 runmat_builtins::Tensor::new(out, ta.shape.clone())
439 .map_err(|e| format!("ne: {e}"))?,
440 ));
441 }
442 (Value::Tensor(t), Value::Num(_)) | (Value::Tensor(t), Value::Int(_)) => {
443 let s = match &b {
444 Value::Num(n) => *n,
445 Value::Int(i) => i.to_f64(),
446 _ => 0.0,
447 };
448 let out: Vec<f64> = t
449 .data
450 .iter()
451 .map(|x| if (*x - s).abs() >= 1e-12 { 1.0 } else { 0.0 })
452 .collect();
453 stack.push(Value::Tensor(
454 runmat_builtins::Tensor::new(out, t.shape.clone())
455 .map_err(|e| format!("ne: {e}"))?,
456 ));
457 }
458 (Value::Num(_), Value::Tensor(t)) | (Value::Int(_), Value::Tensor(t)) => {
459 let s = match &a {
460 Value::Num(n) => *n,
461 Value::Int(i) => i.to_f64(),
462 _ => 0.0,
463 };
464 let out: Vec<f64> = t
465 .data
466 .iter()
467 .map(|x| if (s - *x).abs() >= 1e-12 { 1.0 } else { 0.0 })
468 .collect();
469 stack.push(Value::Tensor(
470 runmat_builtins::Tensor::new(out, t.shape.clone())
471 .map_err(|e| format!("ne: {e}"))?,
472 ));
473 }
474 (Value::StringArray(sa), Value::StringArray(sb)) => {
475 if sa.shape != sb.shape {
476 return Err(crate::interpreter::errors::mex(
477 "ShapeMismatch",
478 "shape mismatch for string array comparison",
479 ));
480 }
481 let mut out = Vec::with_capacity(sa.data.len());
482 for i in 0..sa.data.len() {
483 out.push(if sa.data[i] != sb.data[i] { 1.0 } else { 0.0 });
484 }
485 stack.push(Value::Tensor(
486 runmat_builtins::Tensor::new(out, sa.shape.clone())
487 .map_err(|e| format!("ne: {e}"))?,
488 ));
489 }
490 (Value::StringArray(sa), Value::String(s)) => {
491 let mut out = Vec::with_capacity(sa.data.len());
492 for i in 0..sa.data.len() {
493 out.push(if sa.data[i] != *s { 1.0 } else { 0.0 });
494 }
495 stack.push(Value::Tensor(
496 runmat_builtins::Tensor::new(out, sa.shape.clone())
497 .map_err(|e| format!("ne: {e}"))?,
498 ));
499 }
500 (Value::String(s), Value::StringArray(sa)) => {
501 let mut out = Vec::with_capacity(sa.data.len());
502 for i in 0..sa.data.len() {
503 out.push(if *s != sa.data[i] { 1.0 } else { 0.0 });
504 }
505 stack.push(Value::Tensor(
506 runmat_builtins::Tensor::new(out, sa.shape.clone())
507 .map_err(|e| format!("ne: {e}"))?,
508 ));
509 }
510 (Value::String(a_s), Value::String(b_s)) => {
511 stack.push(Value::Num(if a_s != b_s { 1.0 } else { 0.0 }))
512 }
513 _ => {
514 let bb: f64 = (&b).try_into()?;
515 let aa: f64 = (&a).try_into()?;
516 stack.push(Value::Num(if aa != bb { 1.0 } else { 0.0 }));
517 }
518 }
519 Ok(())
520}