1use crate::bytecode::{EndExpr, UserFunction};
2use crate::interpreter::errors::mex;
3use runmat_builtins::Value;
4use runmat_runtime::RuntimeError;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8
9#[derive(Debug, Clone, Copy)]
10pub struct ValueToF64Error;
11
12pub type BuiltinEndCallback<'a> = dyn Fn(
13 &'a str,
14 Vec<Value>,
15 ) -> Pin<Box<dyn Future<Output = Result<Option<Value>, RuntimeError>> + 'a>>
16 + 'a;
17
18pub type UserEndCallback<'a> = dyn Fn(
19 &'a str,
20 Vec<Value>,
21 &'a HashMap<String, UserFunction>,
22 &'a [Value],
23 ) -> Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + 'a>>
24 + 'a;
25
26pub fn value_to_f64(v: &Value) -> Result<f64, ValueToF64Error> {
27 match v {
28 Value::Num(n) => Ok(*n),
29 Value::Int(i) => Ok(i.to_f64()),
30 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
31 Value::Tensor(t) if t.data.len() == 1 => Ok(t.data[0]),
32 Value::Complex(re, im) if im.abs() < 1e-12 => Ok(*re),
33 Value::ComplexTensor(ct) if ct.data.len() == 1 && ct.data[0].1.abs() < 1e-12 => {
34 Ok(ct.data[0].0)
35 }
36 _ => Err(ValueToF64Error),
37 }
38}
39
40pub fn eval_end_expr_value<'a>(
41 expr: &'a EndExpr,
42 end_value: f64,
43 vars: &'a [Value],
44 functions: &'a HashMap<String, UserFunction>,
45 call_builtin: &'a BuiltinEndCallback<'a>,
46 call_user: &'a UserEndCallback<'a>,
47) -> Pin<Box<dyn Future<Output = Result<f64, RuntimeError>> + 'a>> {
48 Box::pin(async move {
49 match expr {
50 EndExpr::End => Ok(end_value),
51 EndExpr::Const(v) => Ok(*v),
52 EndExpr::Var(i) => {
53 let v = vars.get(*i).ok_or_else(|| {
54 mex("MissingNumericIndex", "missing variable for end expression")
55 })?;
56 value_to_f64(v)
57 .map_err(|_| mex("UnsupportedIndexType", "end expression must be numeric"))
58 }
59 EndExpr::Call(name, args) => {
60 let mut argv: Vec<Value> = Vec::with_capacity(args.len());
61 for a in args {
62 let val =
63 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
64 .await?;
65 argv.push(Value::Num(val));
66 }
67 let v = if let Some(v) = call_builtin(name, argv.clone()).await? {
68 v
69 } else if functions.contains_key(name) {
70 call_user(name, argv, functions, vars).await?
71 } else {
72 return Err(mex(
73 "UndefinedFunction",
74 &format!("Undefined function in end expression: {name}"),
75 ));
76 };
77 value_to_f64(&v)
78 .map_err(|_| mex("UnsupportedIndexType", "end call must return scalar"))
79 }
80 EndExpr::Add(a, b) => {
81 let lhs =
82 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
83 .await?;
84 let rhs =
85 eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
86 .await?;
87 Ok(lhs + rhs)
88 }
89 EndExpr::Sub(a, b) => {
90 let lhs =
91 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
92 .await?;
93 let rhs =
94 eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
95 .await?;
96 Ok(lhs - rhs)
97 }
98 EndExpr::Mul(a, b) => {
99 let lhs =
100 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
101 .await?;
102 let rhs =
103 eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
104 .await?;
105 Ok(lhs * rhs)
106 }
107 EndExpr::Div(a, b) => {
108 let denom =
109 eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
110 .await?;
111 if denom == 0.0 {
112 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
113 }
114 let lhs =
115 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
116 .await?;
117 Ok(lhs / denom)
118 }
119 EndExpr::LeftDiv(a, b) => {
120 let denom =
121 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
122 .await?;
123 if denom == 0.0 {
124 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
125 }
126 let rhs =
127 eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
128 .await?;
129 Ok(rhs / denom)
130 }
131 EndExpr::Pow(a, b) => {
132 let lhs =
133 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
134 .await?;
135 let rhs =
136 eval_end_expr_value(b, end_value, vars, functions, call_builtin, call_user)
137 .await?;
138 Ok(lhs.powf(rhs))
139 }
140 EndExpr::Neg(a) => {
141 Ok(
142 -eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
143 .await?,
144 )
145 }
146 EndExpr::Pos(a) => {
147 Ok(
148 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
149 .await?,
150 )
151 }
152 EndExpr::Floor(a) => {
153 Ok(
154 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
155 .await?
156 .floor(),
157 )
158 }
159 EndExpr::Ceil(a) => {
160 Ok(
161 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
162 .await?
163 .ceil(),
164 )
165 }
166 EndExpr::Round(a) => {
167 Ok(
168 eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
169 .await?
170 .round(),
171 )
172 }
173 EndExpr::Fix(a) => {
174 let v = eval_end_expr_value(a, end_value, vars, functions, call_builtin, call_user)
175 .await?;
176 Ok(if v >= 0.0 { v.floor() } else { v.ceil() })
177 }
178 }
179 })
180}
181
182pub async fn resolve_range_end_index<'a>(
183 dim_len: usize,
184 end_expr: &'a EndExpr,
185 vars: &'a [Value],
186 functions: &'a HashMap<String, UserFunction>,
187 call_builtin: &'a BuiltinEndCallback<'a>,
188 call_user: &'a UserEndCallback<'a>,
189) -> Result<i64, RuntimeError> {
190 let value = eval_end_expr_value(
191 end_expr,
192 dim_len as f64,
193 vars,
194 functions,
195 call_builtin,
196 call_user,
197 )
198 .await?;
199 Ok(value.floor() as i64)
200}