tract_data/dim/
sym.rs

1use itertools::Itertools;
2use parking_lot::ReentrantMutex;
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::fmt::{self, Display};
6use std::sync::{Arc, Weak};
7use string_interner::DefaultStringInterner;
8use string_interner::Symbol as _;
9
10use crate::TractResult;
11
12use super::parse::parse_assertion;
13use super::{parse_tdim, Assertion, TDim};
14
15#[derive(Clone, Default)]
16pub struct SymbolScope(pub Arc<ReentrantMutex<RefCell<SymbolScopeData>>>);
17
18impl PartialEq for SymbolScope {
19    fn eq(&self, other: &Self) -> bool {
20        Arc::ptr_eq(&self.0, &other.0)
21    }
22}
23
24impl Eq for SymbolScope {}
25
26#[derive(Default)]
27pub struct SymbolScopeData {
28    table: DefaultStringInterner,
29    assertions: Vec<Assertion>,
30    scenarios: Vec<(String, Vec<Assertion>)>,
31}
32
33impl SymbolScope {
34    pub fn get(&self, name: &str) -> Option<Symbol> {
35        let locked = self.0.lock();
36        let locked = locked.borrow();
37        locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym))
38    }
39
40    pub fn sym(&self, name: &str) -> Symbol {
41        let locked = self.0.lock();
42        let mut locked = locked.borrow_mut();
43        let sym = locked.table.get_or_intern(name);
44        Symbol(Arc::downgrade(&self.0), sym)
45    }
46
47    pub fn new_with_prefix(&self, prefix: &str) -> Symbol {
48        let locked = self.0.lock();
49        let mut locked = locked.borrow_mut();
50        let sym = if locked.table.get(prefix).is_none() {
51            locked.table.get_or_intern(prefix)
52        } else {
53            let mut i = 0;
54            loop {
55                let s = format!("{prefix}_{i}");
56                if locked.table.get(&s).is_none() {
57                    break locked.table.get_or_intern(s);
58                }
59                i += 1;
60            }
61        };
62        Symbol(Arc::downgrade(&self.0), sym)
63    }
64
65    pub fn parse_tdim(&self, input: impl AsRef<str>) -> TractResult<TDim> {
66        parse_tdim(self, input.as_ref())
67    }
68
69    pub fn add_assertion(&self, assert: impl Into<String>) -> TractResult<()> {
70        let assert = assert.into();
71        let assert = parse_assertion(self, &assert)?;
72        let locked = self.0.lock();
73        let mut locked = locked.borrow_mut();
74        locked.assertions.push(assert);
75        Ok(())
76    }
77
78    pub fn with_assertion(self, assert: impl Into<String>) -> TractResult<Self> {
79        self.add_assertion(assert)?;
80        Ok(self)
81    }
82
83    pub fn all_assertions(&self) -> Vec<Assertion> {
84        let locked = self.0.lock();
85        let locked = locked.borrow();
86        locked.assertions.clone()
87    }
88
89    pub fn all_scenarios(&self) -> impl IntoIterator<Item = (String, Vec<Assertion>)> {
90        let locked = self.0.lock();
91        let locked = locked.borrow();
92        locked.scenarios.clone()
93    }
94
95    pub fn add_scenario(&self, scenario: impl Into<String>) -> TractResult<()> {
96        let locked = self.0.lock();
97        let mut locked = locked.borrow_mut();
98        let s = scenario.into();
99        if !locked.scenarios.iter().any(|sc| sc.0 == s) {
100            locked.scenarios.push((s, vec![]));
101        }
102        Ok(())
103    }
104
105    pub fn add_scenario_assertion(
106        &self,
107        scenario: impl Into<String>,
108        assertion: impl Into<String>,
109    ) -> TractResult<()> {
110        let assert = parse_assertion(self, &assertion.into())?;
111        let s = scenario.into();
112        let locked = self.0.lock();
113        let mut locked = locked.borrow_mut();
114        if let Some(s) = locked.scenarios.iter_mut().find(|sc| sc.0 == s) {
115            s.1.push(assert);
116        } else {
117            locked.scenarios.push((s, vec![assert]));
118        }
119        Ok(())
120    }
121
122    pub fn with_scenario_assertion(
123        self,
124        scenario: impl Into<String>,
125        assertion: impl Into<String>,
126    ) -> TractResult<Self> {
127        self.add_scenario_assertion(scenario, assertion)?;
128        Ok(self)
129    }
130
131    pub fn with_scenario(self, scenario: impl Into<String>) -> TractResult<Self> {
132        self.add_scenario(scenario)?;
133        Ok(self)
134    }
135
136    pub fn all_symbols(&self) -> Vec<Symbol> {
137        self.0
138            .lock()
139            .borrow()
140            .table
141            .into_iter()
142            .map(|is| Symbol(Arc::downgrade(&self.0), is.0))
143            .collect()
144    }
145
146    pub fn guess_scenario(&self, values: &SymbolValues) -> TractResult<Option<usize>> {
147        let locked = self.0.lock();
148        let locked = locked.borrow();
149        if locked.scenarios.len() == 0 {
150            return Ok(None)
151        }
152        let mut maybe = None;
153        for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
154            if assertions.iter().any(|a| a.check(values) == Some(false)) {
155                continue;
156            } else if assertions.iter().all(|a| a.check(values) == Some(true)) {
157                return Ok(Some(ix));
158            } else if maybe.is_none() {
159                maybe = Some(ix);
160            } else {
161                return Ok(None)
162            }
163        }
164        if maybe.is_some() {
165            Ok(maybe)
166        } else {
167            anyhow::bail!("No possible scenario");
168        }
169    }
170}
171
172impl SymbolScopeData {
173    pub fn all_assertions(&self) -> &[Assertion] {
174        &self.assertions
175    }
176
177    pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator<Item = &'_ Assertion> {
178        self.assertions.iter().chain(
179            scenario
180                .and_then(|s| self.scenarios.iter().find(|s2| s2.0 == s))
181                .map(|s| &*s.1)
182                .unwrap_or(&[])
183                .iter(),
184        )
185    }
186
187    pub fn scenarios(&self) -> impl Iterator<Item = &'_ str> {
188        self.scenarios.iter().map(|s| &*s.0)
189    }
190
191    pub fn scenario(&self, s: &str) -> impl Iterator<Item = &'_ Assertion> {
192        self.scenarios.iter().find(|sc| sc.0 == s).map(|sc| &*sc.1).unwrap_or(&[]).iter()
193    }
194
195    pub fn resolving<R>(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option<R> {
196        self.table.resolve(sym.1).map(f)
197    }
198
199    #[allow(clippy::mutable_key_type)]
200    pub fn prove_positive_or_zero(&self, t: &TDim) -> bool {
201        if let TDim::Val(v) = t {
202            return *v >= 0;
203        }
204        let positives = self.assertions.iter().filter_map(|i| i.as_known_positive()).collect_vec();
205        let mut visited = vec![];
206        let mut todo = vec![t.clone()];
207        while let Some(t) = todo.pop() {
208            if t.to_i64().is_ok_and(|i| i >= 0) {
209                return true;
210            }
211            if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
212                return true;
213            }
214            let syms = t.symbols();
215            for s in syms {
216                let me = t.guess_slope(&s);
217                for pos in &positives {
218                    if pos.symbols().contains(&s) {
219                        let other = pos.guess_slope(&s);
220                        if me.0.signum() == other.0.signum() {
221                            let new = t.clone() * me.1 * other.0.abs()
222                                - pos.clone() * me.0.abs() * other.1;
223                            if !visited.contains(&new) {
224                                todo.push(new);
225                            }
226                        }
227                    }
228                }
229            }
230            visited.push(t);
231            if visited.len() > 10 {
232                break;
233            }
234        }
235        false
236    }
237}
238
239impl fmt::Debug for SymbolScope {
240    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241        let locked = self.0.lock();
242        let locked = locked.borrow();
243        write!(
244            f,
245            "symbols: {}; assertions: {}; {}",
246            locked.table.into_iter().map(|(_, s)| s).sorted().join(", "),
247            locked.assertions.iter().map(|s| s.to_string()).sorted().join(", "),
248            locked
249                .scenarios
250                .iter()
251                .map(|s| format!(
252                    "{}: {}",
253                    s.0,
254                    s.1.iter().map(|s| s.to_string()).sorted().join(", ")
255                ))
256                .join(" ; "),
257        )
258    }
259}
260
261#[derive(Clone)]
262pub struct Symbol(Weak<ReentrantMutex<RefCell<SymbolScopeData>>>, string_interner::DefaultSymbol);
263
264impl Eq for Symbol {}
265
266impl PartialEq for Symbol {
267    fn eq(&self, other: &Self) -> bool {
268        self.1 == other.1
269    }
270}
271
272impl Symbol {
273    pub fn scope(&self) -> Option<SymbolScope> {
274        self.0.upgrade().map(SymbolScope)
275    }
276}
277
278impl PartialOrd for Symbol {
279    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
280        Some(self.1.cmp(&other.1))
281    }
282}
283
284impl Ord for Symbol {
285    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
286        self.1.cmp(&other.1)
287    }
288}
289
290impl std::hash::Hash for Symbol {
291    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
292        self.1.hash(state)
293    }
294}
295
296impl std::fmt::Display for Symbol {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        if let Some(scope) = self.scope() {
299            let lock = scope.0.lock();
300            let lock = lock.borrow();
301            if let Some(s) = lock.table.resolve(self.1) {
302                return write!(f, "{}", s);
303            }
304        }
305        write!(f, "<Sym{}>", self.1.to_usize())
306    }
307}
308
309impl fmt::Debug for Symbol {
310    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311        Display::fmt(&self, f)
312    }
313}
314
315#[derive(Clone, Debug, Default)]
316pub struct SymbolValues {
317    values: HashMap<Symbol, i64>,
318}
319
320impl SymbolValues {
321    pub fn with(mut self, s: &Symbol, v: i64) -> Self {
322        self.set(s, v);
323        self
324    }
325
326    pub fn set(&mut self, s: &Symbol, v: i64) {
327        self.values.insert(s.clone(), v);
328    }
329
330    pub fn get(&self, s: &Symbol) -> Option<i64> {
331        self.values.get(s).copied()
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn as_known_positive_gte() {
341        let s = SymbolScope::default();
342        assert_eq!(
343            parse_assertion(&s, "S>=0").unwrap().as_known_positive(),
344            Some(s.parse_tdim("S").unwrap())
345        );
346    }
347
348    #[test]
349    fn as_known_positive_gt() {
350        let s = SymbolScope::default();
351        assert_eq!(
352            parse_assertion(&s, "S>0").unwrap().as_known_positive(),
353            Some(s.parse_tdim("S-1").unwrap())
354        );
355    }
356
357    #[test]
358    fn as_known_positive_lte() {
359        let s = SymbolScope::default();
360        assert_eq!(
361            parse_assertion(&s, "S<=0").unwrap().as_known_positive(),
362            Some(s.parse_tdim("-S").unwrap())
363        );
364    }
365
366    #[test]
367    fn as_known_positive_lt() {
368        let s = SymbolScope::default();
369        assert_eq!(
370            parse_assertion(&s, "S<0").unwrap().as_known_positive(),
371            Some(s.parse_tdim("-S - 1").unwrap())
372        );
373    }
374
375    #[test]
376    fn prove_positive_0() {
377        let s = SymbolScope::default();
378        assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero());
379    }
380
381    #[test]
382    fn prove_positive_1() {
383        let s = SymbolScope::default();
384        assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero());
385    }
386
387    #[test]
388    fn prove_positive_neg1() {
389        let s = SymbolScope::default();
390        assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero());
391    }
392
393    #[test]
394    fn prove_positive_add_0() {
395        let s = SymbolScope::default();
396        assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero());
397    }
398}