Skip to main content

sim_lib_numbers_cas_diff/implementation/
integrate.rs

1//! Symbolic integration of `CasExpr` trees: the recursive `integrate` rules
2//! covering the builtin operators and power-of-variable cases.
3
4use sim_kernel::{Cx, Error, Result, Symbol, Value};
5use sim_lib_numbers_cas::{CasExpr, simplify_expr};
6use sim_lib_numbers_core::domains;
7
8use super::diff::math;
9
10/// The `integrate-sym` symbol: the symbolic integration entry point.
11pub fn integrate_sym_symbol() -> Symbol {
12    Symbol::new("integrate-sym")
13}
14
15/// Integrate a [`CasExpr`] with respect to `var`, returning a simplified
16/// antiderivative tree (without constant of integration).
17///
18/// Covers the builtin arithmetic operators and power-of-variable cases; inputs
19/// outside the supported set produce an error.
20pub fn integrate_cas(cx: &mut Cx, expr: &CasExpr, var: &Symbol) -> Result<CasExpr> {
21    let integral = match expr {
22        CasExpr::Num(value) => op(
23            math("mul"),
24            vec![CasExpr::Num(value.clone()), CasExpr::Var(var.clone())],
25        ),
26        CasExpr::Var(symbol) if symbol == var => integrate_power_of_var(cx, 1, var)?,
27        CasExpr::Var(symbol) => op(
28            math("mul"),
29            vec![CasExpr::Var(symbol.clone()), CasExpr::Var(var.clone())],
30        ),
31        CasExpr::Op(operator, args) if *operator == math("add") => {
32            op(math("add"), integrate_all(cx, args, var)?)
33        }
34        CasExpr::Op(operator, args) if *operator == math("sub") => {
35            op(math("sub"), integrate_all(cx, args, var)?)
36        }
37        CasExpr::Op(operator, args) if *operator == math("mul") => integrate_mul(cx, args, var)?,
38        CasExpr::Op(operator, args) if *operator == math("pow") => integrate_pow(cx, args, var)?,
39        _ => {
40            return Err(Error::Eval(format!(
41                "{} only supports constants, sums, scalar products, and powers of the integration variable",
42                integrate_sym_symbol()
43            )));
44        }
45    };
46    simplify_expr(cx, integral)
47}
48
49fn integrate_all(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<Vec<CasExpr>> {
50    args.iter().map(|arg| integrate_cas(cx, arg, var)).collect()
51}
52
53fn integrate_mul(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
54    let mut constant = Vec::new();
55    let mut variable = Vec::new();
56    for arg in args {
57        if depends_on(arg, var) {
58            variable.push(arg.clone());
59        } else {
60            constant.push(arg.clone());
61        }
62    }
63    if variable.len() != 1 {
64        return Err(Error::Eval(format!(
65            "{} only handles products with exactly one variable-dependent factor",
66            integrate_sym_symbol()
67        )));
68    }
69    let mut out = constant;
70    out.push(integrate_cas(cx, &variable[0], var)?);
71    Ok(op(math("mul"), out))
72}
73
74fn integrate_pow(cx: &mut Cx, args: &[CasExpr], var: &Symbol) -> Result<CasExpr> {
75    let [base, exponent] = two_args(args)?;
76    if !matches!(base, CasExpr::Var(symbol) if symbol == var) {
77        return Err(Error::Eval(format!(
78            "{} only supports powers of the integration variable",
79            integrate_sym_symbol()
80        )));
81    }
82    let exponent = literal_i64(cx, exponent)?.ok_or_else(|| {
83        Error::Eval(format!(
84            "{} only supports integer exponents for symbolic powers",
85            integrate_sym_symbol()
86        ))
87    })?;
88    integrate_power_of_var(cx, exponent, var)
89}
90
91fn integrate_power_of_var(cx: &mut Cx, exponent: i64, var: &Symbol) -> Result<CasExpr> {
92    if exponent == -1 {
93        return Ok(op(Symbol::new("ln"), vec![CasExpr::Var(var.clone())]));
94    }
95    let next = exponent + 1;
96    Ok(op(
97        math("mul"),
98        vec![
99            CasExpr::Num(rational_constant(cx, 1, next)?),
100            op(
101                math("pow"),
102                vec![
103                    CasExpr::Var(var.clone()),
104                    CasExpr::Num(integer_constant(cx, next)?),
105                ],
106            ),
107        ],
108    ))
109}
110
111fn depends_on(expr: &CasExpr, var: &Symbol) -> bool {
112    match expr {
113        CasExpr::Num(_) => false,
114        CasExpr::Var(symbol) => symbol == var,
115        CasExpr::Op(_, args) => args.iter().any(|arg| depends_on(arg, var)),
116    }
117}
118
119fn literal_i64(cx: &mut Cx, expr: &CasExpr) -> Result<Option<i64>> {
120    let CasExpr::Num(value) = expr else {
121        return Ok(None);
122    };
123    let display = value.object().display(cx)?;
124    Ok(display.parse::<i64>().ok())
125}
126
127fn integer_constant(cx: &mut Cx, value: i64) -> Result<Value> {
128    if cx
129        .registry()
130        .number_domain_by_symbol(&domains::i64())
131        .is_some()
132    {
133        return cx
134            .factory()
135            .number_literal(domains::i64(), value.to_string());
136    }
137    cx.factory()
138        .number_literal(domains::f64(), format!("{value}.0"))
139}
140
141fn rational_constant(cx: &mut Cx, num: i64, den: i64) -> Result<Value> {
142    if cx
143        .registry()
144        .number_domain_by_symbol(&domains::rational())
145        .is_some()
146    {
147        return cx
148            .factory()
149            .number_literal(domains::rational(), format!("{num}/{den}"));
150    }
151    let value = num as f64 / den as f64;
152    cx.factory()
153        .number_literal(domains::f64(), value.to_string())
154}
155
156fn two_args(args: &[CasExpr]) -> Result<[&CasExpr; 2]> {
157    let [left, right] = args else {
158        return Err(Error::Eval(format!(
159            "{} expects exactly two operands",
160            math("pow")
161        )));
162    };
163    Ok([left, right])
164}
165
166use super::diff::op;