tract_data/dim/
assertion.rs1use fmt::Display;
2
3use super::*;
4
5#[derive(Debug, PartialEq, Clone, Hash)]
6#[allow(clippy::upper_case_acronyms)]
7pub enum Assertion {
8 Eq(TDim, TDim),
9 LT(TDim, TDim),
10 GT(TDim, TDim),
11 LTE(TDim, TDim),
12 GTE(TDim, TDim),
13}
14
15impl Display for Assertion {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 use Assertion::*;
18 match self {
19 Eq(l, r) => write!(f, "{l} == {r}"),
20 LT(l, r) => write!(f, "{l} < {r}"),
21 GT(l, r) => write!(f, "{l} > {r}"),
22 LTE(l, r) => write!(f, "{l} <= {r}"),
23 GTE(l, r) => write!(f, "{l} >= {r}"),
24 }
25 }
26}
27
28impl Assertion {
29 pub fn as_known_positive(&self) -> Option<TDim> {
30 use Assertion::*;
31 match self {
32 Eq(left, right) => Some(left.clone() - right),
33 GTE(left, right) => Some(left.clone() - right),
34 GT(left, right) => Some(left.clone() - 1 - right),
35 LTE(left, right) => Some(right.clone() - left),
36 LT(left, right) => Some(right.clone() - 1 - left),
37 }
38 }
39
40 pub fn check(&self, values: &SymbolValues) -> Option<bool> {
41 use Assertion::*;
42 match self {
43 Eq(left, right) => {
44 (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d == 0)
45 }
46 GTE(left, right) => {
47 (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d >= 0)
48 }
49 GT(left, right) => {
50 (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d > 0)
51 }
52 LTE(left, right) => {
53 (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d <= 0)
54 }
55 LT(left, right) => {
56 (left.eval(values) - right.eval(values)).to_i64().ok().map(|d| d < 0)
57 }
58 }
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65 #[test]
66 fn use_equalities() {
67 let s = SymbolScope::default();
68 s.add_assertion("s==0").unwrap();
69 assert!(s.parse_tdim("s").unwrap().simplify().is_zero());
70 }
71
72 #[test]
73 fn prove_positive_with_axiom() {
74 let s = SymbolScope::default();
75 s.add_assertion("s>=0").unwrap();
76 assert!(s.parse_tdim("s").unwrap().prove_positive_or_zero());
77 }
78
79 #[test]
80 fn prove_positive_with_axiom_2() {
81 let s = SymbolScope::default();
82 s.add_assertion("s>=0").unwrap();
83 s.add_assertion("p>=0").unwrap();
84 s.add_assertion("p+s<4096").unwrap();
85 assert!(s.parse_tdim("4096-p").unwrap().prove_positive_or_zero());
86 }
87
88 #[test]
89 fn min_max_with_axiom() {
90 let symbols = SymbolScope::default();
91 symbols.add_assertion("a>=0").unwrap();
92 assert_eq!(symbols.parse_tdim("min(a,0)").unwrap().simplify(), 0.into());
93 assert_eq!(
94 symbols.parse_tdim("max(a,0)").unwrap().simplify(),
95 symbols.parse_tdim("a").unwrap()
96 );
97 }
98
99 #[test]
100 fn low_bound_0() -> TractResult<()> {
101 let symbols = SymbolScope::default().with_assertion("S>=0")?;
102 let s = symbols.parse_tdim("S").unwrap();
103 assert_eq!(s.low_inclusive_bound(), Some(0));
104 Ok(())
105 }
106
107 #[test]
108 fn low_bound_1() -> TractResult<()> {
109 let symbols = SymbolScope::default().with_assertion("S>0")?;
110 assert_eq!(symbols.parse_tdim("S").unwrap().low_inclusive_bound(), Some(1));
111 Ok(())
112 }
113
114 #[test]
115 fn low_bound_2() -> TractResult<()> {
116 let symbols = SymbolScope::default().with_assertion("S>0")?;
117 assert_eq!(symbols.parse_tdim("S + 1").unwrap().low_inclusive_bound(), Some(2));
118 Ok(())
119 }
120
121 #[test]
122 fn low_bound_3() -> TractResult<()> {
123 let symbols = SymbolScope::default().with_assertion("S>0")?;
124 assert_eq!(symbols.parse_tdim("4*S").unwrap().low_inclusive_bound(), Some(4));
125 Ok(())
126 }
127
128 #[test]
129 fn low_bound_4() -> TractResult<()> {
130 let symbols = SymbolScope::default().with_assertion("S>0")?.with_assertion("S>5")?;
131 assert_eq!(symbols.parse_tdim("S + 3").unwrap().low_inclusive_bound(), Some(9));
132 Ok(())
133 }
134
135 #[test]
136 fn max_bug_1() {
137 let symbols = SymbolScope::default();
138 symbols.add_assertion("S>8").unwrap();
139 assert_eq!(
140 symbols.parse_tdim("max(1,-1+(S+1)/4)").unwrap().simplify(),
141 symbols.parse_tdim("-1+(S+1)/4").unwrap(),
142 );
143 }
144
145 #[test]
146 fn min_bug_1() {
147 let symbols = SymbolScope::default();
148 symbols.add_assertion("S>8").unwrap();
149 assert_eq!(
150 symbols.parse_tdim("min(1,-1+(S+1)/4)").unwrap().simplify(),
151 symbols.parse_tdim("1").unwrap()
152 );
153 }
154
155 #[test]
156 fn min_bug_2() {
157 let symbols = SymbolScope::default();
158 symbols.add_assertion("S>50").unwrap();
159 assert_eq!(
160 symbols.parse_tdim("min(-3+2*(S+1)/4,-1+(S+1)/4)").unwrap().simplify(),
161 symbols.parse_tdim("-1+(S+1)/4").unwrap()
162 );
163 }
164
165 #[test]
166 fn min_bug_3() {
167 let symbols = SymbolScope::default();
168 symbols.add_assertion("S>=0").unwrap();
169 symbols.add_assertion("P>=0").unwrap();
170 assert_eq!(
171 symbols.parse_tdim("min(0,(S)#(P+S))").unwrap().simplify(),
172 symbols.parse_tdim("0").unwrap()
173 );
174 }
175
176 #[test]
177 fn guess_scenario() -> TractResult<()> {
178 let symbols = SymbolScope::default()
179 .with_assertion("S>=0")?
180 .with_assertion("P>=0")?
181 .with_scenario_assertion("tg", "S==1")?
182 .with_scenario_assertion("pp", "P==0")?;
183 let s = symbols.sym("S");
184 let p = symbols.sym("P");
185 assert_eq!(symbols.guess_scenario(&SymbolValues::default())?, None);
186 assert_eq!(symbols.guess_scenario(&SymbolValues::default().with(&s, 50))?, Some(1));
187 assert_eq!(symbols.guess_scenario(&SymbolValues::default().with(&p, 50))?, Some(0));
188 assert!(
189 symbols.guess_scenario(&SymbolValues::default().with(&p, 50).with(&s, 50)).is_err()
190 );
191 Ok(())
192 }
193
194 #[test]
195 fn min_llm_0() -> TractResult<()> {
196 let symbols = SymbolScope::default()
197 .with_assertion("S>=0")?
198 .with_assertion("P>=0")?
199 .with_scenario_assertion("tg", "S==1")?
200 .with_scenario_assertion("pp", "P==0")?;
201 assert_eq!(
202 symbols.parse_tdim("min(P,(S)#(P+S))").unwrap().simplify(),
203 symbols.parse_tdim("P").unwrap()
204 );
205 Ok(())
206 }
207}