scivex_sym/expr.rs
1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::ops;
4
5use crate::error::{Result, SymError};
6
7/// Built-in mathematical functions.
8///
9/// # Examples
10///
11/// ```
12/// # use scivex_sym::MathFn;
13/// let f = MathFn::Sin;
14/// assert_eq!(format!("{f}"), "sin");
15/// ```
16#[cfg_attr(
17 feature = "serde-support",
18 derive(serde::Serialize, serde::Deserialize)
19)]
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum MathFn {
22 Sin,
23 Cos,
24 Tan,
25 Exp,
26 Ln,
27 Sqrt,
28 Abs,
29}
30
31impl fmt::Display for MathFn {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 let name = match self {
34 Self::Sin => "sin",
35 Self::Cos => "cos",
36 Self::Tan => "tan",
37 Self::Exp => "exp",
38 Self::Ln => "ln",
39 Self::Sqrt => "sqrt",
40 Self::Abs => "abs",
41 };
42 f.write_str(name)
43 }
44}
45
46/// A symbolic expression AST.
47///
48/// `Sub` is represented as `Add(a, Neg(b))` and `Div` as `Mul(a, Pow(b, Const(-1)))`.
49///
50/// # Examples
51///
52/// ```
53/// # use scivex_sym::{var, constant};
54/// # use std::collections::HashMap;
55/// let expr = var("x") + constant(1.0);
56/// let mut vars = HashMap::new();
57/// vars.insert("x".to_string(), 2.0);
58/// assert!((expr.eval(&vars).unwrap() - 3.0).abs() < 1e-10);
59/// ```
60#[cfg_attr(
61 feature = "serde-support",
62 derive(serde::Serialize, serde::Deserialize)
63)]
64#[derive(Debug, Clone, PartialEq)]
65pub enum Expr {
66 /// Numeric constant.
67 Const(f64),
68 /// Named variable.
69 Var(String),
70 /// Addition: `lhs + rhs`.
71 Add(Box<Expr>, Box<Expr>),
72 /// Multiplication: `lhs * rhs`.
73 Mul(Box<Expr>, Box<Expr>),
74 /// Exponentiation: `base ^ exp`.
75 Pow(Box<Expr>, Box<Expr>),
76 /// Negation: `-expr`.
77 Neg(Box<Expr>),
78 /// Function application: `f(arg)`.
79 Fn(MathFn, Box<Expr>),
80}
81
82// ---------------------------------------------------------------------------
83// Constructors
84// ---------------------------------------------------------------------------
85
86/// Create a constant expression.
87///
88/// # Examples
89///
90/// ```
91/// # use scivex_sym::constant;
92/// # use std::collections::HashMap;
93/// let five = constant(5.0);
94/// let val = five.eval(&HashMap::new()).unwrap();
95/// assert!((val - 5.0).abs() < 1e-10);
96/// ```
97#[must_use]
98pub fn constant(v: f64) -> Expr {
99 Expr::Const(v)
100}
101
102/// Create a variable expression.
103///
104/// # Examples
105///
106/// ```
107/// # use scivex_sym::var;
108/// # use std::collections::HashMap;
109/// let x = var("x");
110/// let mut vars = HashMap::new();
111/// vars.insert("x".to_string(), 7.0);
112/// assert!((x.eval(&vars).unwrap() - 7.0).abs() < 1e-10);
113/// ```
114#[must_use]
115pub fn var(name: &str) -> Expr {
116 Expr::Var(name.to_owned())
117}
118
119/// The additive identity.
120///
121/// # Examples
122///
123/// ```
124/// # use scivex_sym::zero;
125/// assert!(zero().is_zero());
126/// ```
127#[must_use]
128pub fn zero() -> Expr {
129 Expr::Const(0.0)
130}
131
132/// The multiplicative identity.
133///
134/// # Examples
135///
136/// ```
137/// # use scivex_sym::one;
138/// assert!(one().is_one());
139/// ```
140#[must_use]
141pub fn one() -> Expr {
142 Expr::Const(1.0)
143}
144
145/// The constant pi.
146///
147/// # Examples
148///
149/// ```
150/// # use scivex_sym::pi;
151/// assert!(pi().as_const().unwrap() > 3.14);
152/// ```
153#[must_use]
154pub fn pi() -> Expr {
155 Expr::Const(std::f64::consts::PI)
156}
157
158/// The constant e.
159///
160/// # Examples
161///
162/// ```
163/// # use scivex_sym::e;
164/// assert!(e().as_const().unwrap() > 2.71);
165/// ```
166#[must_use]
167pub fn e() -> Expr {
168 Expr::Const(std::f64::consts::E)
169}
170
171/// `sin(expr)`
172///
173/// # Examples
174///
175/// ```
176/// # use scivex_sym::{sin, pi};
177/// # use std::collections::HashMap;
178/// let val = sin(pi()).eval(&HashMap::new()).unwrap();
179/// assert!(val.abs() < 1e-10);
180/// ```
181#[must_use]
182pub fn sin(expr: Expr) -> Expr {
183 Expr::Fn(MathFn::Sin, Box::new(expr))
184}
185
186/// `cos(expr)`
187///
188/// # Examples
189///
190/// ```
191/// # use scivex_sym::{cos, constant};
192/// # use std::collections::HashMap;
193/// let val = cos(constant(0.0)).eval(&HashMap::new()).unwrap();
194/// assert!((val - 1.0).abs() < 1e-10);
195/// ```
196#[must_use]
197pub fn cos(expr: Expr) -> Expr {
198 Expr::Fn(MathFn::Cos, Box::new(expr))
199}
200
201/// `tan(expr)`
202///
203/// # Examples
204///
205/// ```
206/// # use scivex_sym::{tan, constant};
207/// # use std::collections::HashMap;
208/// let val = tan(constant(0.0)).eval(&HashMap::new()).unwrap();
209/// assert!(val.abs() < 1e-10);
210/// ```
211#[must_use]
212pub fn tan(expr: Expr) -> Expr {
213 Expr::Fn(MathFn::Tan, Box::new(expr))
214}
215
216/// `exp(expr)`
217///
218/// # Examples
219///
220/// ```
221/// # use scivex_sym::{exp, constant};
222/// # use std::collections::HashMap;
223/// let val = exp(constant(0.0)).eval(&HashMap::new()).unwrap();
224/// assert!((val - 1.0).abs() < 1e-10);
225/// ```
226#[must_use]
227pub fn exp(expr: Expr) -> Expr {
228 Expr::Fn(MathFn::Exp, Box::new(expr))
229}
230
231/// `ln(expr)`
232///
233/// # Examples
234///
235/// ```
236/// # use scivex_sym::{ln, e};
237/// # use std::collections::HashMap;
238/// let val = ln(e()).eval(&HashMap::new()).unwrap();
239/// assert!((val - 1.0).abs() < 1e-10);
240/// ```
241#[must_use]
242pub fn ln(expr: Expr) -> Expr {
243 Expr::Fn(MathFn::Ln, Box::new(expr))
244}
245
246/// `sqrt(expr)`
247///
248/// # Examples
249///
250/// ```
251/// # use scivex_sym::{sqrt, constant};
252/// # use std::collections::HashMap;
253/// let val = sqrt(constant(4.0)).eval(&HashMap::new()).unwrap();
254/// assert!((val - 2.0).abs() < 1e-10);
255/// ```
256#[must_use]
257pub fn sqrt(expr: Expr) -> Expr {
258 Expr::Fn(MathFn::Sqrt, Box::new(expr))
259}
260
261/// `|expr|`
262///
263/// # Examples
264///
265/// ```
266/// # use scivex_sym::{abs, constant};
267/// # use std::collections::HashMap;
268/// let val = abs(constant(-3.0)).eval(&HashMap::new()).unwrap();
269/// assert!((val - 3.0).abs() < 1e-10);
270/// ```
271#[must_use]
272pub fn abs(expr: Expr) -> Expr {
273 Expr::Fn(MathFn::Abs, Box::new(expr))
274}
275
276// ---------------------------------------------------------------------------
277// Core methods
278// ---------------------------------------------------------------------------
279
280impl Expr {
281 /// Evaluate the expression given concrete variable bindings.
282 ///
283 /// # Examples
284 ///
285 /// ```
286 /// # use scivex_sym::expr::{var, constant};
287 /// # use std::collections::HashMap;
288 /// let expr = constant(2.0) * var("x") + constant(1.0);
289 /// let vars = HashMap::from([("x".to_string(), 3.0)]);
290 /// assert!((expr.eval(&vars).unwrap() - 7.0).abs() < 1e-10);
291 /// ```
292 pub fn eval(&self, vars: &HashMap<String, f64>) -> Result<f64> {
293 match self {
294 Self::Const(v) => Ok(*v),
295 Self::Var(name) => vars
296 .get(name)
297 .copied()
298 .ok_or_else(|| SymError::UndefinedVariable { name: name.clone() }),
299 Self::Add(a, b) => Ok(a.eval(vars)? + b.eval(vars)?),
300 Self::Mul(a, b) => {
301 let av = a.eval(vars)?;
302 let bv = b.eval(vars)?;
303 Ok(av * bv)
304 }
305 Self::Pow(base, exp) => {
306 let bv = base.eval(vars)?;
307 let ev = exp.eval(vars)?;
308 // Check for 0^negative (division by zero).
309 if bv == 0.0 && ev < 0.0 {
310 return Err(SymError::DivisionByZero);
311 }
312 Ok(bv.powf(ev))
313 }
314 Self::Neg(inner) => Ok(-inner.eval(vars)?),
315 Self::Fn(func, arg) => {
316 let v = arg.eval(vars)?;
317 Ok(match func {
318 MathFn::Sin => v.sin(),
319 MathFn::Cos => v.cos(),
320 MathFn::Tan => v.tan(),
321 MathFn::Exp => v.exp(),
322 MathFn::Ln => v.ln(),
323 MathFn::Sqrt => v.sqrt(),
324 MathFn::Abs => v.abs(),
325 })
326 }
327 }
328 }
329
330 /// Replace every occurrence of `var` with `replacement`.
331 ///
332 /// # Examples
333 ///
334 /// ```
335 /// # use scivex_sym::expr::{var, constant};
336 /// # use std::collections::HashMap;
337 /// let expr = var("x") + constant(1.0);
338 /// let replaced = expr.substitute("x", &constant(5.0));
339 /// assert!((replaced.eval(&HashMap::new()).unwrap() - 6.0).abs() < 1e-10);
340 /// ```
341 #[must_use]
342 pub fn substitute(&self, var: &str, replacement: &Expr) -> Expr {
343 match self {
344 Self::Const(_) => self.clone(),
345 Self::Var(name) => {
346 if name == var {
347 replacement.clone()
348 } else {
349 self.clone()
350 }
351 }
352 Self::Add(a, b) => Expr::Add(
353 Box::new(a.substitute(var, replacement)),
354 Box::new(b.substitute(var, replacement)),
355 ),
356 Self::Mul(a, b) => Expr::Mul(
357 Box::new(a.substitute(var, replacement)),
358 Box::new(b.substitute(var, replacement)),
359 ),
360 Self::Pow(base, exp) => Expr::Pow(
361 Box::new(base.substitute(var, replacement)),
362 Box::new(exp.substitute(var, replacement)),
363 ),
364 Self::Neg(inner) => Expr::Neg(Box::new(inner.substitute(var, replacement))),
365 Self::Fn(func, arg) => Expr::Fn(*func, Box::new(arg.substitute(var, replacement))),
366 }
367 }
368
369 /// Collect all free variable names in the expression.
370 ///
371 /// # Examples
372 ///
373 /// ```
374 /// # use scivex_sym::expr::{var, constant};
375 /// let expr = var("x") + var("y") * constant(2.0);
376 /// let vars = expr.free_variables();
377 /// assert!(vars.contains("x"));
378 /// assert!(vars.contains("y"));
379 /// assert_eq!(vars.len(), 2);
380 /// ```
381 #[must_use]
382 pub fn free_variables(&self) -> HashSet<String> {
383 let mut set = HashSet::new();
384 self.collect_vars(&mut set);
385 set
386 }
387
388 fn collect_vars(&self, set: &mut HashSet<String>) {
389 match self {
390 Self::Const(_) => {}
391 Self::Var(name) => {
392 set.insert(name.clone());
393 }
394 Self::Add(a, b) | Self::Mul(a, b) | Self::Pow(a, b) => {
395 a.collect_vars(set);
396 b.collect_vars(set);
397 }
398 Self::Neg(inner) | Self::Fn(_, inner) => inner.collect_vars(set),
399 }
400 }
401
402 /// Returns `true` if the expression is `Const(0.0)`.
403 ///
404 /// # Examples
405 ///
406 /// ```
407 /// # use scivex_sym::expr::constant;
408 /// assert!(constant(0.0).is_zero());
409 /// assert!(!constant(1.0).is_zero());
410 /// ```
411 #[must_use]
412 pub fn is_zero(&self) -> bool {
413 matches!(self, Self::Const(v) if *v == 0.0)
414 }
415
416 /// Returns `true` if the expression is `Const(1.0)`.
417 ///
418 /// # Examples
419 ///
420 /// ```
421 /// # use scivex_sym::expr::constant;
422 /// assert!(constant(1.0).is_one());
423 /// assert!(!constant(2.0).is_one());
424 /// ```
425 #[must_use]
426 pub fn is_one(&self) -> bool {
427 matches!(self, Self::Const(v) if (*v - 1.0).abs() < f64::EPSILON)
428 }
429
430 /// Returns `true` if the expression is a constant.
431 ///
432 /// # Examples
433 ///
434 /// ```
435 /// # use scivex_sym::expr::{constant, var};
436 /// assert!(constant(3.14).is_const());
437 /// assert!(!var("x").is_const());
438 /// ```
439 #[must_use]
440 pub fn is_const(&self) -> bool {
441 matches!(self, Self::Const(_))
442 }
443
444 /// If the expression is a constant, return its value.
445 ///
446 /// # Examples
447 ///
448 /// ```
449 /// # use scivex_sym::expr::{constant, var};
450 /// assert_eq!(constant(42.0).as_const(), Some(42.0));
451 /// assert_eq!(var("x").as_const(), None);
452 /// ```
453 #[must_use]
454 pub fn as_const(&self) -> Option<f64> {
455 match self {
456 Self::Const(v) => Some(*v),
457 _ => None,
458 }
459 }
460}
461
462// ---------------------------------------------------------------------------
463// Operator overloading
464// ---------------------------------------------------------------------------
465
466impl ops::Add for Expr {
467 type Output = Self;
468 fn add(self, rhs: Self) -> Self {
469 Expr::Add(Box::new(self), Box::new(rhs))
470 }
471}
472
473impl ops::Sub for Expr {
474 type Output = Self;
475 fn sub(self, rhs: Self) -> Self {
476 Expr::Add(Box::new(self), Box::new(Expr::Neg(Box::new(rhs))))
477 }
478}
479
480impl ops::Mul for Expr {
481 type Output = Self;
482 fn mul(self, rhs: Self) -> Self {
483 Expr::Mul(Box::new(self), Box::new(rhs))
484 }
485}
486
487impl ops::Div for Expr {
488 type Output = Self;
489 fn div(self, rhs: Self) -> Self {
490 Expr::Mul(
491 Box::new(self),
492 Box::new(Expr::Pow(Box::new(rhs), Box::new(Expr::Const(-1.0)))),
493 )
494 }
495}
496
497impl ops::Neg for Expr {
498 type Output = Self;
499 fn neg(self) -> Self {
500 Expr::Neg(Box::new(self))
501 }
502}
503
504// ---------------------------------------------------------------------------
505// Display
506// ---------------------------------------------------------------------------
507
508impl fmt::Display for Expr {
509 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510 match self {
511 Self::Const(v) => {
512 if (*v - std::f64::consts::PI).abs() < f64::EPSILON {
513 write!(f, "pi")
514 } else if *v < 0.0 {
515 write!(f, "({v})")
516 } else {
517 write!(f, "{v}")
518 }
519 }
520 Self::Var(name) => f.write_str(name),
521 Self::Add(a, b) => write!(f, "({a} + {b})"),
522 Self::Mul(a, b) => write!(f, "({a} * {b})"),
523 Self::Pow(base, exp) => write!(f, "({base}^{exp})"),
524 Self::Neg(inner) => write!(f, "(-{inner})"),
525 Self::Fn(func, arg) => write!(f, "{func}({arg})"),
526 }
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 #[test]
535 fn eval_const_and_var() {
536 let e = constant(42.0);
537 assert!((e.eval(&HashMap::new()).unwrap() - 42.0).abs() < f64::EPSILON);
538
539 let x = var("x");
540 let mut vars = HashMap::new();
541 vars.insert("x".into(), 3.0);
542 assert!((x.eval(&vars).unwrap() - 3.0).abs() < f64::EPSILON);
543 }
544
545 #[test]
546 fn eval_undefined_variable() {
547 let x = var("x");
548 let err = x.eval(&HashMap::new()).unwrap_err();
549 assert!(matches!(err, SymError::UndefinedVariable { name } if name == "x"));
550 }
551
552 #[test]
553 fn eval_division_by_zero() {
554 // 1 / 0 = 1 * 0^(-1)
555 let e = constant(1.0) / constant(0.0);
556 let err = e.eval(&HashMap::new()).unwrap_err();
557 assert!(matches!(err, SymError::DivisionByZero));
558 }
559
560 #[test]
561 fn eval_arithmetic() {
562 let mut vars = HashMap::new();
563 vars.insert("x".into(), 2.0);
564 // (x + 3) * 4 = 20
565 let e = (var("x") + constant(3.0)) * constant(4.0);
566 assert!((e.eval(&vars).unwrap() - 20.0).abs() < f64::EPSILON);
567 }
568
569 #[test]
570 fn eval_functions() {
571 let vars = HashMap::new();
572 let e = sin(constant(0.0));
573 assert!(e.eval(&vars).unwrap().abs() < f64::EPSILON);
574
575 let e = cos(constant(0.0));
576 assert!((e.eval(&vars).unwrap() - 1.0).abs() < f64::EPSILON);
577
578 let e = exp(constant(0.0));
579 assert!((e.eval(&vars).unwrap() - 1.0).abs() < f64::EPSILON);
580
581 let e = ln(constant(1.0));
582 assert!(e.eval(&vars).unwrap().abs() < f64::EPSILON);
583 }
584
585 #[test]
586 fn substitute_works() {
587 let e = var("x") + constant(1.0);
588 let replaced = e.substitute("x", &constant(5.0));
589 assert!((replaced.eval(&HashMap::new()).unwrap() - 6.0).abs() < f64::EPSILON);
590 }
591
592 #[test]
593 fn free_variables_collected() {
594 let e = var("x") * var("y") + sin(var("x"));
595 let fv = e.free_variables();
596 assert!(fv.contains("x"));
597 assert!(fv.contains("y"));
598 assert_eq!(fv.len(), 2);
599 }
600
601 #[test]
602 fn display_formatting() {
603 let e = var("x") + constant(1.0);
604 let s = format!("{e}");
605 assert_eq!(s, "(x + 1)");
606 }
607
608 #[test]
609 fn is_predicates() {
610 assert!(zero().is_zero());
611 assert!(one().is_one());
612 assert!(constant(3.5).is_const());
613 assert!(!var("x").is_const());
614 assert_eq!(constant(7.0).as_const(), Some(7.0));
615 assert_eq!(var("x").as_const(), None);
616 }
617}