Skip to main content

sim_lib_numbers_cas_diff/implementation/
function.rs

1//! The `CasDiffLib` and the `diff` callable: the library wiring that registers
2//! symbolic differentiation and the `integrate-sym` function with the runtime.
3
4use std::{any::Any, sync::Arc};
5
6use sim_kernel::{
7    AbiVersion, Args, Callable, ClassRef, Cx, DefaultFactory, Dependency, Error, Export, Expr,
8    Factory, Lib, LibManifest, LibTarget, Linker, Object, Result, Symbol, Value, Version,
9};
10use sim_lib_numbers_cas::{cas_expr_to_value, value_to_cas_expr};
11use sim_lib_numbers_core::domains;
12
13use super::diff::{diff_cas, diff_symbol};
14use super::integrate::integrate_sym_symbol;
15use super::integrate_function::IntegrateSymFunction;
16
17/// The CAS differentiation and integration library.
18///
19/// Loading this [`Lib`] registers the `diff` and `integrate-sym` functions over
20/// `numbers/cas` expressions. It requires the `numbers/cas` domain to be loaded
21/// first.
22pub struct CasDiffLib;
23
24impl CasDiffLib {
25    /// Construct the CAS differentiation library.
26    pub fn new() -> Self {
27        Self
28    }
29}
30
31impl Default for CasDiffLib {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl Lib for CasDiffLib {
38    fn manifest(&self) -> LibManifest {
39        LibManifest {
40            id: domains::cas_diff(),
41            version: Version(env!("CARGO_PKG_VERSION").to_owned()),
42            abi: AbiVersion { major: 0, minor: 1 },
43            target: LibTarget::HostRegistered,
44            requires: Vec::<Dependency>::new(),
45            capabilities: Vec::new(),
46            exports: vec![
47                Export::Function {
48                    symbol: diff_symbol(),
49                    function_id: None,
50                },
51                Export::Function {
52                    symbol: integrate_sym_symbol(),
53                    function_id: None,
54                },
55            ],
56        }
57    }
58
59    fn load(&self, _cx: &mut sim_kernel::LoadCx, linker: &mut Linker<'_>) -> Result<()> {
60        linker.function_value(
61            diff_symbol(),
62            DefaultFactory
63                .opaque(Arc::new(DiffFunction))
64                .expect("diff function should be boxable"),
65        )?;
66        linker.function_value(
67            integrate_sym_symbol(),
68            DefaultFactory
69                .opaque(Arc::new(IntegrateSymFunction))
70                .expect("integrate-sym function should be boxable"),
71        )?;
72        Ok(())
73    }
74}
75
76#[derive(Clone)]
77struct DiffFunction;
78
79impl Object for DiffFunction {
80    fn display(&self, _cx: &mut Cx) -> Result<String> {
81        Ok(format!("#<function {}>", diff_symbol()))
82    }
83
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87}
88
89impl sim_kernel::ObjectCompat for DiffFunction {
90    fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
91        if let Some(value) = cx
92            .registry()
93            .class_by_symbol(&Symbol::qualified("core", "Function"))
94        {
95            return Ok(value.clone());
96        }
97        DefaultFactory.class_stub(
98            sim_kernel::CORE_FUNCTION_CLASS_ID,
99            Symbol::qualified("core", "Function"),
100        )
101    }
102    fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
103        Ok(Expr::Symbol(diff_symbol()))
104    }
105    fn as_callable(&self) -> Option<&dyn Callable> {
106        Some(self)
107    }
108}
109
110impl Callable for DiffFunction {
111    fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
112        let values = args.into_vec();
113        let [expr_value, var] = values.as_slice() else {
114            return Err(Error::Eval(format!(
115                "{} expects exactly two arguments",
116                diff_symbol()
117            )));
118        };
119        let var = extract_symbolish(cx, var)?.ok_or_else(|| {
120            Error::Eval(format!(
121                "{} expects a quoted symbol or symbol as its second argument",
122                diff_symbol()
123            ))
124        })?;
125        if let Some(number) = cx.number_value_ref(expr_value.clone())?
126            && number.domain == domains::func()
127        {
128            return diff_func_value(cx, expr_value.clone(), &var);
129        }
130        let expr = value_to_cas_expr(cx, expr_value.clone())?;
131        let derivative = diff_cas(cx, &expr, &var)?;
132        cas_expr_to_value(cx, derivative)
133    }
134}
135
136use sim_lib_numbers_cas::extract_symbolish;
137
138fn diff_func_value(cx: &mut Cx, value: Value, var: &Symbol) -> Result<Value> {
139    let expr = value.object().as_expr(cx)?;
140    let Expr::Call { operator, args } = expr else {
141        return Err(Error::Eval(
142            "NotDifferentiable: function value does not expose a symbolic body".to_owned(),
143        ));
144    };
145    let Expr::Symbol(operator) = operator.as_ref() else {
146        return Err(Error::Eval(
147            "NotDifferentiable: function value does not expose a symbolic body".to_owned(),
148        ));
149    };
150    if *operator != Symbol::new("fn") {
151        return Err(Error::Eval(
152            "NotDifferentiable: function value does not expose a symbolic body".to_owned(),
153        ));
154    }
155    let [vars_expr, body_expr] = args.as_slice() else {
156        return Err(Error::Eval(
157            "NotDifferentiable: function value had an invalid fn surface".to_owned(),
158        ));
159    };
160    let body = value_to_cas_expr(cx, cx.factory().expr(body_expr.clone())?)?;
161    let derivative = diff_cas(cx, &body, var)?;
162    let derivative_expr = sim_lib_numbers_cas::cas_expr_to_surface_expr(cx, &derivative)?;
163    cx.eval_expr(Expr::Call {
164        operator: Box::new(Expr::Symbol(Symbol::new("fn"))),
165        args: vec![vars_expr.clone(), derivative_expr],
166    })
167}