1use sim_kernel::{Cx, Error, Result, Symbol, Value};
5use sim_lib_numbers_cas::{CasExpr, simplify_expr};
6use sim_lib_numbers_core::domains;
7
8use super::registry::apply_registered_rule;
9
10pub fn diff_symbol() -> Symbol {
12 Symbol::new("diff")
13}
14
15pub fn diff_cas(cx: &mut Cx, expr: &CasExpr, var: &Symbol) -> Result<CasExpr> {
22 let derivative = match expr {
23 CasExpr::Num(_) => zero(cx)?,
24 CasExpr::Var(symbol) if symbol == var => one(cx)?,
25 CasExpr::Var(_) => zero(cx)?,
26 CasExpr::Op(operator, args) if *operator == math("add") => {
27 op(math("add"), diff_all(cx, args, var)?)
28 }
29 CasExpr::Op(operator, args) if *operator == math("sub") => diff_sub(cx, args, var)?,
30 CasExpr::Op(operator, args) if *operator == math("mul") => diff_mul(cx, args, var)?,
31 CasExpr::Op(operator, args) if *operator == math("div") => diff_div(cx, args, var)?,
32 CasExpr::Op(operator, args) if *operator == math("pow") => diff_pow(cx, args, var)?,
33 CasExpr::Op(operator, args) if *operator == Symbol::new("sin") => {
34 chain_rule(cx, Symbol::new("cos"), args, var)?
35 }
36 CasExpr::Op(operator, args) if *operator == Symbol::new("cos") => {
37 let [arg] = one_arg(args, operator)?;
38 op(
39 math("mul"),
40 vec![
41 neg_one(cx)?,
42 diff_cas(cx, arg, var)?,
43 op(Symbol::new("sin"), vec![arg.clone()]),
44 ],
45 )
46 }
47 CasExpr::Op(operator, args) if *operator == Symbol::new("tan") => {
48 let [arg] = one_arg(args, operator)?;
49 op(
50 math("div"),
51 vec![
52 diff_cas(cx, arg, var)?,
53 op(
54 math("pow"),
55 vec![
56 op(Symbol::new("cos"), vec![arg.clone()]),
57 CasExpr::Num(number_constant(cx, "2")?),
58 ],
59 ),
60 ],
61 )
62 }
63 CasExpr::Op(operator, args) if *operator == Symbol::new("ln") => {
64 let [arg] = one_arg(args, operator)?;
65 op(math("div"), vec![diff_cas(cx, arg, var)?, arg.clone()])
66 }
67 CasExpr::Op(operator, args) if *operator == Symbol::new("exp") => {
68 chain_rule(cx, Symbol::new("exp"), args, var)?
69 }
70 CasExpr::Op(operator, args) => {
71 if let Some(custom) = apply_registered_rule(operator, args, var) {
72 custom
73 } else {
74 op(
75 diff_symbol(),
76 vec![
77 CasExpr::Op(operator.clone(), args.clone()),
78 CasExpr::Var(var.clone()),
79 ],
80 )
81 }
82 }
83 };
84 simplify_expr(cx, derivative)
85}
86
87fn diff_all(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<Vec<CasExpr>> {
88 args.iter().map(|arg| diff_cas(cx, arg, var)).collect()
89}
90
91fn diff_sub(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
92 match args {
93 [] => Err(Error::Eval(
94 "cannot differentiate an empty subtraction".to_owned(),
95 )),
96 [arg] => Ok(op(math("mul"), vec![neg_one(cx)?, diff_cas(cx, arg, var)?])),
97 _ => Ok(op(math("sub"), diff_all(cx, args, var)?)),
98 }
99}
100
101fn diff_mul(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
102 if args.is_empty() {
103 return Err(Error::Eval(
104 "cannot differentiate an empty multiplication".to_owned(),
105 ));
106 }
107 let mut terms = Vec::with_capacity(args.len());
108 for (index, _) in args.iter().enumerate() {
109 let mut factors = Vec::with_capacity(args.len());
110 for (offset, arg) in args.iter().enumerate() {
111 if index == offset {
112 factors.push(diff_cas(cx, arg, var)?);
113 } else {
114 factors.push(arg.clone());
115 }
116 }
117 terms.push(op(math("mul"), factors));
118 }
119 Ok(op(math("add"), terms))
120}
121
122fn diff_div(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
123 match args {
124 [] => Err(Error::Eval(
125 "cannot differentiate an empty division".to_owned(),
126 )),
127 [arg] => Ok(op(
128 math("div"),
129 vec![
130 op(math("mul"), vec![neg_one(cx)?, diff_cas(cx, arg, var)?]),
131 op(
132 math("pow"),
133 vec![arg.clone(), CasExpr::Num(number_constant(cx, "2")?)],
134 ),
135 ],
136 )),
137 [left, right] => {
138 let left_diff = diff_cas(cx, left, var)?;
139 let right_diff = diff_cas(cx, right, var)?;
140 Ok(op(
141 math("div"),
142 vec![
143 op(
144 math("sub"),
145 vec![
146 op(math("mul"), vec![left_diff, right.clone()]),
147 op(math("mul"), vec![left.clone(), right_diff]),
148 ],
149 ),
150 op(
151 math("pow"),
152 vec![right.clone(), CasExpr::Num(number_constant(cx, "2")?)],
153 ),
154 ],
155 ))
156 }
157 [head, tail @ ..] => diff_div(cx, &[head.clone(), op(math("mul"), tail.to_vec())], var),
158 }
159}
160
161fn diff_pow(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
162 let [base, exponent] = two_args(args, &math("pow"))?;
163 let base_diff = diff_cas(cx, base, var)?;
164 if let CasExpr::Num(value) = exponent
165 && let Some(decremented) = decrement_value(cx, value)?
166 {
167 return Ok(op(
168 math("mul"),
169 vec![
170 CasExpr::Num(value.clone()),
171 op(math("pow"), vec![base.clone(), CasExpr::Num(decremented)]),
172 base_diff,
173 ],
174 ));
175 }
176 let exponent_diff = diff_cas(cx, exponent, var)?;
177 Ok(op(
178 math("mul"),
179 vec![
180 op(math("pow"), vec![base.clone(), exponent.clone()]),
181 op(
182 math("add"),
183 vec![
184 op(
185 math("mul"),
186 vec![exponent_diff, op(Symbol::new("ln"), vec![base.clone()])],
187 ),
188 op(
189 math("div"),
190 vec![
191 op(math("mul"), vec![exponent.clone(), base_diff]),
192 base.clone(),
193 ],
194 ),
195 ],
196 ),
197 ],
198 ))
199}
200
201fn chain_rule(cx: &mut Cx, outer: Symbol, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
202 let [arg] = one_arg(args, &outer)?;
203 Ok(op(
204 math("mul"),
205 vec![diff_cas(cx, arg, var)?, op(outer, vec![arg.clone()])],
206 ))
207}
208
209fn one_arg<'a>(args: &'a [CasExpr], operator: &Symbol) -> Result<[&'a CasExpr; 1]> {
210 let [arg] = args else {
211 return Err(Error::Eval(format!(
212 "{operator} expects exactly one CAS operand"
213 )));
214 };
215 Ok([arg])
216}
217
218fn two_args<'a>(args: &'a [CasExpr], operator: &Symbol) -> Result<[&'a CasExpr; 2]> {
219 let [left, right] = args else {
220 return Err(Error::Eval(format!(
221 "{operator} expects exactly two CAS operands"
222 )));
223 };
224 Ok([left, right])
225}
226
227fn zero(cx: &mut Cx) -> Result<CasExpr> {
228 Ok(CasExpr::Num(number_constant(cx, "0")?))
229}
230
231fn one(cx: &mut Cx) -> Result<CasExpr> {
232 Ok(CasExpr::Num(number_constant(cx, "1")?))
233}
234
235fn neg_one(cx: &mut Cx) -> Result<CasExpr> {
236 Ok(CasExpr::Num(number_constant(cx, "-1")?))
237}
238
239fn number_constant(cx: &mut Cx, canonical: &str) -> Result<Value> {
240 if cx
241 .registry()
242 .number_domain_by_symbol(&domains::i64())
243 .is_some()
244 {
245 return cx
246 .factory()
247 .number_literal(domains::i64(), canonical.to_owned());
248 }
249 if cx
250 .registry()
251 .number_domain_by_symbol(&domains::f64())
252 .is_some()
253 {
254 let canonical = if canonical == "-1" {
255 "-1.0".to_owned()
256 } else {
257 format!("{canonical}.0")
258 };
259 return cx.factory().number_literal(domains::f64(), canonical);
260 }
261 Err(Error::Eval(
262 "CAS differentiation requires a loaded integer or f64 number domain".to_owned(),
263 ))
264}
265
266fn decrement_value(cx: &mut Cx, value: &Value) -> Result<Option<Value>> {
267 if literal_number(cx, value)?.is_none() {
268 return Ok(None);
269 }
270 let one = number_constant(cx, "1")?;
271 let decremented = cx.apply_value_number_binary_op(&math("sub"), value.clone(), one)?;
272 Ok(cx
273 .number_value_ref(decremented.clone())?
274 .and_then(|number| number.literal)
275 .map(|_| decremented))
276}
277
278use sim_lib_numbers_cas::literal_number;
279
280pub(crate) fn math(name: &str) -> Symbol {
281 Symbol::qualified("math", name)
282}
283
284pub(crate) fn op(operator: Symbol, args: Vec<CasExpr>) -> CasExpr {
285 CasExpr::Op(operator, args)
286}