Skip to main content

sim_lib_numbers_func/implementation/
value.rs

1//! The `Func` function value: variables plus an optional CAS or native body,
2//! with its metadata and arithmetic over function values.
3
4use std::{any::Any, sync::Arc};
5
6use sim_kernel::{
7    Args, Callable, ClassRef, Cx, DefaultFactory, Error, Expr, Factory, NumberValue, Object,
8    ObjectEncode, Result, Symbol, Value, ValueNumberBinaryOp, ValueNumberUnaryOp,
9};
10use sim_lib_numbers_cas::{CasExpr, cas_expr_to_surface_expr, simplify_expr};
11use sim_lib_numbers_cas_eval::eval_cas;
12
13use super::domain::{func_class_symbol, func_domain_symbol};
14use super::function::{child_env_with_args, vars_expr};
15
16/// A native (Rust-backed) function body: a closure invoked with the runtime
17/// context and the evaluated argument values, used when a `Func` has no CAS body.
18pub type NativeFn = Arc<dyn Fn(&mut Cx, &[Value]) -> Result<Value> + Send + Sync>;
19
20/// Out-of-band annotations attached to a [`Func`] value.
21#[derive(Clone, Default)]
22pub struct FuncMetadata {
23    /// Symbol identifying where this function came from (for example, an
24    /// elementary-function name), when known.
25    pub source: Option<Symbol>,
26    /// Optional hint naming the differentiator that should handle `grad`/`diff`
27    /// for this function.
28    pub differentiator_hint: Option<Symbol>,
29    /// Arbitrary caller-supplied value carried alongside the function.
30    pub payload: Option<Value>,
31}
32
33/// A callable function value in the `Func` number domain: its bound variables
34/// plus an optional symbolic (CAS) body and/or native body.
35#[derive(Clone)]
36pub struct Func {
37    /// The ordered parameter symbols bound when the function is invoked.
38    pub vars: Vec<Symbol>,
39    /// The symbolic body, when the function can be expressed as a CAS expression.
40    pub body_cas: Option<CasExpr>,
41    /// The native body, used when no symbolic body is available.
42    pub body_native: Option<NativeFn>,
43    /// Out-of-band metadata describing the function.
44    pub metadata: FuncMetadata,
45}
46
47impl Func {
48    /// Builds a function from its variables, optional symbolic and native
49    /// bodies, and metadata.
50    pub fn new(
51        vars: Vec<Symbol>,
52        body_cas: Option<CasExpr>,
53        body_native: Option<NativeFn>,
54        metadata: FuncMetadata,
55    ) -> Self {
56        Self {
57            vars,
58            body_cas,
59            body_native,
60            metadata,
61        }
62    }
63
64    /// Builds a function with a symbolic (CAS) body and default metadata.
65    pub fn symbolic(vars: Vec<Symbol>, body_cas: CasExpr) -> Self {
66        Self::new(vars, Some(body_cas), None, FuncMetadata::default())
67    }
68
69    /// Builds a function with a native (Rust closure) body and default metadata.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use std::sync::Arc;
75    /// use sim_kernel::Symbol;
76    /// use sim_lib_numbers_func::Func;
77    ///
78    /// let func = Func::native(
79    ///     vec![Symbol::new("x")],
80    ///     Arc::new(|_cx, args| Ok(args[0].clone())),
81    /// );
82    /// assert_eq!(func.vars, vec![Symbol::new("x")]);
83    /// assert!(func.body_cas.is_none());
84    /// assert!(func.body_native.is_some());
85    /// ```
86    pub fn native(vars: Vec<Symbol>, body_native: NativeFn) -> Self {
87        Self::new(vars, None, Some(body_native), FuncMetadata::default())
88    }
89
90    fn invoke(&self, cx: &mut Cx, args: &[Value]) -> Result<Value> {
91        if let Some(body_native) = &self.body_native {
92            return body_native(cx, args);
93        }
94        let Some(body_cas) = &self.body_cas else {
95            return Err(Error::Eval(
96                "function has neither symbolic nor native body".to_owned(),
97            ));
98        };
99        let env = child_env_with_args(cx.env(), &self.vars, args)?;
100        cx.with_env(env.clone(), |cx| eval_cas(cx, body_cas, &env))
101    }
102}
103
104impl Object for Func {
105    fn display(&self, cx: &mut Cx) -> Result<String> {
106        if let Some(body_cas) = &self.body_cas {
107            return Ok(format!(
108                "#<func {:?} -> {:?}>",
109                self.vars,
110                cas_expr_to_surface_expr(cx, body_cas)?
111            ));
112        }
113        Ok(format!("#<native-func {:?}>", self.vars))
114    }
115
116    fn as_any(&self) -> &dyn Any {
117        self
118    }
119}
120
121impl sim_kernel::ObjectCompat for Func {
122    fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
123        if let Some(value) = cx.registry().class_by_symbol(&func_class_symbol()) {
124            return Ok(value.clone());
125        }
126        DefaultFactory.class_stub(
127            sim_kernel::CORE_NUMBER_CLASS_ID,
128            Symbol::qualified("core", "Number"),
129        )
130    }
131    fn as_expr(&self, cx: &mut Cx) -> Result<Expr> {
132        let Some(body_cas) = &self.body_cas else {
133            return Ok(Expr::Extension {
134                tag: func_class_symbol(),
135                payload: Box::new(Expr::String("#<native-func>".to_owned())),
136            });
137        };
138        Ok(Expr::Call {
139            operator: Box::new(Expr::Symbol(Symbol::new("fn"))),
140            args: vec![
141                vars_expr(&self.vars),
142                cas_expr_to_surface_expr(cx, body_cas)?,
143            ],
144        })
145    }
146    fn as_table(&self, cx: &mut Cx) -> Result<Value> {
147        let vars = cx.factory().list(
148            self.vars
149                .iter()
150                .cloned()
151                .map(|var| cx.factory().symbol(var))
152                .collect::<Result<Vec<_>>>()?,
153        )?;
154        let body_expr = self
155            .body_cas
156            .as_ref()
157            .map(|body| cas_expr_to_surface_expr(cx, body))
158            .transpose()?;
159        let body = match &self.body_cas {
160            Some(_) => cx
161                .factory()
162                .expr(body_expr.expect("body expr should exist when body_cas is present"))?,
163            None => cx.factory().nil()?,
164        };
165        let native = cx.factory().bool(self.body_native.is_some())?;
166        cx.factory().table(vec![
167            (Symbol::new("kind"), cx.factory().string("func".to_owned())?),
168            (Symbol::new("vars"), vars),
169            (Symbol::new("body"), body),
170            (Symbol::new("native"), native),
171        ])
172    }
173    fn as_callable(&self) -> Option<&dyn Callable> {
174        Some(self)
175    }
176    fn as_number_value(&self) -> Option<&dyn NumberValue> {
177        Some(self)
178    }
179    fn as_object_encoder(&self) -> Option<&dyn ObjectEncode> {
180        Some(self)
181    }
182}
183
184impl Callable for Func {
185    fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
186        self.invoke(cx, args.values())
187    }
188}
189
190impl NumberValue for Func {
191    fn number_domain(&self, _cx: &mut Cx) -> Result<Symbol> {
192        Ok(func_domain_symbol())
193    }
194}
195
196impl sim_citizen::Citizen for Func {
197    fn citizen_symbol() -> Symbol {
198        func_class_symbol()
199    }
200
201    fn citizen_version() -> u32 {
202        0
203    }
204
205    fn citizen_arity() -> usize {
206        2
207    }
208
209    fn citizen_fields() -> &'static [&'static str] {
210        &["vars", "body"]
211    }
212}
213
214/// Wraps a [`Func`] into a runtime [`Value`] in the `Func` number domain.
215pub fn build_func_value(cx: &mut Cx, func: Func) -> Result<Value> {
216    cx.factory().opaque(Arc::new(func))
217}
218
219pub(crate) fn build_constant_func_value(cx: &mut Cx, value: Value) -> Result<Value> {
220    build_func_value(
221        cx,
222        Func::new(
223            Vec::new(),
224            Some(CasExpr::Num(value)),
225            None,
226            FuncMetadata::default(),
227        ),
228    )
229}
230
231pub(crate) fn register_value_ops(linker: &mut sim_kernel::Linker<'_>) {
232    linker.value_number_binary_op(binary_op(
233        Symbol::qualified("math", "add"),
234        apply_add_func_op,
235    ));
236    linker.value_number_binary_op(binary_op(
237        Symbol::qualified("math", "sub"),
238        apply_sub_func_op,
239    ));
240    linker.value_number_binary_op(binary_op(
241        Symbol::qualified("math", "mul"),
242        apply_mul_func_op,
243    ));
244    linker.value_number_binary_op(binary_op(
245        Symbol::qualified("math", "div"),
246        apply_div_func_op,
247    ));
248    linker.value_number_binary_op(binary_op(
249        Symbol::qualified("math", "pow"),
250        apply_pow_func_op,
251    ));
252    linker.value_number_binary_op(binary_op(
253        Symbol::qualified("math", "rem"),
254        apply_rem_func_op,
255    ));
256    linker.value_number_unary_op(ValueNumberUnaryOp {
257        operator: Symbol::qualified("math", "neg"),
258        operand_domain: func_domain_symbol(),
259        cost: 1,
260        apply: apply_unary_func_op,
261    });
262}
263
264fn binary_op(
265    operator: Symbol,
266    apply: fn(&mut Cx, Value, Value) -> Result<Value>,
267) -> ValueNumberBinaryOp {
268    ValueNumberBinaryOp {
269        operator,
270        left_domain: func_domain_symbol(),
271        right_domain: func_domain_symbol(),
272        cost: 1,
273        apply,
274    }
275}
276
277fn apply_add_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
278    apply_binary_func_op(cx, Symbol::qualified("math", "add"), left, right)
279}
280
281fn apply_sub_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
282    apply_binary_func_op(cx, Symbol::qualified("math", "sub"), left, right)
283}
284
285fn apply_mul_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
286    apply_binary_func_op(cx, Symbol::qualified("math", "mul"), left, right)
287}
288
289fn apply_div_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
290    apply_binary_func_op(cx, Symbol::qualified("math", "div"), left, right)
291}
292
293fn apply_pow_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
294    apply_binary_func_op(cx, Symbol::qualified("math", "pow"), left, right)
295}
296
297fn apply_rem_func_op(cx: &mut Cx, left: Value, right: Value) -> Result<Value> {
298    apply_binary_func_op(cx, Symbol::qualified("math", "rem"), left, right)
299}
300
301fn apply_binary_func_op(cx: &mut Cx, operator: Symbol, left: Value, right: Value) -> Result<Value> {
302    let left_func = left
303        .object()
304        .downcast_ref::<Func>()
305        .ok_or_else(|| Error::Eval("left operand was not a function value".to_owned()))?
306        .clone();
307    let right_func = right
308        .object()
309        .downcast_ref::<Func>()
310        .ok_or_else(|| Error::Eval("right operand was not a function value".to_owned()))?
311        .clone();
312    let vars = union_vars(&left_func.vars, &right_func.vars);
313    let closure_vars = vars.clone();
314    let body_cas = match (&left_func.body_cas, &right_func.body_cas) {
315        (Some(left_body), Some(right_body)) => Some(simplify_expr(
316            cx,
317            CasExpr::Op(
318                operator.clone(),
319                vec![left_body.clone(), right_body.clone()],
320            ),
321        )?),
322        _ => None,
323    };
324    let native: NativeFn = Arc::new(move |cx: &mut Cx, args: &[Value]| {
325        let left_args = project_args(&closure_vars, &left_func.vars, args)?;
326        let right_args = project_args(&closure_vars, &right_func.vars, args)?;
327        let left_value = left_func.invoke(cx, &left_args)?;
328        let right_value = right_func.invoke(cx, &right_args)?;
329        cx.apply_value_number_binary_op(&operator, left_value, right_value)
330    });
331    let body_native = body_cas.is_none().then_some(native);
332    build_func_value(
333        cx,
334        Func::new(vars, body_cas, body_native, FuncMetadata::default()),
335    )
336}
337
338fn apply_unary_func_op(cx: &mut Cx, value: Value) -> Result<Value> {
339    let func = value
340        .object()
341        .downcast_ref::<Func>()
342        .ok_or_else(|| Error::Eval("operand was not a function value".to_owned()))?
343        .clone();
344    let body_cas = func
345        .body_cas
346        .clone()
347        .map(|body| {
348            simplify_expr(
349                cx,
350                CasExpr::Op(Symbol::qualified("math", "neg"), vec![body]),
351            )
352        })
353        .transpose()?;
354    let native_func = func.clone();
355    let native: NativeFn = Arc::new(move |cx: &mut Cx, args: &[Value]| {
356        let out = native_func.invoke(cx, args)?;
357        cx.apply_value_number_unary_op(&Symbol::qualified("math", "neg"), out)
358    });
359    let body_native = body_cas.is_none().then_some(native);
360    build_func_value(
361        cx,
362        Func::new(
363            func.vars.clone(),
364            body_cas,
365            body_native,
366            FuncMetadata::default(),
367        ),
368    )
369}
370
371fn union_vars(left: &[Symbol], right: &[Symbol]) -> Vec<Symbol> {
372    let mut vars = left.to_vec();
373    for var in right {
374        if !vars.contains(var) {
375            vars.push(var.clone());
376        }
377    }
378    vars
379}
380
381fn project_args(union: &[Symbol], target: &[Symbol], args: &[Value]) -> Result<Vec<Value>> {
382    target
383        .iter()
384        .map(|var| {
385            let index = union
386                .iter()
387                .position(|candidate| candidate == var)
388                .ok_or_else(|| {
389                    Error::Eval(format!(
390                        "function variable {var} missing from projected call"
391                    ))
392                })?;
393            args.get(index).cloned().ok_or_else(|| {
394                Error::Eval(format!(
395                    "function variable {var} missing from call arguments"
396                ))
397            })
398        })
399        .collect()
400}