1use crate::common::IntegrateFloat;
7use crate::error::{IntegrateError, IntegrateResult};
8use std::collections::HashMap;
9use std::fmt;
10use std::ops::{Add as StdAdd, Div as StdDiv, Mul as StdMul, Neg as StdNeg, Sub as StdSub};
11
12use Pattern::{DifferenceOfSquares, PythagoreanIdentity, SumOfSquares};
14use SymbolicExpression::{
15 Abs, Add, Atan, Constant, Cos, Cosh, Div, Exp, Ln, Mul, Neg, Pow, Sin, Sinh, Sqrt, Sub, Tan,
16 Tanh, Var,
17};
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct Variable {
22 pub name: String,
23 pub index: Option<usize>, }
25
26impl Variable {
27 pub fn new(name: impl Into<String>) -> Self {
29 Variable {
30 name: name.into(),
31 index: None,
32 }
33 }
34
35 pub fn indexed(name: impl Into<String>, index: usize) -> Self {
37 Variable {
38 name: name.into(),
39 index: Some(index),
40 }
41 }
42}
43
44impl fmt::Display for Variable {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self.index {
47 Some(idx) => write!(f, "{}[{}]", self.name, idx),
48 None => write!(f, "{}", self.name),
49 }
50 }
51}
52
53#[derive(Debug, Clone, PartialEq)]
55pub enum SymbolicExpression<F: IntegrateFloat> {
56 Constant(F),
58 Var(Variable),
60 Add(Box<SymbolicExpression<F>>, Box<SymbolicExpression<F>>),
62 Sub(Box<SymbolicExpression<F>>, Box<SymbolicExpression<F>>),
64 Mul(Box<SymbolicExpression<F>>, Box<SymbolicExpression<F>>),
66 Div(Box<SymbolicExpression<F>>, Box<SymbolicExpression<F>>),
68 Pow(Box<SymbolicExpression<F>>, Box<SymbolicExpression<F>>),
70 Neg(Box<SymbolicExpression<F>>),
72 Sin(Box<SymbolicExpression<F>>),
74 Cos(Box<SymbolicExpression<F>>),
76 Exp(Box<SymbolicExpression<F>>),
78 Ln(Box<SymbolicExpression<F>>),
80 Sqrt(Box<SymbolicExpression<F>>),
82 Tan(Box<SymbolicExpression<F>>),
84 Atan(Box<SymbolicExpression<F>>),
86 Sinh(Box<SymbolicExpression<F>>),
88 Cosh(Box<SymbolicExpression<F>>),
90 Tanh(Box<SymbolicExpression<F>>),
92 Abs(Box<SymbolicExpression<F>>),
94}
95
96impl<F: IntegrateFloat> SymbolicExpression<F> {
97 pub fn constant(value: F) -> Self {
99 SymbolicExpression::Constant(value)
100 }
101
102 pub fn var(name: impl Into<String>) -> Self {
104 SymbolicExpression::Var(Variable::new(name))
105 }
106
107 pub fn indexedvar(name: impl Into<String>, index: usize) -> Self {
109 SymbolicExpression::Var(Variable::indexed(name, index))
110 }
111
112 pub fn tan(expr: SymbolicExpression<F>) -> Self {
114 SymbolicExpression::Tan(Box::new(expr))
115 }
116
117 pub fn atan(expr: SymbolicExpression<F>) -> Self {
119 SymbolicExpression::Atan(Box::new(expr))
120 }
121
122 pub fn sinh(expr: SymbolicExpression<F>) -> Self {
124 SymbolicExpression::Sinh(Box::new(expr))
125 }
126
127 pub fn cosh(expr: SymbolicExpression<F>) -> Self {
129 SymbolicExpression::Cosh(Box::new(expr))
130 }
131
132 pub fn tanh(expr: SymbolicExpression<F>) -> Self {
134 SymbolicExpression::Tanh(Box::new(expr))
135 }
136
137 pub fn abs(expr: SymbolicExpression<F>) -> Self {
139 SymbolicExpression::Abs(Box::new(expr))
140 }
141
142 pub fn differentiate(&self, var: &Variable) -> SymbolicExpression<F> {
144 use SymbolicExpression::*;
145
146 match self {
147 Constant(_) => Constant(F::zero()),
148 Var(v) => {
149 if v == var {
150 Constant(F::one())
151 } else {
152 Constant(F::zero())
153 }
154 }
155 Add(a, b) => Add(
156 Box::new(a.differentiate(var)),
157 Box::new(b.differentiate(var)),
158 ),
159 Sub(a, b) => Sub(
160 Box::new(a.differentiate(var)),
161 Box::new(b.differentiate(var)),
162 ),
163 Mul(a, b) => {
164 Add(
166 Box::new(Mul(Box::new(a.differentiate(var)), b.clone())),
167 Box::new(Mul(a.clone(), Box::new(b.differentiate(var)))),
168 )
169 }
170 Div(a, b) => {
171 Div(
173 Box::new(Sub(
174 Box::new(Mul(Box::new(a.differentiate(var)), b.clone())),
175 Box::new(Mul(a.clone(), Box::new(b.differentiate(var)))),
176 )),
177 Box::new(Mul(b.clone(), b.clone())),
178 )
179 }
180 Pow(a, b) => {
181 if let Constant(n) = &**b {
183 Mul(
185 Box::new(Mul(
186 Box::new(Constant(*n)),
187 Box::new(Pow(a.clone(), Box::new(Constant(*n - F::one())))),
188 )),
189 Box::new(a.differentiate(var)),
190 )
191 } else {
192 let exp_expr = Exp(Box::new(Mul(b.clone(), Box::new(Ln(a.clone())))));
194 exp_expr.differentiate(var)
195 }
196 }
197 Neg(a) => Neg(Box::new(a.differentiate(var))),
198 Sin(a) => {
199 Mul(Box::new(Cos(a.clone())), Box::new(a.differentiate(var)))
201 }
202 Cos(a) => {
203 Neg(Box::new(Mul(
205 Box::new(Sin(a.clone())),
206 Box::new(a.differentiate(var)),
207 )))
208 }
209 Exp(a) => {
210 Mul(Box::new(Exp(a.clone())), Box::new(a.differentiate(var)))
212 }
213 Ln(a) => {
214 Div(Box::new(a.differentiate(var)), a.clone())
216 }
217 Sqrt(a) => {
218 Div(
220 Box::new(a.differentiate(var)),
221 Box::new(Mul(
222 Box::new(Constant(F::from(2.0).unwrap())),
223 Box::new(Sqrt(a.clone())),
224 )),
225 )
226 }
227 Tan(a) => {
228 Div(
230 Box::new(a.differentiate(var)),
231 Box::new(Pow(
232 Box::new(Cos(a.clone())),
233 Box::new(Constant(F::from(2.0).unwrap())),
234 )),
235 )
236 }
237 Atan(a) => {
238 Div(
240 Box::new(a.differentiate(var)),
241 Box::new(Add(
242 Box::new(Constant(F::one())),
243 Box::new(Pow(a.clone(), Box::new(Constant(F::from(2.0).unwrap())))),
244 )),
245 )
246 }
247 Sinh(a) => {
248 Mul(Box::new(Cosh(a.clone())), Box::new(a.differentiate(var)))
250 }
251 Cosh(a) => {
252 Mul(Box::new(Sinh(a.clone())), Box::new(a.differentiate(var)))
254 }
255 Tanh(a) => {
256 Div(
258 Box::new(a.differentiate(var)),
259 Box::new(Pow(
260 Box::new(Cosh(a.clone())),
261 Box::new(Constant(F::from(2.0).unwrap())),
262 )),
263 )
264 }
265 Abs(a) => {
266 Mul(
269 Box::new(Div(a.clone(), Box::new(Abs(a.clone())))),
270 Box::new(a.differentiate(var)),
271 )
272 }
273 }
274 }
275
276 pub fn evaluate(&self, values: &HashMap<Variable, F>) -> IntegrateResult<F> {
278 match self {
279 Constant(c) => Ok(*c),
280 Var(v) => values.get(v).copied().ok_or_else(|| {
281 IntegrateError::ComputationError(format!("Variable {v} not found in values"))
282 }),
283 Add(a, b) => Ok(a.evaluate(values)? + b.evaluate(values)?),
284 Sub(a, b) => Ok(a.evaluate(values)? - b.evaluate(values)?),
285 Mul(a, b) => Ok(a.evaluate(values)? * b.evaluate(values)?),
286 Div(a, b) => {
287 let b_val = b.evaluate(values)?;
288 if b_val.abs() < F::epsilon() {
289 Err(IntegrateError::ComputationError(
290 "Division by zero".to_string(),
291 ))
292 } else {
293 Ok(a.evaluate(values)? / b_val)
294 }
295 }
296 Pow(a, b) => Ok(a.evaluate(values)?.powf(b.evaluate(values)?)),
297 Neg(a) => Ok(-a.evaluate(values)?),
298 Sin(a) => Ok(a.evaluate(values)?.sin()),
299 Cos(a) => Ok(a.evaluate(values)?.cos()),
300 Exp(a) => Ok(a.evaluate(values)?.exp()),
301 Ln(a) => {
302 let a_val = a.evaluate(values)?;
303 if a_val <= F::zero() {
304 Err(IntegrateError::ComputationError(
305 "Logarithm of non-positive value".to_string(),
306 ))
307 } else {
308 Ok(a_val.ln())
309 }
310 }
311 Sqrt(a) => {
312 let a_val = a.evaluate(values)?;
313 if a_val < F::zero() {
314 Err(IntegrateError::ComputationError(
315 "Square root of negative value".to_string(),
316 ))
317 } else {
318 Ok(a_val.sqrt())
319 }
320 }
321 Tan(a) => Ok(a.evaluate(values)?.tan()),
322 Atan(a) => Ok(a.evaluate(values)?.atan()),
323 Sinh(a) => Ok(a.evaluate(values)?.sinh()),
324 Cosh(a) => Ok(a.evaluate(values)?.cosh()),
325 Tanh(a) => Ok(a.evaluate(values)?.tanh()),
326 Abs(a) => Ok(a.evaluate(values)?.abs()),
327 }
328 }
329
330 pub fn variables(&self) -> Vec<Variable> {
332 let mut vars = Vec::new();
333
334 match self {
335 Constant(_) => {}
336 Var(v) => vars.push(v.clone()),
337 Add(a, b) | Sub(a, b) | Mul(a, b) | Div(a, b) | Pow(a, b) => {
338 vars.extend(a.variables());
339 vars.extend(b.variables());
340 }
341 Neg(a) | Sin(a) | Cos(a) | Exp(a) | Ln(a) | Sqrt(a) | Tan(a) | Atan(a) | Sinh(a)
342 | Cosh(a) | Tanh(a) | Abs(a) => {
343 vars.extend(a.variables());
344 }
345 }
346
347 vars.sort_by(|a, b| match (&a.name, &b.name) {
349 (n1, n2) if n1 != n2 => n1.cmp(n2),
350 _ => a.index.cmp(&b.index),
351 });
352 vars.dedup();
353 vars
354 }
355}
356
357#[derive(Debug, Clone, PartialEq)]
359pub enum Pattern<F: IntegrateFloat> {
360 SumOfSquares(Box<SymbolicExpression<F>>, Box<SymbolicExpression<F>>),
362 DifferenceOfSquares(Box<SymbolicExpression<F>>, Box<SymbolicExpression<F>>),
364 PythagoreanIdentity(Box<SymbolicExpression<F>>),
366 EulerFormula(Box<SymbolicExpression<F>>),
368}
369
370#[allow(dead_code)]
372pub fn match_pattern<F: IntegrateFloat>(expr: &SymbolicExpression<F>) -> Option<Pattern<F>> {
373 match expr {
374 Add(a, b) => {
376 if let (Pow(base_a, exp_a), Pow(base_b, exp_b)) = (a.as_ref(), b.as_ref()) {
377 if let (Constant(n_a), Constant(n_b)) = (exp_a.as_ref(), exp_b.as_ref()) {
378 if (*n_a - F::from(2.0).unwrap()).abs() < F::epsilon()
379 && (*n_b - F::from(2.0).unwrap()).abs() < F::epsilon()
380 {
381 return Some(Pattern::SumOfSquares(base_a.clone(), base_b.clone()));
382 }
383 }
384 }
385
386 if let (Pow(sin_base, sin_exp), Pow(cos_base, cos_exp)) = (a.as_ref(), b.as_ref()) {
388 if let (Sin(sin_arg), Cos(cos_arg), Constant(n1), Constant(n2)) = (
389 sin_base.as_ref(),
390 cos_base.as_ref(),
391 sin_exp.as_ref(),
392 cos_exp.as_ref(),
393 ) {
394 if match_expressions(sin_arg, cos_arg)
395 && (*n1 - F::from(2.0).unwrap()).abs() < F::epsilon()
396 && (*n2 - F::from(2.0).unwrap()).abs() < F::epsilon()
397 {
398 return Some(Pattern::PythagoreanIdentity(sin_arg.clone()));
399 }
400 }
401 }
402 None
403 }
404 Sub(a, b) => {
406 if let (Pow(base_a, exp_a), Pow(base_b, exp_b)) = (a.as_ref(), b.as_ref()) {
407 if let (Constant(n_a), Constant(n_b)) = (exp_a.as_ref(), exp_b.as_ref()) {
408 if (*n_a - F::from(2.0).unwrap()).abs() < F::epsilon()
409 && (*n_b - F::from(2.0).unwrap()).abs() < F::epsilon()
410 {
411 return Some(Pattern::DifferenceOfSquares(base_a.clone(), base_b.clone()));
412 }
413 }
414 }
415 None
416 }
417 _ => None,
418 }
419}
420
421#[allow(dead_code)]
423fn match_expressions<F: IntegrateFloat>(
424 expr1: &SymbolicExpression<F>,
425 expr2: &SymbolicExpression<F>,
426) -> bool {
427 match (expr1, expr2) {
428 (Constant(a), Constant(b)) => (*a - *b).abs() < F::epsilon(),
429 (Var(a), Var(b)) => a == b,
430 _ => false,
431 }
432}
433
434#[allow(dead_code)]
436pub fn pattern_simplify<F: IntegrateFloat>(expr: &SymbolicExpression<F>) -> SymbolicExpression<F> {
437 if let Some(pattern) = match_pattern(expr) {
438 match pattern {
439 Pattern::DifferenceOfSquares(a, b) => {
440 Mul(Box::new(Add(a.clone(), b.clone())), Box::new(Sub(a, b)))
442 }
443 Pattern::PythagoreanIdentity(_) => {
444 Constant(F::one())
446 }
447 _ => expr.clone(),
448 }
449 } else {
450 match expr {
452 Add(a, b) => {
453 let a_simp = pattern_simplify(a);
454 let b_simp = pattern_simplify(b);
455 pattern_simplify(&Add(Box::new(a_simp), Box::new(b_simp)))
456 }
457 Sub(a, b) => {
458 let a_simp = pattern_simplify(a);
459 let b_simp = pattern_simplify(b);
460 pattern_simplify(&Sub(Box::new(a_simp), Box::new(b_simp)))
461 }
462 Mul(a, b) => {
463 let a_simp = pattern_simplify(a);
464 let b_simp = pattern_simplify(b);
465 Mul(Box::new(a_simp), Box::new(b_simp))
466 }
467 _ => expr.clone(),
468 }
469 }
470}
471
472#[allow(dead_code)]
474pub fn simplify<F: IntegrateFloat>(expr: &SymbolicExpression<F>) -> SymbolicExpression<F> {
475 match expr {
476 Add(a, b) => {
478 let a_simp = simplify(a);
479 let b_simp = simplify(b);
480 match (&a_simp, &b_simp) {
481 (Constant(x), Constant(y)) => Constant(*x + *y),
482 (Constant(x), _) if x.abs() < F::epsilon() => b_simp,
483 (_, Constant(y)) if y.abs() < F::epsilon() => a_simp,
484 _ => Add(Box::new(a_simp), Box::new(b_simp)),
485 }
486 }
487 Sub(a, b) => {
488 let a_simp = simplify(a);
489 let b_simp = simplify(b);
490 match (&a_simp, &b_simp) {
491 (Constant(x), Constant(y)) => Constant(*x - *y),
492 (_, Constant(y)) if y.abs() < F::epsilon() => a_simp,
493 _ => Sub(Box::new(a_simp), Box::new(b_simp)),
494 }
495 }
496 Mul(a, b) => {
497 let a_simp = simplify(a);
498 let b_simp = simplify(b);
499 match (&a_simp, &b_simp) {
500 (Constant(x), Constant(y)) => Constant(*x * *y),
501 (Constant(x), _) if x.abs() < F::epsilon() => Constant(F::zero()),
502 (_, Constant(y)) if y.abs() < F::epsilon() => Constant(F::zero()),
503 (Constant(x), _) if (*x - F::one()).abs() < F::epsilon() => b_simp,
504 (_, Constant(y)) if (*y - F::one()).abs() < F::epsilon() => a_simp,
505 _ => Mul(Box::new(a_simp), Box::new(b_simp)),
506 }
507 }
508 Div(a, b) => {
509 let a_simp = simplify(a);
510 let b_simp = simplify(b);
511 match (&a_simp, &b_simp) {
512 (Constant(x), Constant(y)) if y.abs() > F::epsilon() => Constant(*x / *y),
513 (Constant(x), _) if x.abs() < F::epsilon() => Constant(F::zero()),
514 (_, Constant(y)) if (*y - F::one()).abs() < F::epsilon() => a_simp,
515 _ => Div(Box::new(a_simp), Box::new(b_simp)),
516 }
517 }
518 Neg(a) => {
519 let a_simp = simplify(a);
520 match &a_simp {
521 Constant(x) => Constant(-*x),
522 Neg(inner) => (**inner).clone(),
523 _ => Neg(Box::new(a_simp)),
524 }
525 }
526 Pow(a, b) => {
527 let a_simp = simplify(a);
528 let b_simp = simplify(b);
529 match (&a_simp, &b_simp) {
530 (Constant(x), Constant(y)) => Constant(x.powf(*y)),
531 (_, Constant(y)) if y.abs() < F::epsilon() => Constant(F::one()), (_, Constant(y)) if (*y - F::one()).abs() < F::epsilon() => a_simp, (Constant(x), _) if x.abs() < F::epsilon() => Constant(F::zero()), (Constant(x), _) if (*x - F::one()).abs() < F::epsilon() => Constant(F::one()), _ => Pow(Box::new(a_simp), Box::new(b_simp)),
536 }
537 }
538 Exp(a) => {
539 let a_simp = simplify(a);
540 match &a_simp {
541 Constant(x) => Constant(x.exp()),
542 Ln(inner) => (**inner).clone(), _ => Exp(Box::new(a_simp)),
544 }
545 }
546 Ln(a) => {
547 let a_simp = simplify(a);
548 match &a_simp {
549 Constant(x) if *x > F::zero() => Constant(x.ln()),
550 Exp(inner) => (**inner).clone(), Constant(x) if (*x - F::one()).abs() < F::epsilon() => Constant(F::zero()), _ => Ln(Box::new(a_simp)),
553 }
554 }
555 Sin(a) => {
556 let a_simp = simplify(a);
557 match &a_simp {
558 Constant(x) => Constant(x.sin()),
559 Neg(inner) => Neg(Box::new(Sin(inner.clone()))), _ => Sin(Box::new(a_simp)),
561 }
562 }
563 Cos(a) => {
564 let a_simp = simplify(a);
565 match &a_simp {
566 Constant(x) => Constant(x.cos()),
567 Neg(inner) => Cos(inner.clone()), _ => Cos(Box::new(a_simp)),
569 }
570 }
571 Tan(a) => {
572 let a_simp = simplify(a);
573 match &a_simp {
574 Constant(x) => Constant(x.tan()),
575 Neg(inner) => Neg(Box::new(Tan(inner.clone()))), _ => Tan(Box::new(a_simp)),
577 }
578 }
579 Sqrt(a) => {
580 let a_simp = simplify(a);
581 match &a_simp {
582 Constant(x) if *x >= F::zero() => Constant(x.sqrt()),
583 Pow(base, exp) => {
584 if let Constant(n) = &**exp {
585 Pow(base.clone(), Box::new(Constant(*n / F::from(2.0).unwrap())))
587 } else {
588 Sqrt(Box::new(a_simp))
589 }
590 }
591 _ => Sqrt(Box::new(a_simp)),
592 }
593 }
594 Abs(a) => {
595 let a_simp = simplify(a);
596 match &a_simp {
597 Constant(x) => Constant(x.abs()),
598 Neg(inner) => Abs(inner.clone()), Abs(inner) => Abs(inner.clone()), _ => Abs(Box::new(a_simp)),
601 }
602 }
603 _ => expr.clone(),
604 }
605}
606
607#[allow(dead_code)]
609pub fn deep_simplify<F: IntegrateFloat>(expr: &SymbolicExpression<F>) -> SymbolicExpression<F> {
610 let algebraic_simplified = simplify(expr);
612 pattern_simplify(&algebraic_simplified)
614}
615
616impl<F: IntegrateFloat> StdAdd for SymbolicExpression<F> {
618 type Output = Self;
619
620 fn add(self, rhs: Self) -> Self::Output {
621 SymbolicExpression::Add(Box::new(self), Box::new(rhs))
622 }
623}
624
625impl<F: IntegrateFloat> StdSub for SymbolicExpression<F> {
626 type Output = Self;
627
628 fn sub(self, rhs: Self) -> Self::Output {
629 SymbolicExpression::Sub(Box::new(self), Box::new(rhs))
630 }
631}
632
633impl<F: IntegrateFloat> StdMul for SymbolicExpression<F> {
634 type Output = Self;
635
636 fn mul(self, rhs: Self) -> Self::Output {
637 SymbolicExpression::Mul(Box::new(self), Box::new(rhs))
638 }
639}
640
641impl<F: IntegrateFloat> StdDiv for SymbolicExpression<F> {
642 type Output = Self;
643
644 fn div(self, rhs: Self) -> Self::Output {
645 SymbolicExpression::Div(Box::new(self), Box::new(rhs))
646 }
647}
648
649impl<F: IntegrateFloat> StdNeg for SymbolicExpression<F> {
650 type Output = Self;
651
652 fn neg(self) -> Self::Output {
653 SymbolicExpression::Neg(Box::new(self))
654 }
655}
656
657impl<F: IntegrateFloat> fmt::Display for SymbolicExpression<F> {
658 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
659 match self {
660 Constant(c) => write!(f, "{c}"),
661 Var(v) => write!(f, "{v}"),
662 Add(a, b) => write!(f, "({a} + {b})"),
663 Sub(a, b) => write!(f, "({a} - {b})"),
664 Mul(a, b) => write!(f, "({a} * {b})"),
665 Div(a, b) => write!(f, "({a} / {b})"),
666 Pow(a, b) => write!(f, "({a} ^ {b})"),
667 Neg(a) => write!(f, "(-{a})"),
668 Sin(a) => write!(f, "sin({a})"),
669 Cos(a) => write!(f, "cos({a})"),
670 Exp(a) => write!(f, "exp({a})"),
671 Ln(a) => write!(f, "ln({a})"),
672 Sqrt(a) => write!(f, "sqrt({a})"),
673 Tan(a) => write!(f, "tan({a})"),
674 Atan(a) => write!(f, "atan({a})"),
675 Sinh(a) => write!(f, "sinh({a})"),
676 Cosh(a) => write!(f, "cosh({a})"),
677 Tanh(a) => write!(f, "tanh({a})"),
678 Abs(a) => write!(f, "|{a}|"),
679 }
680 }
681}