sim_lib_numbers_cas_diff/implementation/
integrate.rs1use sim_kernel::{Cx, Error, Result, Symbol, Value};
5use sim_lib_numbers_cas::{CasExpr, simplify_expr};
6use sim_lib_numbers_core::domains;
7
8use super::diff::math;
9
10pub fn integrate_sym_symbol() -> Symbol {
12 Symbol::new("integrate-sym")
13}
14
15pub fn integrate_cas(cx: &mut Cx, expr: &CasExpr, var: &Symbol) -> Result<CasExpr> {
21 let integral = match expr {
22 CasExpr::Num(value) => op(
23 math("mul"),
24 vec![CasExpr::Num(value.clone()), CasExpr::Var(var.clone())],
25 ),
26 CasExpr::Var(symbol) if symbol == var => integrate_power_of_var(cx, 1, var)?,
27 CasExpr::Var(symbol) => op(
28 math("mul"),
29 vec![CasExpr::Var(symbol.clone()), CasExpr::Var(var.clone())],
30 ),
31 CasExpr::Op(operator, args) if *operator == math("add") => {
32 op(math("add"), integrate_all(cx, args, var)?)
33 }
34 CasExpr::Op(operator, args) if *operator == math("sub") => {
35 op(math("sub"), integrate_all(cx, args, var)?)
36 }
37 CasExpr::Op(operator, args) if *operator == math("mul") => integrate_mul(cx, args, var)?,
38 CasExpr::Op(operator, args) if *operator == math("pow") => integrate_pow(cx, args, var)?,
39 _ => {
40 return Err(Error::Eval(format!(
41 "{} only supports constants, sums, scalar products, and powers of the integration variable",
42 integrate_sym_symbol()
43 )));
44 }
45 };
46 simplify_expr(cx, integral)
47}
48
49fn integrate_all(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<Vec<CasExpr>> {
50 args.iter().map(|arg| integrate_cas(cx, arg, var)).collect()
51}
52
53fn integrate_mul(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
54 let mut constant = Vec::new();
55 let mut variable = Vec::new();
56 for arg in args {
57 if depends_on(arg, var) {
58 variable.push(arg.clone());
59 } else {
60 constant.push(arg.clone());
61 }
62 }
63 if variable.len() != 1 {
64 return Err(Error::Eval(format!(
65 "{} only handles products with exactly one variable-dependent factor",
66 integrate_sym_symbol()
67 )));
68 }
69 let mut out = constant;
70 out.push(integrate_cas(cx, &variable[0], var)?);
71 Ok(op(math("mul"), out))
72}
73
74fn integrate_pow(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
75 let [base, exponent] = two_args(args)?;
76 if !matches!(base, CasExpr::Var(symbol) if symbol == var) {
77 return Err(Error::Eval(format!(
78 "{} only supports powers of the integration variable",
79 integrate_sym_symbol()
80 )));
81 }
82 let exponent = literal_i64(cx, exponent)?.ok_or_else(|| {
83 Error::Eval(format!(
84 "{} only supports integer exponents for symbolic powers",
85 integrate_sym_symbol()
86 ))
87 })?;
88 integrate_power_of_var(cx, exponent, var)
89}
90
91fn integrate_power_of_var(cx: &mut Cx, exponent: i64, var: &Symbol) -> Result<CasExpr> {
92 if exponent == -1 {
93 return Ok(op(Symbol::new("ln"), vec![CasExpr::Var(var.clone())]));
94 }
95 let next = exponent + 1;
96 Ok(op(
97 math("mul"),
98 vec![
99 CasExpr::Num(rational_constant(cx, 1, next)?),
100 op(
101 math("pow"),
102 vec![
103 CasExpr::Var(var.clone()),
104 CasExpr::Num(integer_constant(cx, next)?),
105 ],
106 ),
107 ],
108 ))
109}
110
111fn depends_on(expr: &CasExpr, var: &Symbol) -> bool {
112 match expr {
113 CasExpr::Num(_) => false,
114 CasExpr::Var(symbol) => symbol == var,
115 CasExpr::Op(_, args) => args.iter().any(|arg| depends_on(arg, var)),
116 }
117}
118
119fn literal_i64(cx: &mut Cx, expr: &CasExpr) -> Result<Option<i64>> {
120 let CasExpr::Num(value) = expr else {
121 return Ok(None);
122 };
123 let display = value.object().display(cx)?;
124 Ok(display.parse::<i64>().ok())
125}
126
127fn integer_constant(cx: &mut Cx, value: i64) -> Result<Value> {
128 if cx
129 .registry()
130 .number_domain_by_symbol(&domains::i64())
131 .is_some()
132 {
133 return cx
134 .factory()
135 .number_literal(domains::i64(), value.to_string());
136 }
137 cx.factory()
138 .number_literal(domains::f64(), format!("{value}.0"))
139}
140
141fn rational_constant(cx: &mut Cx, num: i64, den: i64) -> Result<Value> {
142 if cx
143 .registry()
144 .number_domain_by_symbol(&domains::rational())
145 .is_some()
146 {
147 return cx
148 .factory()
149 .number_literal(domains::rational(), format!("{num}/{den}"));
150 }
151 let value = num as f64 / den as f64;
152 cx.factory()
153 .number_literal(domains::f64(), value.to_string())
154}
155
156fn two_args(args: &[CasExpr]) -> Result<[&CasExpr; 2]> {
157 let [left, right] = args else {
158 return Err(Error::Eval(format!(
159 "{} expects exactly two operands",
160 math("pow")
161 )));
162 };
163 Ok([left, right])
164}
165
166use super::diff::op;