tract_data/dim/
resolve.rs

1use tract_num_traits::Zero;
2
3use crate::internal::*;
4
5pub fn solve_for(sym: &Symbol, left: &TDim, right: &TDim) -> Option<TDim> {
6    if !left.symbols().contains(sym) && !right.symbols().contains(sym) {
7        return None;
8    }
9    if right.symbols().contains(sym) {
10        return solve_for(sym, &(left.clone() - right), &0.to_dim());
11    }
12    match left {
13        TDim::Sym(s) => {
14            if s == sym {
15                Some(right.clone())
16            } else {
17                None
18            }
19        }
20        TDim::Add(terms) => {
21            let consts: TDim = terms.iter().filter(|t| !t.symbols().contains(sym)).sum();
22            if consts.is_zero() {
23                None
24            } else {
25                solve_for(sym, &(left.clone() - &consts), &(right.clone() - &consts))
26            }
27        }
28        TDim::MulInt(z, a) => {
29            let gcd = right.gcd();
30            if gcd % z.unsigned_abs() == 0 {
31                solve_for(sym, a, &(right.clone() / *z))
32            } else {
33                None
34            }
35        }
36        _ => None,
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43    use super::{parse_tdim, SymbolScope};
44
45    lazy_static::lazy_static!(
46        static ref TABLE:SymbolScope = SymbolScope::default();
47        static ref A:Symbol = TABLE.sym("a");
48    );
49
50    fn p(s: &str) -> TDim {
51        parse_tdim(&TABLE, s).unwrap()
52    }
53
54    #[test]
55    fn trivial() {
56        assert_eq!(solve_for(&A, &p("a"), &p("3")), Some(3i32.to_dim()));
57    }
58
59    #[test]
60    fn negative() {
61        assert_eq!(solve_for(&A, &p("a + 3"), &p("0")), Some(-(3i32.to_dim())));
62    }
63
64    #[test]
65    fn swap() {
66        assert_eq!(solve_for(&A, &p("3"), &p("a")), Some(3i32.to_dim()));
67    }
68
69    #[test]
70    fn scale() {
71        assert_eq!(solve_for(&A, &p("3 * a"), &p("6")), Some(2.to_dim()));
72    }
73
74    #[test]
75    fn ax_plus_b() {
76        assert_eq!(solve_for(&A, &p("3 * a + 1"), &p("7")), Some(2.to_dim()));
77    }
78
79    #[test]
80    fn both_sides() {
81        assert_eq!(solve_for(&A, &p("3 * a + 1"), &p("2 * a")), Some((-1).to_dim()));
82    }
83
84    #[test]
85    fn x_over_n() {
86        assert_eq!(solve_for(&A, &p("a/4"), &p("2")), None);
87    }
88
89    #[test]
90    fn with_symbols() {
91        assert_eq!(solve_for(&A, &p("a + 1"), &p("b")), Some(p("b-1")));
92    }
93}