Skip to main content

sim_lib_numbers_func/implementation/
function.rs

1//! Function operations: the `fn`, `call`, and `grad` callables and the
2//! function class builder backing the `Func` domain.
3
4use std::{any::Any, sync::Arc};
5
6use sim_kernel::{
7    Args, Callable, Class, ClassId, ClassRef, Cx, DefaultFactory, Env, Error, Expr, Factory,
8    Object, ObjectEncode, ObjectEncoding, ReadConstructor, ReadConstructorRef, Result, ShapeRef,
9    Symbol, TableRef, Value,
10};
11use sim_lib_numbers_cas::{cas_expr_to_surface_expr, expr_to_cas_expr};
12
13use super::domain::{func_class_symbol, value_shape_symbol};
14use super::value::{Func, FuncMetadata, build_func_value};
15
16/// Returns the symbol bound to the `fn` function-builder callable.
17///
18/// # Examples
19///
20/// ```
21/// use sim_lib_numbers_func::{call_symbol, fn_symbol, grad_symbol};
22///
23/// assert_eq!(fn_symbol().to_string(), "fn");
24/// assert_eq!(call_symbol().to_string(), "call");
25/// assert_eq!(grad_symbol().to_string(), "grad");
26/// ```
27pub fn fn_symbol() -> Symbol {
28    Symbol::new("fn")
29}
30
31/// Returns the symbol bound to the `call` apply-a-function callable.
32pub fn call_symbol() -> Symbol {
33    Symbol::new("call")
34}
35
36/// Returns the symbol bound to the `grad` gradient-of-a-function callable.
37pub fn grad_symbol() -> Symbol {
38    Symbol::new("grad")
39}
40
41#[derive(Clone)]
42pub struct FnBuilder;
43
44impl Object for FnBuilder {
45    fn display(&self, _cx: &mut Cx) -> Result<String> {
46        Ok("#<function fn>".to_owned())
47    }
48
49    fn as_any(&self) -> &dyn Any {
50        self
51    }
52}
53
54impl sim_kernel::ObjectCompat for FnBuilder {
55    fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
56        function_class(cx)
57    }
58    fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
59        Ok(Expr::Symbol(fn_symbol()))
60    }
61    fn as_callable(&self) -> Option<&dyn Callable> {
62        Some(self)
63    }
64}
65
66impl Callable for FnBuilder {
67    fn call(&self, _cx: &mut Cx, _args: Args) -> Result<Value> {
68        Err(Error::Eval(
69            "fn must be called with unevaluated parameters and a body".to_owned(),
70        ))
71    }
72
73    fn call_exprs(&self, cx: &mut Cx, args: sim_kernel::RawArgs) -> Result<Value> {
74        let args = args.into_exprs();
75        let [vars_expr, body_expr] = args.as_slice() else {
76            return Err(Error::Eval(
77                "fn expects exactly a parameter list and one body expression".to_owned(),
78            ));
79        };
80        let vars = parse_vars_expr(vars_expr)?;
81        let body_cas = expr_to_cas_expr(cx, body_expr)?
82            .ok_or_else(|| Error::Eval("fn body must be CAS-compatible".to_owned()))?;
83        build_func_value(
84            cx,
85            Func::new(vars, Some(body_cas), None, FuncMetadata::default()),
86        )
87    }
88}
89
90#[derive(Clone)]
91pub struct CallFunction;
92
93impl Object for CallFunction {
94    fn display(&self, _cx: &mut Cx) -> Result<String> {
95        Ok("#<function call>".to_owned())
96    }
97
98    fn as_any(&self) -> &dyn Any {
99        self
100    }
101}
102
103impl sim_kernel::ObjectCompat for CallFunction {
104    fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
105        function_class(cx)
106    }
107    fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
108        Ok(Expr::Symbol(call_symbol()))
109    }
110    fn as_callable(&self) -> Option<&dyn Callable> {
111        Some(self)
112    }
113}
114
115impl Callable for CallFunction {
116    fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
117        let mut values = args.into_vec();
118        if values.is_empty() {
119            return Err(Error::Eval(
120                "call expects a callable value and at least zero arguments".to_owned(),
121            ));
122        }
123        let callable = values.remove(0);
124        cx.call_value(callable, Args::new(values))
125    }
126}
127
128#[derive(Clone)]
129pub struct GradFunction;
130
131impl Object for GradFunction {
132    fn display(&self, _cx: &mut Cx) -> Result<String> {
133        Ok("#<function grad>".to_owned())
134    }
135
136    fn as_any(&self) -> &dyn Any {
137        self
138    }
139}
140
141impl sim_kernel::ObjectCompat for GradFunction {
142    fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
143        function_class(cx)
144    }
145    fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
146        Ok(Expr::Symbol(grad_symbol()))
147    }
148    fn as_callable(&self) -> Option<&dyn Callable> {
149        Some(self)
150    }
151}
152
153impl Callable for GradFunction {
154    fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
155        let values = args.into_vec();
156        let [value] = values.as_slice() else {
157            return Err(Error::Eval(
158                "grad expects exactly one function value".to_owned(),
159            ));
160        };
161        let func = expect_func(value)?;
162        let mut grads = Vec::with_capacity(func.vars.len());
163        for var in &func.vars {
164            let var_value = cx.factory().symbol(var.clone())?;
165            grads.push(cx.call_function(
166                &Symbol::new("diff"),
167                Args::new(vec![value.clone(), var_value]),
168            )?);
169        }
170        cx.factory().list(grads)
171    }
172}
173
174pub(crate) struct FuncValueClass {
175    id: std::sync::atomic::AtomicU32,
176}
177
178pub(crate) fn build_func_class() -> Arc<FuncValueClass> {
179    Arc::new(FuncValueClass {
180        id: std::sync::atomic::AtomicU32::new(0),
181    })
182}
183
184impl FuncValueClass {
185    pub(crate) fn set_id(&self, id: ClassId) {
186        self.id.store(id.0, std::sync::atomic::Ordering::Relaxed);
187    }
188}
189
190impl Object for FuncValueClass {
191    fn display(&self, _cx: &mut Cx) -> Result<String> {
192        Ok(format!("#<class {}>", func_class_symbol()))
193    }
194
195    fn as_any(&self) -> &dyn Any {
196        self
197    }
198}
199
200impl sim_kernel::ObjectCompat for FuncValueClass {
201    fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
202        if let Some(value) = cx
203            .registry()
204            .class_by_symbol(&Symbol::qualified("core", "Class"))
205        {
206            return Ok(value.clone());
207        }
208        DefaultFactory.class_stub(
209            sim_kernel::CORE_CLASS_CLASS_ID,
210            Symbol::qualified("core", "Class"),
211        )
212    }
213    fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
214        Ok(Expr::Symbol(func_class_symbol()))
215    }
216    fn as_callable(&self) -> Option<&dyn Callable> {
217        Some(self)
218    }
219    fn as_class(&self) -> Option<&dyn Class> {
220        Some(self)
221    }
222    fn as_read_constructor(&self) -> Option<&dyn ReadConstructor> {
223        Some(self)
224    }
225}
226
227impl Callable for FuncValueClass {
228    fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
229        let values = args.into_vec();
230        let [vars_value, body_value] = values.as_slice() else {
231            return Err(Error::Eval(format!(
232                "class {} expects exactly two arguments",
233                func_class_symbol()
234            )));
235        };
236        let vars = parse_vars_value(cx, vars_value)?;
237        let body_expr = body_value.object().as_expr(cx)?;
238        let body_cas = expr_to_cas_expr(cx, &body_expr)?
239            .ok_or_else(|| Error::Eval("numbers/Func body must be CAS-compatible".to_owned()))?;
240        build_func_value(
241            cx,
242            Func::new(vars, Some(body_cas), None, FuncMetadata::default()),
243        )
244    }
245}
246
247impl Class for FuncValueClass {
248    fn id(&self) -> ClassId {
249        ClassId(self.id.load(std::sync::atomic::Ordering::Relaxed))
250    }
251
252    fn symbol(&self) -> Symbol {
253        func_class_symbol()
254    }
255
256    fn constructor_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
257        cx.factory().nil()
258    }
259
260    fn instance_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
261        Ok(cx
262            .registry()
263            .shape_by_symbol(&value_shape_symbol())
264            .cloned()
265            .unwrap_or(cx.factory().symbol(value_shape_symbol())?))
266    }
267
268    fn read_constructor(&self, cx: &mut Cx) -> Result<Option<ReadConstructorRef>> {
269        Ok(cx.registry().class_by_symbol(&func_class_symbol()).cloned())
270    }
271
272    fn members(&self, cx: &mut Cx) -> Result<TableRef> {
273        cx.factory().table(Vec::new())
274    }
275}
276
277impl ReadConstructor for FuncValueClass {
278    fn symbol(&self) -> Symbol {
279        func_class_symbol()
280    }
281
282    fn args_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
283        cx.factory().nil()
284    }
285
286    fn construct_read(&self, cx: &mut Cx, args: Vec<Value>) -> Result<Value> {
287        self.call(cx, Args::new(args))
288    }
289}
290
291impl ObjectEncode for Func {
292    fn object_encoding(&self, cx: &mut Cx) -> Result<ObjectEncoding> {
293        let Some(body_cas) = &self.body_cas else {
294            return Err(Error::Eval(
295                "native-only functions do not have a read-construct encoding".to_owned(),
296            ));
297        };
298        Ok(ObjectEncoding::Constructor {
299            class: func_class_symbol(),
300            args: vec![
301                vars_expr(&self.vars),
302                cas_expr_to_surface_expr(cx, body_cas)?,
303            ],
304        })
305    }
306}
307
308pub(crate) fn parse_vars_expr(expr: &Expr) -> Result<Vec<Symbol>> {
309    let Expr::List(items) = expr else {
310        return Err(Error::Eval(
311            "function parameter list must be a list of symbols".to_owned(),
312        ));
313    };
314    items
315        .iter()
316        .map(|item| match item {
317            Expr::Symbol(symbol) => Ok(symbol.clone()),
318            _ => Err(Error::Eval(
319                "function parameter list must contain only symbols".to_owned(),
320            )),
321        })
322        .collect()
323}
324
325fn parse_vars_value(cx: &mut Cx, value: &Value) -> Result<Vec<Symbol>> {
326    parse_vars_expr(&value.object().as_expr(cx)?)
327}
328
329pub(crate) fn vars_expr(vars: &[Symbol]) -> Expr {
330    Expr::List(vars.iter().cloned().map(Expr::Symbol).collect())
331}
332
333pub(crate) fn function_class(cx: &mut Cx) -> Result<ClassRef> {
334    if let Some(value) = cx
335        .registry()
336        .class_by_symbol(&Symbol::qualified("core", "Function"))
337    {
338        return Ok(value.clone());
339    }
340    cx.factory().class_stub(
341        sim_kernel::CORE_FUNCTION_CLASS_ID,
342        Symbol::qualified("core", "Function"),
343    )
344}
345
346pub(crate) fn expect_func(value: &Value) -> Result<&Func> {
347    value
348        .object()
349        .downcast_ref::<Func>()
350        .ok_or_else(|| Error::Eval("expected a numbers/func value".to_owned()))
351}
352
353pub(crate) fn child_env_with_args(parent: &Env, vars: &[Symbol], args: &[Value]) -> Result<Env> {
354    if vars.len() != args.len() {
355        return Err(Error::Eval(format!(
356            "function expected {} arguments but received {}",
357            vars.len(),
358            args.len()
359        )));
360    }
361    let mut env = Env::child(Arc::new(parent.clone()));
362    for (var, value) in vars.iter().cloned().zip(args.iter().cloned()) {
363        env.define(var, value);
364    }
365    Ok(env)
366}