Skip to main content

sim_lib_numbers_cas_diff/implementation/
diff.rs

1//! Symbolic differentiation of `CasExpr` trees: the recursive `diff` rules for
2//! the builtin operators, falling through to the extensible rule registry.
3
4use sim_kernel::{Cx, Error, Result, Symbol, Value};
5use sim_lib_numbers_cas::{CasExpr, simplify_expr};
6use sim_lib_numbers_core::domains;
7
8use super::registry::apply_registered_rule;
9
10/// The `diff` symbol: the symbolic differentiation entry point.
11pub fn diff_symbol() -> Symbol {
12    Symbol::new("diff")
13}
14
15/// Differentiate a [`CasExpr`] with respect to `var`, returning a simplified
16/// derivative tree.
17///
18/// Builtin operators (`+`, `-`, `*`, `/`, `^`, and the standard trig functions)
19/// use their hard-coded rules; unrecognized operators fall through to the
20/// extensible differentiation-rule registry.
21pub fn diff_cas(cx: &mut Cx, expr: &CasExpr, var: &Symbol) -> Result<CasExpr> {
22    let derivative = match expr {
23        CasExpr::Num(_) => zero(cx)?,
24        CasExpr::Var(symbol) if symbol == var => one(cx)?,
25        CasExpr::Var(_) => zero(cx)?,
26        CasExpr::Op(operator, args) if *operator == math("add") => {
27            op(math("add"), diff_all(cx, args, var)?)
28        }
29        CasExpr::Op(operator, args) if *operator == math("sub") => diff_sub(cx, args, var)?,
30        CasExpr::Op(operator, args) if *operator == math("mul") => diff_mul(cx, args, var)?,
31        CasExpr::Op(operator, args) if *operator == math("div") => diff_div(cx, args, var)?,
32        CasExpr::Op(operator, args) if *operator == math("pow") => diff_pow(cx, args, var)?,
33        CasExpr::Op(operator, args) if *operator == Symbol::new("sin") => {
34            chain_rule(cx, Symbol::new("cos"), args, var)?
35        }
36        CasExpr::Op(operator, args) if *operator == Symbol::new("cos") => {
37            let [arg] = one_arg(args, operator)?;
38            op(
39                math("mul"),
40                vec![
41                    neg_one(cx)?,
42                    diff_cas(cx, arg, var)?,
43                    op(Symbol::new("sin"), vec![arg.clone()]),
44                ],
45            )
46        }
47        CasExpr::Op(operator, args) if *operator == Symbol::new("tan") => {
48            let [arg] = one_arg(args, operator)?;
49            op(
50                math("div"),
51                vec![
52                    diff_cas(cx, arg, var)?,
53                    op(
54                        math("pow"),
55                        vec![
56                            op(Symbol::new("cos"), vec![arg.clone()]),
57                            CasExpr::Num(number_constant(cx, "2")?),
58                        ],
59                    ),
60                ],
61            )
62        }
63        CasExpr::Op(operator, args) if *operator == Symbol::new("ln") => {
64            let [arg] = one_arg(args, operator)?;
65            op(math("div"), vec![diff_cas(cx, arg, var)?, arg.clone()])
66        }
67        CasExpr::Op(operator, args) if *operator == Symbol::new("exp") => {
68            chain_rule(cx, Symbol::new("exp"), args, var)?
69        }
70        CasExpr::Op(operator, args) => {
71            if let Some(custom) = apply_registered_rule(operator, args, var) {
72                custom
73            } else {
74                op(
75                    diff_symbol(),
76                    vec![
77                        CasExpr::Op(operator.clone(), args.clone()),
78                        CasExpr::Var(var.clone()),
79                    ],
80                )
81            }
82        }
83    };
84    simplify_expr(cx, derivative)
85}
86
87fn diff_all(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<Vec<CasExpr>> {
88    args.iter().map(|arg| diff_cas(cx, arg, var)).collect()
89}
90
91fn diff_sub(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
92    match args {
93        [] => Err(Error::Eval(
94            "cannot differentiate an empty subtraction".to_owned(),
95        )),
96        [arg] => Ok(op(math("mul"), vec![neg_one(cx)?, diff_cas(cx, arg, var)?])),
97        _ => Ok(op(math("sub"), diff_all(cx, args, var)?)),
98    }
99}
100
101fn diff_mul(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
102    if args.is_empty() {
103        return Err(Error::Eval(
104            "cannot differentiate an empty multiplication".to_owned(),
105        ));
106    }
107    let mut terms = Vec::with_capacity(args.len());
108    for (index, _) in args.iter().enumerate() {
109        let mut factors = Vec::with_capacity(args.len());
110        for (offset, arg) in args.iter().enumerate() {
111            if index == offset {
112                factors.push(diff_cas(cx, arg, var)?);
113            } else {
114                factors.push(arg.clone());
115            }
116        }
117        terms.push(op(math("mul"), factors));
118    }
119    Ok(op(math("add"), terms))
120}
121
122fn diff_div(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
123    match args {
124        [] => Err(Error::Eval(
125            "cannot differentiate an empty division".to_owned(),
126        )),
127        [arg] => Ok(op(
128            math("div"),
129            vec![
130                op(math("mul"), vec![neg_one(cx)?, diff_cas(cx, arg, var)?]),
131                op(
132                    math("pow"),
133                    vec![arg.clone(), CasExpr::Num(number_constant(cx, "2")?)],
134                ),
135            ],
136        )),
137        [left, right] => {
138            let left_diff = diff_cas(cx, left, var)?;
139            let right_diff = diff_cas(cx, right, var)?;
140            Ok(op(
141                math("div"),
142                vec![
143                    op(
144                        math("sub"),
145                        vec![
146                            op(math("mul"), vec![left_diff, right.clone()]),
147                            op(math("mul"), vec![left.clone(), right_diff]),
148                        ],
149                    ),
150                    op(
151                        math("pow"),
152                        vec![right.clone(), CasExpr::Num(number_constant(cx, "2")?)],
153                    ),
154                ],
155            ))
156        }
157        [head, tail @ ..] => diff_div(cx, &[head.clone(), op(math("mul"), tail.to_vec())], var),
158    }
159}
160
161fn diff_pow(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
162    let [base, exponent] = two_args(args, &math("pow"))?;
163    let base_diff = diff_cas(cx, base, var)?;
164    if let CasExpr::Num(value) = exponent
165        && let Some(decremented) = decrement_value(cx, value)?
166    {
167        return Ok(op(
168            math("mul"),
169            vec![
170                CasExpr::Num(value.clone()),
171                op(math("pow"), vec![base.clone(), CasExpr::Num(decremented)]),
172                base_diff,
173            ],
174        ));
175    }
176    let exponent_diff = diff_cas(cx, exponent, var)?;
177    Ok(op(
178        math("mul"),
179        vec![
180            op(math("pow"), vec![base.clone(), exponent.clone()]),
181            op(
182                math("add"),
183                vec![
184                    op(
185                        math("mul"),
186                        vec![exponent_diff, op(Symbol::new("ln"), vec![base.clone()])],
187                    ),
188                    op(
189                        math("div"),
190                        vec![
191                            op(math("mul"), vec![exponent.clone(), base_diff]),
192                            base.clone(),
193                        ],
194                    ),
195                ],
196            ),
197        ],
198    ))
199}
200
201fn chain_rule(cx: &mut Cx, outer: Symbol, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
202    let [arg] = one_arg(args, &outer)?;
203    Ok(op(
204        math("mul"),
205        vec![diff_cas(cx, arg, var)?, op(outer, vec![arg.clone()])],
206    ))
207}
208
209fn one_arg<'a>(args: &'a [CasExpr], operator: &Symbol) -> Result<[&'a CasExpr; 1]> {
210    let [arg] = args else {
211        return Err(Error::Eval(format!(
212            "{operator} expects exactly one CAS operand"
213        )));
214    };
215    Ok([arg])
216}
217
218fn two_args<'a>(args: &'a [CasExpr], operator: &Symbol) -> Result<[&'a CasExpr; 2]> {
219    let [left, right] = args else {
220        return Err(Error::Eval(format!(
221            "{operator} expects exactly two CAS operands"
222        )));
223    };
224    Ok([left, right])
225}
226
227fn zero(cx: &mut Cx) -> Result<CasExpr> {
228    Ok(CasExpr::Num(number_constant(cx, "0")?))
229}
230
231fn one(cx: &mut Cx) -> Result<CasExpr> {
232    Ok(CasExpr::Num(number_constant(cx, "1")?))
233}
234
235fn neg_one(cx: &mut Cx) -> Result<CasExpr> {
236    Ok(CasExpr::Num(number_constant(cx, "-1")?))
237}
238
239fn number_constant(cx: &mut Cx, canonical: &str) -> Result<Value> {
240    if cx
241        .registry()
242        .number_domain_by_symbol(&domains::i64())
243        .is_some()
244    {
245        return cx
246            .factory()
247            .number_literal(domains::i64(), canonical.to_owned());
248    }
249    if cx
250        .registry()
251        .number_domain_by_symbol(&domains::f64())
252        .is_some()
253    {
254        let canonical = if canonical == "-1" {
255            "-1.0".to_owned()
256        } else {
257            format!("{canonical}.0")
258        };
259        return cx.factory().number_literal(domains::f64(), canonical);
260    }
261    Err(Error::Eval(
262        "CAS differentiation requires a loaded integer or f64 number domain".to_owned(),
263    ))
264}
265
266fn decrement_value(cx: &mut Cx, value: &Value) -> Result<Option<Value>> {
267    if literal_number(cx, value)?.is_none() {
268        return Ok(None);
269    }
270    let one = number_constant(cx, "1")?;
271    let decremented = cx.apply_value_number_binary_op(&math("sub"), value.clone(), one)?;
272    Ok(cx
273        .number_value_ref(decremented.clone())?
274        .and_then(|number| number.literal)
275        .map(|_| decremented))
276}
277
278use sim_lib_numbers_cas::literal_number;
279
280pub(crate) fn math(name: &str) -> Symbol {
281    Symbol::qualified("math", name)
282}
283
284pub(crate) fn op(operator: Symbol, args: Vec<CasExpr>) -> CasExpr {
285    CasExpr::Op(operator, args)
286}