Skip to main content

tensorlogic_compiler/symbolic_diff/
helpers.rs

1//! Helper constructors and the algebraic simplification pass.
2
3use tensorlogic_ir::TLExpr;
4
5/// Build the canonical zero expression.
6#[inline]
7pub(super) fn zero() -> TLExpr {
8    TLExpr::Constant(0.0)
9}
10
11/// Build the canonical one expression.
12#[inline]
13pub(super) fn one() -> TLExpr {
14    TLExpr::Constant(1.0)
15}
16
17/// Check whether `expr` is a constant equal to `value` (within f64 epsilon).
18#[inline]
19pub(super) fn is_constant_value(expr: &TLExpr, value: f64) -> bool {
20    match expr {
21        TLExpr::Constant(v) => (v - value).abs() < f64::EPSILON,
22        _ => false,
23    }
24}
25
26/// Return a symbolic representation of the derivative of a function node.
27///
28/// For a known named function `f` (zero-arity predicate), the marker is `Pred("d_f", [])`.
29/// For a complex/anonymous function, `Pred("d_f", [])` is used as a generic marker.
30pub(super) fn derivative_of_function(function: &TLExpr) -> TLExpr {
31    match function {
32        TLExpr::Pred { name, args } if args.is_empty() => {
33            TLExpr::pred(format!("d_{}", name), vec![])
34        }
35        _ => TLExpr::pred("d_f".to_string(), vec![]),
36    }
37}
38
39/// Basic algebraic simplification applied to derivative expressions.
40///
41/// Simplification rules:
42/// - `0 + x → x`,  `x + 0 → x`
43/// - `0 * x → 0`,  `x * 0 → 0`
44/// - `1 * x → x`,  `x * 1 → x`
45/// - `x - 0 → x`
46/// - `0 - c → -c` (constant folding for arithmetic negation form)
47/// - `0 / x → 0`
48/// - `x ^ 0 → 1`,  `x ^ 1 → x`
49/// - constant folding for pure-constant arithmetic nodes
50pub fn simplify_derivative(expr: TLExpr) -> TLExpr {
51    match expr {
52        TLExpr::Add(l, r) => {
53            let l = simplify_derivative(*l);
54            let r = simplify_derivative(*r);
55            if is_constant_value(&l, 0.0) {
56                return r;
57            }
58            if is_constant_value(&r, 0.0) {
59                return l;
60            }
61            if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&l, &r) {
62                return TLExpr::Constant(a + b);
63            }
64            TLExpr::Add(Box::new(l), Box::new(r))
65        }
66
67        TLExpr::Sub(l, r) => {
68            let l = simplify_derivative(*l);
69            let r = simplify_derivative(*r);
70            if is_constant_value(&r, 0.0) {
71                return l;
72            }
73            if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&l, &r) {
74                return TLExpr::Constant(a - b);
75            }
76            TLExpr::Sub(Box::new(l), Box::new(r))
77        }
78
79        TLExpr::Mul(l, r) => {
80            let l = simplify_derivative(*l);
81            let r = simplify_derivative(*r);
82            if is_constant_value(&l, 0.0) || is_constant_value(&r, 0.0) {
83                return TLExpr::Constant(0.0);
84            }
85            if is_constant_value(&l, 1.0) {
86                return r;
87            }
88            if is_constant_value(&r, 1.0) {
89                return l;
90            }
91            if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&l, &r) {
92                return TLExpr::Constant(a * b);
93            }
94            TLExpr::Mul(Box::new(l), Box::new(r))
95        }
96
97        TLExpr::Div(l, r) => {
98            let l = simplify_derivative(*l);
99            let r = simplify_derivative(*r);
100            if is_constant_value(&l, 0.0) {
101                return TLExpr::Constant(0.0);
102            }
103            if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&l, &r) {
104                if b.abs() > f64::EPSILON {
105                    return TLExpr::Constant(a / b);
106                }
107            }
108            TLExpr::Div(Box::new(l), Box::new(r))
109        }
110
111        TLExpr::Pow(base, exp) => {
112            let base = simplify_derivative(*base);
113            let exp = simplify_derivative(*exp);
114            if is_constant_value(&exp, 0.0) {
115                return TLExpr::Constant(1.0);
116            }
117            if is_constant_value(&exp, 1.0) {
118                return base;
119            }
120            if let (TLExpr::Constant(b), TLExpr::Constant(e)) = (&base, &exp) {
121                return TLExpr::Constant(b.powf(*e));
122            }
123            TLExpr::Pow(Box::new(base), Box::new(exp))
124        }
125
126        TLExpr::And(l, r) => TLExpr::And(
127            Box::new(simplify_derivative(*l)),
128            Box::new(simplify_derivative(*r)),
129        ),
130        TLExpr::Or(l, r) => TLExpr::Or(
131            Box::new(simplify_derivative(*l)),
132            Box::new(simplify_derivative(*r)),
133        ),
134        TLExpr::Not(inner) => TLExpr::Not(Box::new(simplify_derivative(*inner))),
135
136        other => other,
137    }
138}