sim_lib_numbers_func/implementation/
function.rs1use std::{any::Any, sync::Arc};
5
6use sim_kernel::{
7 Args, Callable, Class, ClassId, ClassRef, Cx, DefaultFactory, Env, Error, Expr, Factory,
8 Object, ObjectEncode, ObjectEncoding, ReadConstructor, ReadConstructorRef, Result, ShapeRef,
9 Symbol, TableRef, Value,
10};
11use sim_lib_numbers_cas::{cas_expr_to_surface_expr, expr_to_cas_expr};
12
13use super::domain::{func_class_symbol, value_shape_symbol};
14use super::value::{Func, FuncMetadata, build_func_value};
15
16pub fn fn_symbol() -> Symbol {
28 Symbol::new("fn")
29}
30
31pub fn call_symbol() -> Symbol {
33 Symbol::new("call")
34}
35
36pub fn grad_symbol() -> Symbol {
38 Symbol::new("grad")
39}
40
41#[derive(Clone)]
42pub struct FnBuilder;
43
44impl Object for FnBuilder {
45 fn display(&self, _cx: &mut Cx) -> Result<String> {
46 Ok("#<function fn>".to_owned())
47 }
48
49 fn as_any(&self) -> &dyn Any {
50 self
51 }
52}
53
54impl sim_kernel::ObjectCompat for FnBuilder {
55 fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
56 function_class(cx)
57 }
58 fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
59 Ok(Expr::Symbol(fn_symbol()))
60 }
61 fn as_callable(&self) -> Option<&dyn Callable> {
62 Some(self)
63 }
64}
65
66impl Callable for FnBuilder {
67 fn call(&self, _cx: &mut Cx, _args: Args) -> Result<Value> {
68 Err(Error::Eval(
69 "fn must be called with unevaluated parameters and a body".to_owned(),
70 ))
71 }
72
73 fn call_exprs(&self, cx: &mut Cx, args: sim_kernel::RawArgs) -> Result<Value> {
74 let args = args.into_exprs();
75 let [vars_expr, body_expr] = args.as_slice() else {
76 return Err(Error::Eval(
77 "fn expects exactly a parameter list and one body expression".to_owned(),
78 ));
79 };
80 let vars = parse_vars_expr(vars_expr)?;
81 let body_cas = expr_to_cas_expr(cx, body_expr)?
82 .ok_or_else(|| Error::Eval("fn body must be CAS-compatible".to_owned()))?;
83 build_func_value(
84 cx,
85 Func::new(vars, Some(body_cas), None, FuncMetadata::default()),
86 )
87 }
88}
89
90#[derive(Clone)]
91pub struct CallFunction;
92
93impl Object for CallFunction {
94 fn display(&self, _cx: &mut Cx) -> Result<String> {
95 Ok("#<function call>".to_owned())
96 }
97
98 fn as_any(&self) -> &dyn Any {
99 self
100 }
101}
102
103impl sim_kernel::ObjectCompat for CallFunction {
104 fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
105 function_class(cx)
106 }
107 fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
108 Ok(Expr::Symbol(call_symbol()))
109 }
110 fn as_callable(&self) -> Option<&dyn Callable> {
111 Some(self)
112 }
113}
114
115impl Callable for CallFunction {
116 fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
117 let mut values = args.into_vec();
118 if values.is_empty() {
119 return Err(Error::Eval(
120 "call expects a callable value and at least zero arguments".to_owned(),
121 ));
122 }
123 let callable = values.remove(0);
124 cx.call_value(callable, Args::new(values))
125 }
126}
127
128#[derive(Clone)]
129pub struct GradFunction;
130
131impl Object for GradFunction {
132 fn display(&self, _cx: &mut Cx) -> Result<String> {
133 Ok("#<function grad>".to_owned())
134 }
135
136 fn as_any(&self) -> &dyn Any {
137 self
138 }
139}
140
141impl sim_kernel::ObjectCompat for GradFunction {
142 fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
143 function_class(cx)
144 }
145 fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
146 Ok(Expr::Symbol(grad_symbol()))
147 }
148 fn as_callable(&self) -> Option<&dyn Callable> {
149 Some(self)
150 }
151}
152
153impl Callable for GradFunction {
154 fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
155 let values = args.into_vec();
156 let [value] = values.as_slice() else {
157 return Err(Error::Eval(
158 "grad expects exactly one function value".to_owned(),
159 ));
160 };
161 let func = expect_func(value)?;
162 let mut grads = Vec::with_capacity(func.vars.len());
163 for var in &func.vars {
164 let var_value = cx.factory().symbol(var.clone())?;
165 grads.push(cx.call_function(
166 &Symbol::new("diff"),
167 Args::new(vec![value.clone(), var_value]),
168 )?);
169 }
170 cx.factory().list(grads)
171 }
172}
173
174pub(crate) struct FuncValueClass {
175 id: std::sync::atomic::AtomicU32,
176}
177
178pub(crate) fn build_func_class() -> Arc<FuncValueClass> {
179 Arc::new(FuncValueClass {
180 id: std::sync::atomic::AtomicU32::new(0),
181 })
182}
183
184impl FuncValueClass {
185 pub(crate) fn set_id(&self, id: ClassId) {
186 self.id.store(id.0, std::sync::atomic::Ordering::Relaxed);
187 }
188}
189
190impl Object for FuncValueClass {
191 fn display(&self, _cx: &mut Cx) -> Result<String> {
192 Ok(format!("#<class {}>", func_class_symbol()))
193 }
194
195 fn as_any(&self) -> &dyn Any {
196 self
197 }
198}
199
200impl sim_kernel::ObjectCompat for FuncValueClass {
201 fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
202 if let Some(value) = cx
203 .registry()
204 .class_by_symbol(&Symbol::qualified("core", "Class"))
205 {
206 return Ok(value.clone());
207 }
208 DefaultFactory.class_stub(
209 sim_kernel::CORE_CLASS_CLASS_ID,
210 Symbol::qualified("core", "Class"),
211 )
212 }
213 fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
214 Ok(Expr::Symbol(func_class_symbol()))
215 }
216 fn as_callable(&self) -> Option<&dyn Callable> {
217 Some(self)
218 }
219 fn as_class(&self) -> Option<&dyn Class> {
220 Some(self)
221 }
222 fn as_read_constructor(&self) -> Option<&dyn ReadConstructor> {
223 Some(self)
224 }
225}
226
227impl Callable for FuncValueClass {
228 fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
229 let values = args.into_vec();
230 let [vars_value, body_value] = values.as_slice() else {
231 return Err(Error::Eval(format!(
232 "class {} expects exactly two arguments",
233 func_class_symbol()
234 )));
235 };
236 let vars = parse_vars_value(cx, vars_value)?;
237 let body_expr = body_value.object().as_expr(cx)?;
238 let body_cas = expr_to_cas_expr(cx, &body_expr)?
239 .ok_or_else(|| Error::Eval("numbers/Func body must be CAS-compatible".to_owned()))?;
240 build_func_value(
241 cx,
242 Func::new(vars, Some(body_cas), None, FuncMetadata::default()),
243 )
244 }
245}
246
247impl Class for FuncValueClass {
248 fn id(&self) -> ClassId {
249 ClassId(self.id.load(std::sync::atomic::Ordering::Relaxed))
250 }
251
252 fn symbol(&self) -> Symbol {
253 func_class_symbol()
254 }
255
256 fn constructor_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
257 cx.factory().nil()
258 }
259
260 fn instance_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
261 Ok(cx
262 .registry()
263 .shape_by_symbol(&value_shape_symbol())
264 .cloned()
265 .unwrap_or(cx.factory().symbol(value_shape_symbol())?))
266 }
267
268 fn read_constructor(&self, cx: &mut Cx) -> Result<Option<ReadConstructorRef>> {
269 Ok(cx.registry().class_by_symbol(&func_class_symbol()).cloned())
270 }
271
272 fn members(&self, cx: &mut Cx) -> Result<TableRef> {
273 cx.factory().table(Vec::new())
274 }
275}
276
277impl ReadConstructor for FuncValueClass {
278 fn symbol(&self) -> Symbol {
279 func_class_symbol()
280 }
281
282 fn args_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
283 cx.factory().nil()
284 }
285
286 fn construct_read(&self, cx: &mut Cx, args: Vec<Value>) -> Result<Value> {
287 self.call(cx, Args::new(args))
288 }
289}
290
291impl ObjectEncode for Func {
292 fn object_encoding(&self, cx: &mut Cx) -> Result<ObjectEncoding> {
293 let Some(body_cas) = &self.body_cas else {
294 return Err(Error::Eval(
295 "native-only functions do not have a read-construct encoding".to_owned(),
296 ));
297 };
298 Ok(ObjectEncoding::Constructor {
299 class: func_class_symbol(),
300 args: vec![
301 vars_expr(&self.vars),
302 cas_expr_to_surface_expr(cx, body_cas)?,
303 ],
304 })
305 }
306}
307
308pub(crate) fn parse_vars_expr(expr: &Expr) -> Result<Vec<Symbol>> {
309 let Expr::List(items) = expr else {
310 return Err(Error::Eval(
311 "function parameter list must be a list of symbols".to_owned(),
312 ));
313 };
314 items
315 .iter()
316 .map(|item| match item {
317 Expr::Symbol(symbol) => Ok(symbol.clone()),
318 _ => Err(Error::Eval(
319 "function parameter list must contain only symbols".to_owned(),
320 )),
321 })
322 .collect()
323}
324
325fn parse_vars_value(cx: &mut Cx, value: &Value) -> Result<Vec<Symbol>> {
326 parse_vars_expr(&value.object().as_expr(cx)?)
327}
328
329pub(crate) fn vars_expr(vars: &[Symbol]) -> Expr {
330 Expr::List(vars.iter().cloned().map(Expr::Symbol).collect())
331}
332
333pub(crate) fn function_class(cx: &mut Cx) -> Result<ClassRef> {
334 if let Some(value) = cx
335 .registry()
336 .class_by_symbol(&Symbol::qualified("core", "Function"))
337 {
338 return Ok(value.clone());
339 }
340 cx.factory().class_stub(
341 sim_kernel::CORE_FUNCTION_CLASS_ID,
342 Symbol::qualified("core", "Function"),
343 )
344}
345
346pub(crate) fn expect_func(value: &Value) -> Result<&Func> {
347 value
348 .object()
349 .downcast_ref::<Func>()
350 .ok_or_else(|| Error::Eval("expected a numbers/func value".to_owned()))
351}
352
353pub(crate) fn child_env_with_args(parent: &Env, vars: &[Symbol], args: &[Value]) -> Result<Env> {
354 if vars.len() != args.len() {
355 return Err(Error::Eval(format!(
356 "function expected {} arguments but received {}",
357 vars.len(),
358 args.len()
359 )));
360 }
361 let mut env = Env::child(Arc::new(parent.clone()));
362 for (var, value) in vars.iter().cloned().zip(args.iter().cloned()) {
363 env.define(var, value);
364 }
365 Ok(env)
366}