sim_lib_numbers_cas_diff/implementation/
function.rs1use std::{any::Any, sync::Arc};
5
6use sim_kernel::{
7 AbiVersion, Args, Callable, ClassRef, Cx, DefaultFactory, Dependency, Error, Export, Expr,
8 Factory, Lib, LibManifest, LibTarget, Linker, Object, Result, Symbol, Value, Version,
9};
10use sim_lib_numbers_cas::{cas_expr_to_value, value_to_cas_expr};
11use sim_lib_numbers_core::domains;
12
13use super::diff::{diff_cas, diff_symbol};
14use super::integrate::integrate_sym_symbol;
15use super::integrate_function::IntegrateSymFunction;
16
17pub struct CasDiffLib;
23
24impl CasDiffLib {
25 pub fn new() -> Self {
27 Self
28 }
29}
30
31impl Default for CasDiffLib {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl Lib for CasDiffLib {
38 fn manifest(&self) -> LibManifest {
39 LibManifest {
40 id: domains::cas_diff(),
41 version: Version(env!("CARGO_PKG_VERSION").to_owned()),
42 abi: AbiVersion { major: 0, minor: 1 },
43 target: LibTarget::HostRegistered,
44 requires: Vec::<Dependency>::new(),
45 capabilities: Vec::new(),
46 exports: vec![
47 Export::Function {
48 symbol: diff_symbol(),
49 function_id: None,
50 },
51 Export::Function {
52 symbol: integrate_sym_symbol(),
53 function_id: None,
54 },
55 ],
56 }
57 }
58
59 fn load(&self, _cx: &mut sim_kernel::LoadCx, linker: &mut Linker<'_>) -> Result<()> {
60 linker.function_value(
61 diff_symbol(),
62 DefaultFactory
63 .opaque(Arc::new(DiffFunction))
64 .expect("diff function should be boxable"),
65 )?;
66 linker.function_value(
67 integrate_sym_symbol(),
68 DefaultFactory
69 .opaque(Arc::new(IntegrateSymFunction))
70 .expect("integrate-sym function should be boxable"),
71 )?;
72 Ok(())
73 }
74}
75
76#[derive(Clone)]
77struct DiffFunction;
78
79impl Object for DiffFunction {
80 fn display(&self, _cx: &mut Cx) -> Result<String> {
81 Ok(format!("#<function {}>", diff_symbol()))
82 }
83
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87}
88
89impl sim_kernel::ObjectCompat for DiffFunction {
90 fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
91 if let Some(value) = cx
92 .registry()
93 .class_by_symbol(&Symbol::qualified("core", "Function"))
94 {
95 return Ok(value.clone());
96 }
97 DefaultFactory.class_stub(
98 sim_kernel::CORE_FUNCTION_CLASS_ID,
99 Symbol::qualified("core", "Function"),
100 )
101 }
102 fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
103 Ok(Expr::Symbol(diff_symbol()))
104 }
105 fn as_callable(&self) -> Option<&dyn Callable> {
106 Some(self)
107 }
108}
109
110impl Callable for DiffFunction {
111 fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
112 let values = args.into_vec();
113 let [expr_value, var] = values.as_slice() else {
114 return Err(Error::Eval(format!(
115 "{} expects exactly two arguments",
116 diff_symbol()
117 )));
118 };
119 let var = extract_symbolish(cx, var)?.ok_or_else(|| {
120 Error::Eval(format!(
121 "{} expects a quoted symbol or symbol as its second argument",
122 diff_symbol()
123 ))
124 })?;
125 if let Some(number) = cx.number_value_ref(expr_value.clone())?
126 && number.domain == domains::func()
127 {
128 return diff_func_value(cx, expr_value.clone(), &var);
129 }
130 let expr = value_to_cas_expr(cx, expr_value.clone())?;
131 let derivative = diff_cas(cx, &expr, &var)?;
132 cas_expr_to_value(cx, derivative)
133 }
134}
135
136use sim_lib_numbers_cas::extract_symbolish;
137
138fn diff_func_value(cx: &mut Cx, value: Value, var: &Symbol) -> Result<Value> {
139 let expr = value.object().as_expr(cx)?;
140 let Expr::Call { operator, args } = expr else {
141 return Err(Error::Eval(
142 "NotDifferentiable: function value does not expose a symbolic body".to_owned(),
143 ));
144 };
145 let Expr::Symbol(operator) = operator.as_ref() else {
146 return Err(Error::Eval(
147 "NotDifferentiable: function value does not expose a symbolic body".to_owned(),
148 ));
149 };
150 if *operator != Symbol::new("fn") {
151 return Err(Error::Eval(
152 "NotDifferentiable: function value does not expose a symbolic body".to_owned(),
153 ));
154 }
155 let [vars_expr, body_expr] = args.as_slice() else {
156 return Err(Error::Eval(
157 "NotDifferentiable: function value had an invalid fn surface".to_owned(),
158 ));
159 };
160 let body = value_to_cas_expr(cx, cx.factory().expr(body_expr.clone())?)?;
161 let derivative = diff_cas(cx, &body, var)?;
162 let derivative_expr = sim_lib_numbers_cas::cas_expr_to_surface_expr(cx, &derivative)?;
163 cx.eval_expr(Expr::Call {
164 operator: Box::new(Expr::Symbol(Symbol::new("fn"))),
165 args: vec![vars_expr.clone(), derivative_expr],
166 })
167}