1use itertools::Itertools;
2use parking_lot::ReentrantMutex;
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::fmt::{self, Display};
6use std::sync::{Arc, Weak};
7use string_interner::DefaultStringInterner;
8use string_interner::Symbol as _;
9
10use crate::TractResult;
11
12use super::parse::parse_assertion;
13use super::{parse_tdim, Assertion, TDim};
14
15#[derive(Clone, Default)]
16pub struct SymbolScope(pub Arc<ReentrantMutex<RefCell<SymbolScopeData>>>);
17
18impl PartialEq for SymbolScope {
19 fn eq(&self, other: &Self) -> bool {
20 Arc::ptr_eq(&self.0, &other.0)
21 }
22}
23
24impl Eq for SymbolScope {}
25
26#[derive(Default)]
27pub struct SymbolScopeData {
28 table: DefaultStringInterner,
29 assertions: Vec<Assertion>,
30 scenarios: Vec<(String, Vec<Assertion>)>,
31}
32
33impl SymbolScope {
34 pub fn get(&self, name: &str) -> Option<Symbol> {
35 let locked = self.0.lock();
36 let locked = locked.borrow();
37 locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym))
38 }
39
40 pub fn sym(&self, name: &str) -> Symbol {
41 let locked = self.0.lock();
42 let mut locked = locked.borrow_mut();
43 let sym = locked.table.get_or_intern(name);
44 Symbol(Arc::downgrade(&self.0), sym)
45 }
46
47 pub fn new_with_prefix(&self, prefix: &str) -> Symbol {
48 let locked = self.0.lock();
49 let mut locked = locked.borrow_mut();
50 let sym = if locked.table.get(prefix).is_none() {
51 locked.table.get_or_intern(prefix)
52 } else {
53 let mut i = 0;
54 loop {
55 let s = format!("{prefix}_{i}");
56 if locked.table.get(&s).is_none() {
57 break locked.table.get_or_intern(s);
58 }
59 i += 1;
60 }
61 };
62 Symbol(Arc::downgrade(&self.0), sym)
63 }
64
65 pub fn parse_tdim(&self, input: impl AsRef<str>) -> TractResult<TDim> {
66 parse_tdim(self, input.as_ref())
67 }
68
69 pub fn add_assertion(&self, assert: impl Into<String>) -> TractResult<()> {
70 let assert = assert.into();
71 let assert = parse_assertion(self, &assert)?;
72 let locked = self.0.lock();
73 let mut locked = locked.borrow_mut();
74 locked.assertions.push(assert);
75 Ok(())
76 }
77
78 pub fn with_assertion(self, assert: impl Into<String>) -> TractResult<Self> {
79 self.add_assertion(assert)?;
80 Ok(self)
81 }
82
83 pub fn all_assertions(&self) -> Vec<Assertion> {
84 let locked = self.0.lock();
85 let locked = locked.borrow();
86 locked.assertions.clone()
87 }
88
89 pub fn all_scenarios(&self) -> impl IntoIterator<Item = (String, Vec<Assertion>)> {
90 let locked = self.0.lock();
91 let locked = locked.borrow();
92 locked.scenarios.clone()
93 }
94
95 pub fn add_scenario(&self, scenario: impl Into<String>) -> TractResult<()> {
96 let locked = self.0.lock();
97 let mut locked = locked.borrow_mut();
98 let s = scenario.into();
99 if !locked.scenarios.iter().any(|sc| sc.0 == s) {
100 locked.scenarios.push((s, vec![]));
101 }
102 Ok(())
103 }
104
105 pub fn add_scenario_assertion(
106 &self,
107 scenario: impl Into<String>,
108 assertion: impl Into<String>,
109 ) -> TractResult<()> {
110 let assert = parse_assertion(self, &assertion.into())?;
111 let s = scenario.into();
112 let locked = self.0.lock();
113 let mut locked = locked.borrow_mut();
114 if let Some(s) = locked.scenarios.iter_mut().find(|sc| sc.0 == s) {
115 s.1.push(assert);
116 } else {
117 locked.scenarios.push((s, vec![assert]));
118 }
119 Ok(())
120 }
121
122 pub fn with_scenario_assertion(
123 self,
124 scenario: impl Into<String>,
125 assertion: impl Into<String>,
126 ) -> TractResult<Self> {
127 self.add_scenario_assertion(scenario, assertion)?;
128 Ok(self)
129 }
130
131 pub fn with_scenario(self, scenario: impl Into<String>) -> TractResult<Self> {
132 self.add_scenario(scenario)?;
133 Ok(self)
134 }
135
136 pub fn all_symbols(&self) -> Vec<Symbol> {
137 self.0
138 .lock()
139 .borrow()
140 .table
141 .into_iter()
142 .map(|is| Symbol(Arc::downgrade(&self.0), is.0))
143 .collect()
144 }
145
146 pub fn guess_scenario(&self, values: &SymbolValues) -> TractResult<Option<usize>> {
147 let locked = self.0.lock();
148 let locked = locked.borrow();
149 if locked.scenarios.len() == 0 {
150 return Ok(None)
151 }
152 let mut maybe = None;
153 for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
154 if assertions.iter().any(|a| a.check(values) == Some(false)) {
155 continue;
156 } else if assertions.iter().all(|a| a.check(values) == Some(true)) {
157 return Ok(Some(ix));
158 } else if maybe.is_none() {
159 maybe = Some(ix);
160 } else {
161 return Ok(None)
162 }
163 }
164 if maybe.is_some() {
165 Ok(maybe)
166 } else {
167 anyhow::bail!("No possible scenario");
168 }
169 }
170}
171
172impl SymbolScopeData {
173 pub fn all_assertions(&self) -> &[Assertion] {
174 &self.assertions
175 }
176
177 pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator<Item = &'_ Assertion> {
178 self.assertions.iter().chain(
179 scenario
180 .and_then(|s| self.scenarios.iter().find(|s2| s2.0 == s))
181 .map(|s| &*s.1)
182 .unwrap_or(&[])
183 .iter(),
184 )
185 }
186
187 pub fn scenarios(&self) -> impl Iterator<Item = &'_ str> {
188 self.scenarios.iter().map(|s| &*s.0)
189 }
190
191 pub fn scenario(&self, s: &str) -> impl Iterator<Item = &'_ Assertion> {
192 self.scenarios.iter().find(|sc| sc.0 == s).map(|sc| &*sc.1).unwrap_or(&[]).iter()
193 }
194
195 pub fn resolving<R>(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option<R> {
196 self.table.resolve(sym.1).map(f)
197 }
198
199 #[allow(clippy::mutable_key_type)]
200 pub fn prove_positive_or_zero(&self, t: &TDim) -> bool {
201 if let TDim::Val(v) = t {
202 return *v >= 0;
203 }
204 let positives = self.assertions.iter().filter_map(|i| i.as_known_positive()).collect_vec();
205 let mut visited = vec![];
206 let mut todo = vec![t.clone()];
207 while let Some(t) = todo.pop() {
208 if t.to_i64().is_ok_and(|i| i >= 0) {
209 return true;
210 }
211 if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
212 return true;
213 }
214 let syms = t.symbols();
215 for s in syms {
216 let me = t.guess_slope(&s);
217 for pos in &positives {
218 if pos.symbols().contains(&s) {
219 let other = pos.guess_slope(&s);
220 if me.0.signum() == other.0.signum() {
221 let new = t.clone() * me.1 * other.0.abs()
222 - pos.clone() * me.0.abs() * other.1;
223 if !visited.contains(&new) {
224 todo.push(new);
225 }
226 }
227 }
228 }
229 }
230 visited.push(t);
231 if visited.len() > 10 {
232 break;
233 }
234 }
235 false
236 }
237}
238
239impl fmt::Debug for SymbolScope {
240 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241 let locked = self.0.lock();
242 let locked = locked.borrow();
243 write!(
244 f,
245 "symbols: {}; assertions: {}; {}",
246 locked.table.into_iter().map(|(_, s)| s).sorted().join(", "),
247 locked.assertions.iter().map(|s| s.to_string()).sorted().join(", "),
248 locked
249 .scenarios
250 .iter()
251 .map(|s| format!(
252 "{}: {}",
253 s.0,
254 s.1.iter().map(|s| s.to_string()).sorted().join(", ")
255 ))
256 .join(" ; "),
257 )
258 }
259}
260
261#[derive(Clone)]
262pub struct Symbol(Weak<ReentrantMutex<RefCell<SymbolScopeData>>>, string_interner::DefaultSymbol);
263
264impl Eq for Symbol {}
265
266impl PartialEq for Symbol {
267 fn eq(&self, other: &Self) -> bool {
268 self.1 == other.1
269 }
270}
271
272impl Symbol {
273 pub fn scope(&self) -> Option<SymbolScope> {
274 self.0.upgrade().map(SymbolScope)
275 }
276}
277
278impl PartialOrd for Symbol {
279 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
280 Some(self.1.cmp(&other.1))
281 }
282}
283
284impl Ord for Symbol {
285 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
286 self.1.cmp(&other.1)
287 }
288}
289
290impl std::hash::Hash for Symbol {
291 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
292 self.1.hash(state)
293 }
294}
295
296impl std::fmt::Display for Symbol {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 if let Some(scope) = self.scope() {
299 let lock = scope.0.lock();
300 let lock = lock.borrow();
301 if let Some(s) = lock.table.resolve(self.1) {
302 return write!(f, "{}", s);
303 }
304 }
305 write!(f, "<Sym{}>", self.1.to_usize())
306 }
307}
308
309impl fmt::Debug for Symbol {
310 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311 Display::fmt(&self, f)
312 }
313}
314
315#[derive(Clone, Debug, Default)]
316pub struct SymbolValues {
317 values: HashMap<Symbol, i64>,
318}
319
320impl SymbolValues {
321 pub fn with(mut self, s: &Symbol, v: i64) -> Self {
322 self.set(s, v);
323 self
324 }
325
326 pub fn set(&mut self, s: &Symbol, v: i64) {
327 self.values.insert(s.clone(), v);
328 }
329
330 pub fn get(&self, s: &Symbol) -> Option<i64> {
331 self.values.get(s).copied()
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn as_known_positive_gte() {
341 let s = SymbolScope::default();
342 assert_eq!(
343 parse_assertion(&s, "S>=0").unwrap().as_known_positive(),
344 Some(s.parse_tdim("S").unwrap())
345 );
346 }
347
348 #[test]
349 fn as_known_positive_gt() {
350 let s = SymbolScope::default();
351 assert_eq!(
352 parse_assertion(&s, "S>0").unwrap().as_known_positive(),
353 Some(s.parse_tdim("S-1").unwrap())
354 );
355 }
356
357 #[test]
358 fn as_known_positive_lte() {
359 let s = SymbolScope::default();
360 assert_eq!(
361 parse_assertion(&s, "S<=0").unwrap().as_known_positive(),
362 Some(s.parse_tdim("-S").unwrap())
363 );
364 }
365
366 #[test]
367 fn as_known_positive_lt() {
368 let s = SymbolScope::default();
369 assert_eq!(
370 parse_assertion(&s, "S<0").unwrap().as_known_positive(),
371 Some(s.parse_tdim("-S - 1").unwrap())
372 );
373 }
374
375 #[test]
376 fn prove_positive_0() {
377 let s = SymbolScope::default();
378 assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero());
379 }
380
381 #[test]
382 fn prove_positive_1() {
383 let s = SymbolScope::default();
384 assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero());
385 }
386
387 #[test]
388 fn prove_positive_neg1() {
389 let s = SymbolScope::default();
390 assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero());
391 }
392
393 #[test]
394 fn prove_positive_add_0() {
395 let s = SymbolScope::default();
396 assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero());
397 }
398}