tract_data/dim/
resolve.rs1use tract_num_traits::Zero;
2
3use crate::internal::*;
4
5pub fn solve_for(sym: &Symbol, left: &TDim, right: &TDim) -> Option<TDim> {
6 if !left.symbols().contains(sym) && !right.symbols().contains(sym) {
7 return None;
8 }
9 if right.symbols().contains(sym) {
10 return solve_for(sym, &(left.clone() - right), &0.to_dim());
11 }
12 match left {
13 TDim::Sym(s) => {
14 if s == sym {
15 Some(right.clone())
16 } else {
17 None
18 }
19 }
20 TDim::Add(terms) => {
21 let consts: TDim = terms.iter().filter(|t| !t.symbols().contains(sym)).sum();
22 if consts.is_zero() {
23 None
24 } else {
25 solve_for(sym, &(left.clone() - &consts), &(right.clone() - &consts))
26 }
27 }
28 TDim::MulInt(z, a) => {
29 let gcd = right.gcd();
30 if gcd % z.unsigned_abs() == 0 {
31 solve_for(sym, a, &(right.clone() / *z))
32 } else {
33 None
34 }
35 }
36 _ => None,
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43 use super::{parse_tdim, SymbolScope};
44
45 lazy_static::lazy_static!(
46 static ref TABLE:SymbolScope = SymbolScope::default();
47 static ref A:Symbol = TABLE.sym("a");
48 );
49
50 fn p(s: &str) -> TDim {
51 parse_tdim(&TABLE, s).unwrap()
52 }
53
54 #[test]
55 fn trivial() {
56 assert_eq!(solve_for(&A, &p("a"), &p("3")), Some(3i32.to_dim()));
57 }
58
59 #[test]
60 fn negative() {
61 assert_eq!(solve_for(&A, &p("a + 3"), &p("0")), Some(-(3i32.to_dim())));
62 }
63
64 #[test]
65 fn swap() {
66 assert_eq!(solve_for(&A, &p("3"), &p("a")), Some(3i32.to_dim()));
67 }
68
69 #[test]
70 fn scale() {
71 assert_eq!(solve_for(&A, &p("3 * a"), &p("6")), Some(2.to_dim()));
72 }
73
74 #[test]
75 fn ax_plus_b() {
76 assert_eq!(solve_for(&A, &p("3 * a + 1"), &p("7")), Some(2.to_dim()));
77 }
78
79 #[test]
80 fn both_sides() {
81 assert_eq!(solve_for(&A, &p("3 * a + 1"), &p("2 * a")), Some((-1).to_dim()));
82 }
83
84 #[test]
85 fn x_over_n() {
86 assert_eq!(solve_for(&A, &p("a/4"), &p("2")), None);
87 }
88
89 #[test]
90 fn with_symbols() {
91 assert_eq!(solve_for(&A, &p("a + 1"), &p("b")), Some(p("b-1")));
92 }
93}