1use std::{any::Any, sync::Arc};
5
6use sim_kernel::{
7 Args, Callable, ClassRef, Cx, DefaultFactory, Error, Expr, Factory, NumberValue, Object,
8 ObjectEncode, Result, Symbol, Value, ValueNumberBinaryOp, ValueNumberUnaryOp,
9};
10use sim_lib_numbers_cas::{CasExpr, cas_expr_to_surface_expr, simplify_expr};
11use sim_lib_numbers_cas_eval::eval_cas;
12
13use super::domain::{func_class_symbol, func_domain_symbol};
14use super::function::{child_env_with_args, vars_expr};
15
16pub type NativeFn = Arc<dyn Fn(&mut Cx, &[Value]) -> Result<Value> + Send + Sync>;
19
20#[derive(Clone, Default)]
22pub struct FuncMetadata {
23 pub source: Option<Symbol>,
26 pub differentiator_hint: Option<Symbol>,
29 pub payload: Option<Value>,
31}
32
33#[derive(Clone)]
36pub struct Func {
37 pub vars: Vec<Symbol>,
39 pub body_cas: Option<CasExpr>,
41 pub body_native: Option<NativeFn>,
43 pub metadata: FuncMetadata,
45}
46
47impl Func {
48 pub fn new(
51 vars: Vec<Symbol>,
52 body_cas: Option<CasExpr>,
53 body_native: Option<NativeFn>,
54 metadata: FuncMetadata,
55 ) -> Self {
56 Self {
57 vars,
58 body_cas,
59 body_native,
60 metadata,
61 }
62 }
63
64 pub fn symbolic(vars: Vec<Symbol>, body_cas: CasExpr) -> Self {
66 Self::new(vars, Some(body_cas), None, FuncMetadata::default())
67 }
68
69 pub fn native(vars: Vec<Symbol>, body_native: NativeFn) -> Self {
87 Self::new(vars, None, Some(body_native), FuncMetadata::default())
88 }
89
90 fn invoke(&self, cx: &mut Cx, args: &[Value]) -> Result<Value> {
91 if let Some(body_native) = &self.body_native {
92 return body_native(cx, args);
93 }
94 let Some(body_cas) = &self.body_cas else {
95 return Err(Error::Eval(
96 "function has neither symbolic nor native body".to_owned(),
97 ));
98 };
99 let env = child_env_with_args(cx.env(), &self.vars, args)?;
100 cx.with_env(env.clone(), |cx| eval_cas(cx, body_cas, &env))
101 }
102}
103
104impl Object for Func {
105 fn display(&self, cx: &mut Cx) -> Result<String> {
106 if let Some(body_cas) = &self.body_cas {
107 return Ok(format!(
108 "#<func {:?} -> {:?}>",
109 self.vars,
110 cas_expr_to_surface_expr(cx, body_cas)?
111 ));
112 }
113 Ok(format!("#<native-func {:?}>", self.vars))
114 }
115
116 fn as_any(&self) -> &dyn Any {
117 self
118 }
119}
120
121impl sim_kernel::ObjectCompat for Func {
122 fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
123 if let Some(value) = cx.registry().class_by_symbol(&func_class_symbol()) {
124 return Ok(value.clone());
125 }
126 DefaultFactory.class_stub(
127 sim_kernel::CORE_NUMBER_CLASS_ID,
128 Symbol::qualified("core", "Number"),
129 )
130 }
131 fn as_expr(&self, cx: &mut Cx) -> Result<Expr> {
132 let Some(body_cas) = &self.body_cas else {
133 return Ok(Expr::Extension {
134 tag: func_class_symbol(),
135 payload: Box::new(Expr::String("#<native-func>".to_owned())),
136 });
137 };
138 Ok(Expr::Call {
139 operator: Box::new(Expr::Symbol(Symbol::new("fn"))),
140 args: vec![
141 vars_expr(&self.vars),
142 cas_expr_to_surface_expr(cx, body_cas)?,
143 ],
144 })
145 }
146 fn as_table(&self, cx: &mut Cx) -> Result<Value> {
147 let vars = cx.factory().list(
148 self.vars
149 .iter()
150 .cloned()
151 .map(|var| cx.factory().symbol(var))
152 .collect::<Result<Vec<_>>>()?,
153 )?;
154 let body_expr = self
155 .body_cas
156 .as_ref()
157 .map(|body| cas_expr_to_surface_expr(cx, body))
158 .transpose()?;
159 let body = match &self.body_cas {
160 Some(_) => cx
161 .factory()
162 .expr(body_expr.expect("body expr should exist when body_cas is present"))?,
163 None => cx.factory().nil()?,
164 };
165 let native = cx.factory().bool(self.body_native.is_some())?;
166 cx.factory().table(vec![
167 (Symbol::new("kind"), cx.factory().string("func".to_owned())?),
168 (Symbol::new("vars"), vars),
169 (Symbol::new("body"), body),
170 (Symbol::new("native"), native),
171 ])
172 }
173 fn as_callable(&self) -> Option<&dyn Callable> {
174 Some(self)
175 }
176 fn as_number_value(&self) -> Option<&dyn NumberValue> {
177 Some(self)
178 }
179 fn as_object_encoder(&self) -> Option<&dyn ObjectEncode> {
180 Some(self)
181 }
182}
183
184impl Callable for Func {
185 fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
186 self.invoke(cx, args.values())
187 }
188}
189
190impl NumberValue for Func {
191 fn number_domain(&self, _cx: &mut Cx) -> Result<Symbol> {
192 Ok(func_domain_symbol())
193 }
194}
195
196impl sim_citizen::Citizen for Func {
197 fn citizen_symbol() -> Symbol {
198 func_class_symbol()
199 }
200
201 fn citizen_version() -> u32 {
202 0
203 }
204
205 fn citizen_arity() -> usize {
206 2
207 }
208
209 fn citizen_fields() -> &'static [&'static str] {
210 &["vars", "body"]
211 }
212}
213
214pub fn build_func_value(cx: &mut Cx, func: Func) -> Result<Value> {
216 cx.factory().opaque(Arc::new(func))
217}
218
219pub(crate) fn build_constant_func_value(cx: &mut Cx, value: Value) -> Result<Value> {
220 build_func_value(
221 cx,
222 Func::new(
223 Vec::new(),
224 Some(CasExpr::Num(value)),
225 None,
226 FuncMetadata::default(),
227 ),
228 )
229}
230
231pub(crate) fn register_value_ops(linker: &mut sim_kernel::Linker<'_>) {
232 linker.value_number_binary_op(binary_op(
233 Symbol::qualified("math", "add"),
234 apply_add_func_op,
235 ));
236 linker.value_number_binary_op(binary_op(
237 Symbol::qualified("math", "sub"),
238 apply_sub_func_op,
239 ));
240 linker.value_number_binary_op(binary_op(
241 Symbol::qualified("math", "mul"),
242 apply_mul_func_op,
243 ));
244 linker.value_number_binary_op(binary_op(
245 Symbol::qualified("math", "div"),
246 apply_div_func_op,
247 ));
248 linker.value_number_binary_op(binary_op(
249 Symbol::qualified("math", "pow"),
250 apply_pow_func_op,
251 ));
252 linker.value_number_binary_op(binary_op(
253 Symbol::qualified("math", "rem"),
254 apply_rem_func_op,
255 ));
256 linker.value_number_unary_op(ValueNumberUnaryOp {
257 operator: Symbol::qualified("math", "neg"),
258 operand_domain: func_domain_symbol(),
259 cost: 1,
260 apply: apply_unary_func_op,
261 });
262}
263
264fn binary_op(
265 operator: Symbol,
266 apply: fn(&mut Cx, Value, Value) -> Result<Value>,
267) -> ValueNumberBinaryOp {
268 ValueNumberBinaryOp {
269 operator,
270 left_domain: func_domain_symbol(),
271 right_domain: func_domain_symbol(),
272 cost: 1,
273 apply,
274 }
275}
276
277fn apply_add_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
278 apply_binary_func_op(cx, Symbol::qualified("math", "add"), left, right)
279}
280
281fn apply_sub_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
282 apply_binary_func_op(cx, Symbol::qualified("math", "sub"), left, right)
283}
284
285fn apply_mul_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
286 apply_binary_func_op(cx, Symbol::qualified("math", "mul"), left, right)
287}
288
289fn apply_div_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
290 apply_binary_func_op(cx, Symbol::qualified("math", "div"), left, right)
291}
292
293fn apply_pow_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
294 apply_binary_func_op(cx, Symbol::qualified("math", "pow"), left, right)
295}
296
297fn apply_rem_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
298 apply_binary_func_op(cx, Symbol::qualified("math", "rem"), left, right)
299}
300
301fn apply_binary_func_op(cx: &mut Cx, operator: Symbol, left: Value, right: Value) -> Result<Value> {
302 let left_func = left
303 .object()
304 .downcast_ref::<Func>()
305 .ok_or_else(|| Error::Eval("left operand was not a function value".to_owned()))?
306 .clone();
307 let right_func = right
308 .object()
309 .downcast_ref::<Func>()
310 .ok_or_else(|| Error::Eval("right operand was not a function value".to_owned()))?
311 .clone();
312 let vars = union_vars(&left_func.vars, &right_func.vars);
313 let closure_vars = vars.clone();
314 let body_cas = match (&left_func.body_cas, &right_func.body_cas) {
315 (Some(left_body), Some(right_body)) => Some(simplify_expr(
316 cx,
317 CasExpr::Op(
318 operator.clone(),
319 vec![left_body.clone(), right_body.clone()],
320 ),
321 )?),
322 _ => None,
323 };
324 let native: NativeFn = Arc::new(move |cx: &mut Cx, args: &[Value]| {
325 let left_args = project_args(&closure_vars, &left_func.vars, args)?;
326 let right_args = project_args(&closure_vars, &right_func.vars, args)?;
327 let left_value = left_func.invoke(cx, &left_args)?;
328 let right_value = right_func.invoke(cx, &right_args)?;
329 cx.apply_value_number_binary_op(&operator, left_value, right_value)
330 });
331 let body_native = body_cas.is_none().then_some(native);
332 build_func_value(
333 cx,
334 Func::new(vars, body_cas, body_native, FuncMetadata::default()),
335 )
336}
337
338fn apply_unary_func_op(cx: &mut Cx, value: Value) -> Result<Value> {
339 let func = value
340 .object()
341 .downcast_ref::<Func>()
342 .ok_or_else(|| Error::Eval("operand was not a function value".to_owned()))?
343 .clone();
344 let body_cas = func
345 .body_cas
346 .clone()
347 .map(|body| {
348 simplify_expr(
349 cx,
350 CasExpr::Op(Symbol::qualified("math", "neg"), vec![body]),
351 )
352 })
353 .transpose()?;
354 let native_func = func.clone();
355 let native: NativeFn = Arc::new(move |cx: &mut Cx, args: &[Value]| {
356 let out = native_func.invoke(cx, args)?;
357 cx.apply_value_number_unary_op(&Symbol::qualified("math", "neg"), out)
358 });
359 let body_native = body_cas.is_none().then_some(native);
360 build_func_value(
361 cx,
362 Func::new(
363 func.vars.clone(),
364 body_cas,
365 body_native,
366 FuncMetadata::default(),
367 ),
368 )
369}
370
371fn union_vars(left: &[Symbol], right: &[Symbol]) -> Vec<Symbol> {
372 let mut vars = left.to_vec();
373 for var in right {
374 if !vars.contains(var) {
375 vars.push(var.clone());
376 }
377 }
378 vars
379}
380
381fn project_args(union: &[Symbol], target: &[Symbol], args: &[Value]) -> Result<Vec<Value>> {
382 target
383 .iter()
384 .map(|var| {
385 let index = union
386 .iter()
387 .position(|candidate| candidate == var)
388 .ok_or_else(|| {
389 Error::Eval(format!(
390 "function variable {var} missing from projected call"
391 ))
392 })?;
393 args.get(index).cloned().ok_or_else(|| {
394 Error::Eval(format!(
395 "function variable {var} missing from call arguments"
396 ))
397 })
398 })
399 .collect()
400}