Skip to main content

tract_data/dim/
assertion.rs

1use fmt::Display;
2
3use super::*;
4
5#[derive(Debug, PartialEq, Clone, Hash)]
6#[allow(clippy::upper_case_acronyms)]
7pub enum Assertion {
8    Eq(TDim, TDim),
9    LT(TDim, TDim),
10    GT(TDim, TDim),
11    LTE(TDim, TDim),
12    GTE(TDim, TDim),
13}
14
15impl Display for Assertion {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        use Assertion::*;
18        match self {
19            Eq(l, r) => write!(f, "{l} == {r}"),
20            LT(l, r) => write!(f, "{l} < {r}"),
21            GT(l, r) => write!(f, "{l} > {r}"),
22            LTE(l, r) => write!(f, "{l} <= {r}"),
23            GTE(l, r) => write!(f, "{l} >= {r}"),
24        }
25    }
26}
27
28impl Assertion {
29    pub fn as_known_positive(&self) -> Option<TDim> {
30        use Assertion::*;
31        match self {
32            Eq(left, right) => Some(left.clone() - right),
33            GTE(left, right) => Some(left.clone() - right),
34            GT(left, right) => Some(left.clone() - 1 - right),
35            LTE(left, right) => Some(right.clone() - left),
36            LT(left, right) => Some(right.clone() - 1 - left),
37        }
38    }
39
40    pub fn check(&self, values: &SymbolValues) -> Option<bool> {
41        use Assertion::*;
42        match self {
43            Eq(left, right) => {
44                (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d == 0)
45            }
46            GTE(left, right) => {
47                (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d >= 0)
48            }
49            GT(left, right) => {
50                (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d > 0)
51            }
52            LTE(left, right) => {
53                (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d <= 0)
54            }
55            LT(left, right) => {
56                (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d < 0)
57            }
58        }
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65    #[test]
66    fn use_equalities() {
67        let s = SymbolScope::default();
68        s.add_assertion("s==0").unwrap();
69        assert!(s.parse_tdim("s").unwrap().simplify().is_zero());
70    }
71
72    #[test]
73    fn prove_positive_with_axiom() {
74        let s = SymbolScope::default();
75        s.add_assertion("s>=0").unwrap();
76        assert!(s.parse_tdim("s").unwrap().prove_positive_or_zero());
77    }
78
79    #[test]
80    fn prove_positive_with_axiom_2() {
81        let s = SymbolScope::default();
82        s.add_assertion("s>=0").unwrap();
83        s.add_assertion("p>=0").unwrap();
84        s.add_assertion("p+s<4096").unwrap();
85        assert!(s.parse_tdim("4096-p").unwrap().prove_positive_or_zero());
86    }
87
88    #[test]
89    fn min_max_with_axiom() {
90        let symbols = SymbolScope::default();
91        symbols.add_assertion("a>=0").unwrap();
92        assert_eq!(symbols.parse_tdim("min(a,0)").unwrap().simplify(), 0.into());
93        assert_eq!(
94            symbols.parse_tdim("max(a,0)").unwrap().simplify(),
95            symbols.parse_tdim("a").unwrap()
96        );
97    }
98
99    #[test]
100    fn low_bound_0() -> TractResult<()> {
101        let symbols = SymbolScope::default().with_assertion("S>=0")?;
102        let s = symbols.parse_tdim("S").unwrap();
103        assert_eq!(s.low_inclusive_bound(), Some(0));
104        Ok(())
105    }
106
107    #[test]
108    fn low_bound_1() -> TractResult<()> {
109        let symbols = SymbolScope::default().with_assertion("S>0")?;
110        assert_eq!(symbols.parse_tdim("S").unwrap().low_inclusive_bound(), Some(1));
111        Ok(())
112    }
113
114    #[test]
115    fn low_bound_2() -> TractResult<()> {
116        let symbols = SymbolScope::default().with_assertion("S>0")?;
117        assert_eq!(symbols.parse_tdim("S + 1").unwrap().low_inclusive_bound(), Some(2));
118        Ok(())
119    }
120
121    #[test]
122    fn low_bound_3() -> TractResult<()> {
123        let symbols = SymbolScope::default().with_assertion("S>0")?;
124        assert_eq!(symbols.parse_tdim("4*S").unwrap().low_inclusive_bound(), Some(4));
125        Ok(())
126    }
127
128    #[test]
129    fn low_bound_4() -> TractResult<()> {
130        let symbols = SymbolScope::default().with_assertion("S>0")?.with_assertion("S>5")?;
131        assert_eq!(symbols.parse_tdim("S + 3").unwrap().low_inclusive_bound(), Some(9));
132        Ok(())
133    }
134
135    #[test]
136    fn max_bug_1() {
137        let symbols = SymbolScope::default();
138        symbols.add_assertion("S>8").unwrap();
139        assert_eq!(
140            symbols.parse_tdim("max(1,-1+(S+1)/4)").unwrap().simplify(),
141            symbols.parse_tdim("-1+(S+1)/4").unwrap(),
142        );
143    }
144
145    #[test]
146    fn min_bug_1() {
147        let symbols = SymbolScope::default();
148        symbols.add_assertion("S>8").unwrap();
149        assert_eq!(
150            symbols.parse_tdim("min(1,-1+(S+1)/4)").unwrap().simplify(),
151            symbols.parse_tdim("1").unwrap()
152        );
153    }
154
155    #[test]
156    fn min_bug_2() {
157        let symbols = SymbolScope::default();
158        symbols.add_assertion("S>50").unwrap();
159        assert_eq!(
160            symbols.parse_tdim("min(-3+2*(S+1)/4,-1+(S+1)/4)").unwrap().simplify(),
161            symbols.parse_tdim("-1+(S+1)/4").unwrap()
162        );
163    }
164
165    #[test]
166    fn min_bug_3() {
167        let symbols = SymbolScope::default();
168        symbols.add_assertion("S>=0").unwrap();
169        symbols.add_assertion("P>=0").unwrap();
170        assert_eq!(
171            symbols.parse_tdim("min(0,(S)#(P+S))").unwrap().simplify(),
172            symbols.parse_tdim("0").unwrap()
173        );
174    }
175
176    #[test]
177    fn guess_scenario() -> TractResult<()> {
178        let symbols = SymbolScope::default()
179            .with_assertion("S>=0")?
180            .with_assertion("P>=0")?
181            .with_scenario_assertion("tg", "S==1")?
182            .with_scenario_assertion("pp", "P==0")?;
183        let s = symbols.sym("S");
184        let p = symbols.sym("P");
185        assert_eq!(symbols.guess_scenario(&SymbolValues::default())?, None);
186        assert_eq!(symbols.guess_scenario(&SymbolValues::default().with(&s, 50))?, Some(1));
187        assert_eq!(symbols.guess_scenario(&SymbolValues::default().with(&p, 50))?, Some(0));
188        assert!(
189            symbols.guess_scenario(&SymbolValues::default().with(&p, 50).with(&s, 50)).is_err()
190        );
191        Ok(())
192    }
193
194    #[test]
195    fn min_llm_0() -> TractResult<()> {
196        let symbols = SymbolScope::default()
197            .with_assertion("S>=0")?
198            .with_assertion("P>=0")?
199            .with_scenario_assertion("tg", "S==1")?
200            .with_scenario_assertion("pp", "P==0")?;
201        assert_eq!(
202            symbols.parse_tdim("min(P,(S)#(P+S))").unwrap().simplify(),
203            symbols.parse_tdim("P").unwrap()
204        );
205        Ok(())
206    }
207}