tensorlogic_compiler/symbolic_diff/
helpers.rs1use tensorlogic_ir::TLExpr;
4
5#[inline]
7pub(super) fn zero() -> TLExpr {
8 TLExpr::Constant(0.0)
9}
10
11#[inline]
13pub(super) fn one() -> TLExpr {
14 TLExpr::Constant(1.0)
15}
16
17#[inline]
19pub(super) fn is_constant_value(expr: &TLExpr, value: f64) -> bool {
20 match expr {
21 TLExpr::Constant(v) => (v - value).abs() < f64::EPSILON,
22 _ => false,
23 }
24}
25
26pub(super) fn derivative_of_function(function: &TLExpr) -> TLExpr {
31 match function {
32 TLExpr::Pred { name, args } if args.is_empty() => {
33 TLExpr::pred(format!("d_{}", name), vec![])
34 }
35 _ => TLExpr::pred("d_f".to_string(), vec![]),
36 }
37}
38
39pub fn simplify_derivative(expr: TLExpr) -> TLExpr {
51 match expr {
52 TLExpr::Add(l, r) => {
53 let l = simplify_derivative(*l);
54 let r = simplify_derivative(*r);
55 if is_constant_value(&l, 0.0) {
56 return r;
57 }
58 if is_constant_value(&r, 0.0) {
59 return l;
60 }
61 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&l, &r) {
62 return TLExpr::Constant(a + b);
63 }
64 TLExpr::Add(Box::new(l), Box::new(r))
65 }
66
67 TLExpr::Sub(l, r) => {
68 let l = simplify_derivative(*l);
69 let r = simplify_derivative(*r);
70 if is_constant_value(&r, 0.0) {
71 return l;
72 }
73 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&l, &r) {
74 return TLExpr::Constant(a - b);
75 }
76 TLExpr::Sub(Box::new(l), Box::new(r))
77 }
78
79 TLExpr::Mul(l, r) => {
80 let l = simplify_derivative(*l);
81 let r = simplify_derivative(*r);
82 if is_constant_value(&l, 0.0) || is_constant_value(&r, 0.0) {
83 return TLExpr::Constant(0.0);
84 }
85 if is_constant_value(&l, 1.0) {
86 return r;
87 }
88 if is_constant_value(&r, 1.0) {
89 return l;
90 }
91 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&l, &r) {
92 return TLExpr::Constant(a * b);
93 }
94 TLExpr::Mul(Box::new(l), Box::new(r))
95 }
96
97 TLExpr::Div(l, r) => {
98 let l = simplify_derivative(*l);
99 let r = simplify_derivative(*r);
100 if is_constant_value(&l, 0.0) {
101 return TLExpr::Constant(0.0);
102 }
103 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&l, &r) {
104 if b.abs() > f64::EPSILON {
105 return TLExpr::Constant(a / b);
106 }
107 }
108 TLExpr::Div(Box::new(l), Box::new(r))
109 }
110
111 TLExpr::Pow(base, exp) => {
112 let base = simplify_derivative(*base);
113 let exp = simplify_derivative(*exp);
114 if is_constant_value(&exp, 0.0) {
115 return TLExpr::Constant(1.0);
116 }
117 if is_constant_value(&exp, 1.0) {
118 return base;
119 }
120 if let (TLExpr::Constant(b), TLExpr::Constant(e)) = (&base, &exp) {
121 return TLExpr::Constant(b.powf(*e));
122 }
123 TLExpr::Pow(Box::new(base), Box::new(exp))
124 }
125
126 TLExpr::And(l, r) => TLExpr::And(
127 Box::new(simplify_derivative(*l)),
128 Box::new(simplify_derivative(*r)),
129 ),
130 TLExpr::Or(l, r) => TLExpr::Or(
131 Box::new(simplify_derivative(*l)),
132 Box::new(simplify_derivative(*r)),
133 ),
134 TLExpr::Not(inner) => TLExpr::Not(Box::new(simplify_derivative(*inner))),
135
136 other => other,
137 }
138}