Skip to main content

runmat_vm/indexing/
end_expr.rs

1use crate::bytecode::{EndExpr, UserFunction};
2use crate::interpreter::errors::mex;
3use runmat_builtins::Value;
4use runmat_runtime::RuntimeError;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8
9#[derive(Debug, Clone, Copy)]
10pub struct ValueToF64Error;
11
12pub type BuiltinEndCallback<'a> = dyn Fn(
13        &'a str,
14        Vec<Value>,
15    ) -> Pin<Box<dyn Future<Output = Result<Option<Value>, RuntimeError>> + 'a>>
16    + 'a;
17
18pub type UserEndCallback<'a> = dyn Fn(
19        &'a str,
20        Vec<Value>,
21        &'a HashMap<String, UserFunction>,
22        &'a [Value],
23    ) -> Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + 'a>>
24    + 'a;
25
26pub fn value_to_f64(v: &Value) -> Result<f64, ValueToF64Error> {
27    match v {
28        Value::Num(n) => Ok(*n),
29        Value::Int(i) => Ok(i.to_f64()),
30        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
31        Value::Tensor(t) if t.data.len() == 1 => Ok(t.data[0]),
32        Value::Complex(re, im) if im.abs() < 1e-12 => Ok(*re),
33        Value::ComplexTensor(ct) if ct.data.len() == 1 && ct.data[0].1.abs() < 1e-12 => {
34            Ok(ct.data[0].0)
35        }
36        _ => Err(ValueToF64Error),
37    }
38}
39
40pub fn eval_end_expr_value<'a>(
41    expr: &'a EndExpr,
42    end_value: f64,
43    vars: &'a [Value],
44    functions: &'a HashMap<String, UserFunction>,
45    call_builtin: &'a BuiltinEndCallback<'a>,
46    call_user: &'a UserEndCallback<'a>,
47) -> Pin<Box<dyn Future<Output = Result<f64, RuntimeError>> + 'a>> {
48    Box::pin(async move {
49        match expr {
50            EndExpr::End => Ok(end_value),
51            EndExpr::Const(v) => Ok(*v),
52            EndExpr::Var(i) => {
53                let v = vars.get(*i).ok_or_else(|| {
54                    mex("MissingNumericIndex", "missing variable for end expression")
55                })?;
56                value_to_f64(v)
57                    .map_err(|_| mex("UnsupportedIndexType", "end expression must be numeric"))
58            }
59            EndExpr::Call(name, args) => {
60                let mut argv: Vec<Value> = Vec::with_capacity(args.len());
61                for a in args {
62                    let val =
63                        eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
64                            .await?;
65                    argv.push(Value::Num(val));
66                }
67                let v = if let Some(v) = call_builtin(name, argv.clone()).await? {
68                    v
69                } else if functions.contains_key(name) {
70                    call_user(name, argv, functions, vars).await?
71                } else {
72                    return Err(mex(
73                        "UndefinedFunction",
74                        &format!("Undefined function in end expression: {name}"),
75                    ));
76                };
77                value_to_f64(&v)
78                    .map_err(|_| mex("UnsupportedIndexType", "end call must return scalar"))
79            }
80            EndExpr::Add(a, b) => {
81                let lhs =
82                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
83                        .await?;
84                let rhs =
85                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
86                        .await?;
87                Ok(lhs + rhs)
88            }
89            EndExpr::Sub(a, b) => {
90                let lhs =
91                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
92                        .await?;
93                let rhs =
94                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
95                        .await?;
96                Ok(lhs - rhs)
97            }
98            EndExpr::Mul(a, b) => {
99                let lhs =
100                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
101                        .await?;
102                let rhs =
103                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
104                        .await?;
105                Ok(lhs * rhs)
106            }
107            EndExpr::Div(a, b) => {
108                let denom =
109                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
110                        .await?;
111                if denom == 0.0 {
112                    return Err(mex("IndexOutOfBounds", "Index out of bounds"));
113                }
114                let lhs =
115                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
116                        .await?;
117                Ok(lhs / denom)
118            }
119            EndExpr::LeftDiv(a, b) => {
120                let denom =
121                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
122                        .await?;
123                if denom == 0.0 {
124                    return Err(mex("IndexOutOfBounds", "Index out of bounds"));
125                }
126                let rhs =
127                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
128                        .await?;
129                Ok(rhs / denom)
130            }
131            EndExpr::Pow(a, b) => {
132                let lhs =
133                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
134                        .await?;
135                let rhs =
136                    eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
137                        .await?;
138                Ok(lhs.powf(rhs))
139            }
140            EndExpr::Neg(a) => {
141                Ok(
142                    -eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
143                        .await?,
144                )
145            }
146            EndExpr::Pos(a) => {
147                Ok(
148                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
149                        .await?,
150                )
151            }
152            EndExpr::Floor(a) => {
153                Ok(
154                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
155                        .await?
156                        .floor(),
157                )
158            }
159            EndExpr::Ceil(a) => {
160                Ok(
161                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
162                        .await?
163                        .ceil(),
164                )
165            }
166            EndExpr::Round(a) => {
167                Ok(
168                    eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
169                        .await?
170                        .round(),
171                )
172            }
173            EndExpr::Fix(a) => {
174                let v = eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
175                    .await?;
176                Ok(if v >= 0.0 { v.floor() } else { v.ceil() })
177            }
178        }
179    })
180}
181
182pub async fn resolve_range_end_index<'a>(
183    dim_len: usize,
184    end_expr: &'a EndExpr,
185    vars: &'a [Value],
186    functions: &'a HashMap<String, UserFunction>,
187    call_builtin: &'a BuiltinEndCallback<'a>,
188    call_user: &'a UserEndCallback<'a>,
189) -> Result<i64, RuntimeError> {
190    let value = eval_end_expr_value(
191        end_expr,
192        dim_len as f64,
193        vars,
194        functions,
195        call_builtin,
196        call_user,
197    )
198    .await?;
199    Ok(value.floor() as i64)
200}