1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use tract_num_traits::Zero;

use crate::internal::*;

pub fn solve_for(sym: &Symbol, left: &TDim, right: &TDim) -> Option<TDim> {
    if !left.symbols().contains(sym) && !right.symbols().contains(sym) {
        return None;
    }
    if right.symbols().contains(sym) {
        return solve_for(sym, &(left.clone() - right), &0.to_dim());
    }
    match left {
        TDim::Sym(s) => {
            if s == sym {
                Some(right.clone())
            } else {
                None
            }
        }
        TDim::Add(terms) => {
            let consts: TDim = terms.iter().filter(|t| !t.symbols().contains(sym)).sum();
            if consts.is_zero() {
                None
            } else {
                solve_for(sym, &(left.clone() - &consts), &(right.clone() - &consts))
            }
        }
        TDim::MulInt(z, a) => {
            let gcd = right.gcd();
            if gcd % z.unsigned_abs() == 0 {
                solve_for(sym, a, &(right.clone() / *z))
            } else {
                None
            }
        }
        _ => None,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use super::{parse_tdim, SymbolTable};

    lazy_static::lazy_static!(
        static ref TABLE:SymbolTable = SymbolTable::default();
        static ref A:Symbol = TABLE.sym("a");
    );

    fn p(s: &str) -> TDim {
        parse_tdim(&TABLE, s).unwrap()
    }

    #[test]
    fn trivial() {
        assert_eq!(solve_for(&A, &p("a"), &p("3")), Some(3i32.to_dim()));
    }

    #[test]
    fn negative() {
        assert_eq!(solve_for(&A, &p("a + 3"), &p("0")), Some(-(3i32.to_dim())));
    }

    #[test]
    fn swap() {
        assert_eq!(solve_for(&A, &p("3"), &p("a")), Some(3i32.to_dim()));
    }

    #[test]
    fn scale() {
        assert_eq!(solve_for(&A, &p("3 * a"), &p("6")), Some(2.to_dim()));
    }

    #[test]
    fn ax_plus_b() {
        assert_eq!(solve_for(&A, &p("3 * a + 1"), &p("7")), Some(2.to_dim()));
    }

    #[test]
    fn both_sides() {
        assert_eq!(solve_for(&A, &p("3 * a + 1"), &p("2 * a")), Some((-1).to_dim()));
    }

    #[test]
    fn x_over_n() {
        assert_eq!(solve_for(&A, &p("a/4"), &p("2")), None);
    }

    #[test]
    fn with_symbols() {
        assert_eq!(solve_for(&A, &p("a + 1"), &p("b")), Some(p("b-1")));
    }
}