tract_data/dim/
sym.rs

1use itertools::Itertools;
2use parking_lot::ReentrantMutex;
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::fmt::{self, Display};
6use std::sync::{Arc, Weak};
7use string_interner::DefaultStringInterner;
8use string_interner::Symbol as _;
9
10use crate::TractResult;
11
12use super::parse::parse_assertion;
13use super::{parse_tdim, Assertion, TDim};
14
15#[derive(Clone, Default)]
16pub struct SymbolScope(pub Arc<ReentrantMutex<RefCell<SymbolScopeData>>>);
17
18impl PartialEq for SymbolScope {
19    fn eq(&self, other: &Self) -> bool {
20        Arc::ptr_eq(&self.0, &other.0)
21    }
22}
23
24impl Eq for SymbolScope {}
25
26#[derive(Default)]
27pub struct SymbolScopeData {
28    table: DefaultStringInterner,
29    assertions: Vec<Assertion>,
30    scenarios: Vec<(String, Vec<Assertion>)>,
31}
32
33impl SymbolScope {
34    pub fn get(&self, name: &str) -> Option<Symbol> {
35        let locked = self.0.lock();
36        let locked = locked.borrow();
37        locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym))
38    }
39
40    pub fn sym(&self, name: &str) -> Symbol {
41        let locked = self.0.lock();
42        let mut locked = locked.borrow_mut();
43        let sym = locked.table.get_or_intern(name);
44        Symbol(Arc::downgrade(&self.0), sym)
45    }
46
47    pub fn new_with_prefix(&self, prefix: &str) -> Symbol {
48        let locked = self.0.lock();
49        let mut locked = locked.borrow_mut();
50        let sym = if locked.table.get(prefix).is_none() {
51            locked.table.get_or_intern(prefix)
52        } else {
53            let mut i = 0;
54            loop {
55                let s = format!("{prefix}_{i}");
56                if locked.table.get(&s).is_none() {
57                    break locked.table.get_or_intern(s);
58                }
59                i += 1;
60            }
61        };
62        Symbol(Arc::downgrade(&self.0), sym)
63    }
64
65    pub fn parse_tdim(&self, input: impl AsRef<str>) -> TractResult<TDim> {
66        parse_tdim(self, input.as_ref())
67    }
68
69    pub fn add_assertion(&self, assert: impl Into<String>) -> TractResult<()> {
70        let assert = assert.into();
71        let assert = parse_assertion(self, &assert)?;
72        let locked = self.0.lock();
73        let mut locked = locked.borrow_mut();
74        locked.assertions.push(assert);
75        Ok(())
76    }
77
78    pub fn with_assertion(self, assert: impl Into<String>) -> TractResult<Self> {
79        self.add_assertion(assert)?;
80        Ok(self)
81    }
82
83    pub fn all_assertions(&self) -> Vec<Assertion> {
84        let locked = self.0.lock();
85        let locked = locked.borrow();
86        locked.assertions.clone()
87    }
88
89    pub fn all_scenarios(&self) -> impl IntoIterator<Item = (String, Vec<Assertion>)> {
90        let locked = self.0.lock();
91        let locked = locked.borrow();
92        locked.scenarios.clone()
93    }
94
95    pub fn add_scenario(&self, scenario: impl Into<String>) -> TractResult<()> {
96        let locked = self.0.lock();
97        let mut locked = locked.borrow_mut();
98        let s = scenario.into();
99        if !locked.scenarios.iter().any(|sc| sc.0 == s) {
100            locked.scenarios.push((s, vec![]));
101        }
102        Ok(())
103    }
104
105    pub fn add_scenario_assertion(
106        &self,
107        scenario: impl Into<String>,
108        assertion: impl Into<String>,
109    ) -> TractResult<()> {
110        let assert = parse_assertion(self, &assertion.into())?;
111        let s = scenario.into();
112        let locked = self.0.lock();
113        let mut locked = locked.borrow_mut();
114        if let Some(s) = locked.scenarios.iter_mut().find(|sc| sc.0 == s) {
115            s.1.push(assert);
116        } else {
117            locked.scenarios.push((s, vec![assert]));
118        }
119        Ok(())
120    }
121
122    pub fn with_scenario_assertion(
123        self,
124        scenario: impl Into<String>,
125        assertion: impl Into<String>,
126    ) -> TractResult<Self> {
127        self.add_scenario_assertion(scenario, assertion)?;
128        Ok(self)
129    }
130
131    pub fn with_scenario(self, scenario: impl Into<String>) -> TractResult<Self> {
132        self.add_scenario(scenario)?;
133        Ok(self)
134    }
135
136    pub fn all_symbols(&self) -> Vec<Symbol> {
137        self.0
138            .lock()
139            .borrow()
140            .table
141            .into_iter()
142            .map(|is| Symbol(Arc::downgrade(&self.0), is.0))
143            .collect()
144    }
145
146    pub fn guess_scenario(&self, values: &SymbolValues) -> TractResult<Option<usize>> {
147        let locked = self.0.lock();
148        let locked = locked.borrow();
149        if locked.scenarios.len() == 0 {
150            return Ok(None);
151        }
152        let mut maybe = None;
153        for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
154            if assertions.iter().any(|a| a.check(values) == Some(false)) {
155                continue;
156            } else if assertions.iter().all(|a| a.check(values) == Some(true)) {
157                return Ok(Some(ix));
158            } else if maybe.is_none() {
159                maybe = Some(ix);
160            } else {
161                return Ok(None);
162            }
163        }
164        if maybe.is_some() {
165            Ok(maybe)
166        } else {
167            anyhow::bail!("No possible scenario");
168        }
169    }
170}
171
172impl SymbolScopeData {
173    pub fn all_assertions(&self) -> &[Assertion] {
174        &self.assertions
175    }
176
177    pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator<Item = &'_ Assertion> {
178        self.assertions.iter().chain(
179            scenario
180                .and_then(|s| self.scenarios.iter().find(|s2| s2.0 == s))
181                .map(|s| &*s.1)
182                .unwrap_or(&[])
183                .iter(),
184        )
185    }
186
187    pub fn scenarios(&self) -> impl Iterator<Item = &'_ str> {
188        self.scenarios.iter().map(|s| &*s.0)
189    }
190
191    pub fn scenario(&self, s: &str) -> impl Iterator<Item = &'_ Assertion> {
192        self.scenarios.iter().find(|sc| sc.0 == s).map(|sc| &*sc.1).unwrap_or(&[]).iter()
193    }
194
195    pub fn resolving<R>(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option<R> {
196        self.table.resolve(sym.1).map(f)
197    }
198
199    #[allow(clippy::mutable_key_type)]
200    pub fn prove_positive_or_zero(&self, t: &TDim) -> bool {
201        if let TDim::Val(v) = t {
202            return *v >= 0;
203        }
204        let positives = self.assertions.iter().filter_map(|i| i.as_known_positive()).collect_vec();
205        let mut visited = vec![];
206        let mut todo = vec![t.clone()];
207        while let Some(t) = todo.pop() {
208            if t.to_i64().is_ok_and(|i| i >= 0) {
209                return true;
210            }
211            if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
212                return true;
213            }
214            let syms = t.symbols();
215            for s in syms {
216                let me = t.guess_slope(&s);
217                for pos in &positives {
218                    if pos.symbols().contains(&s) {
219                        let other = pos.guess_slope(&s);
220                        if me.0.signum() == other.0.signum() {
221                            let new = t.clone() * me.1 * other.0.abs()
222                                - pos.clone() * me.0.abs() * other.1;
223                            if !visited.contains(&new) {
224                                todo.push(new);
225                            }
226                        }
227                    }
228                }
229            }
230            visited.push(t);
231            if visited.len() > 10 {
232                break;
233            }
234        }
235        false
236    }
237
238    pub(crate) fn prove_strict_positive(&self, b: &TDim) -> bool {
239        self.prove_positive_or_zero(&(b.clone() - 1))
240    }
241}
242
243impl fmt::Debug for SymbolScope {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        let locked = self.0.lock();
246        let locked = locked.borrow();
247        write!(
248            f,
249            "symbols: {}; assertions: {}; {}",
250            locked.table.into_iter().map(|(_, s)| s).sorted().join(", "),
251            locked.assertions.iter().map(|s| s.to_string()).sorted().join(", "),
252            locked
253                .scenarios
254                .iter()
255                .map(|s| format!(
256                    "{}: {}",
257                    s.0,
258                    s.1.iter().map(|s| s.to_string()).sorted().join(", ")
259                ))
260                .join(" ; "),
261        )
262    }
263}
264
265#[derive(Clone)]
266pub struct Symbol(Weak<ReentrantMutex<RefCell<SymbolScopeData>>>, string_interner::DefaultSymbol);
267
268impl Eq for Symbol {}
269
270impl PartialEq for Symbol {
271    fn eq(&self, other: &Self) -> bool {
272        self.1 == other.1
273    }
274}
275
276impl Symbol {
277    pub fn scope(&self) -> Option<SymbolScope> {
278        self.0.upgrade().map(SymbolScope)
279    }
280}
281
282impl PartialOrd for Symbol {
283    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
284        Some(self.cmp(other))
285    }
286}
287
288impl Ord for Symbol {
289    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
290        self.1.cmp(&other.1)
291    }
292}
293
294impl std::hash::Hash for Symbol {
295    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
296        self.1.hash(state)
297    }
298}
299
300impl std::fmt::Display for Symbol {
301    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302        if let Some(scope) = self.scope() {
303            let lock = scope.0.lock();
304            let lock = lock.borrow();
305            if let Some(s) = lock.table.resolve(self.1) {
306                return write!(f, "{s}");
307            }
308        }
309        write!(f, "<Sym{}>", self.1.to_usize())
310    }
311}
312
313impl fmt::Debug for Symbol {
314    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
315        Display::fmt(&self, f)
316    }
317}
318
319#[derive(Clone, Debug, Default)]
320pub struct SymbolValues {
321    values: HashMap<Symbol, i64>,
322}
323
324impl SymbolValues {
325    pub fn with(mut self, s: &Symbol, v: i64) -> Self {
326        self.set(s, v);
327        self
328    }
329
330    pub fn set(&mut self, s: &Symbol, v: i64) {
331        self.values.insert(s.clone(), v);
332    }
333
334    pub fn get(&self, s: &Symbol) -> Option<i64> {
335        self.values.get(s).copied()
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn as_known_positive_gte() {
345        let s = SymbolScope::default();
346        assert_eq!(
347            parse_assertion(&s, "S>=0").unwrap().as_known_positive(),
348            Some(s.parse_tdim("S").unwrap())
349        );
350    }
351
352    #[test]
353    fn as_known_positive_gt() {
354        let s = SymbolScope::default();
355        assert_eq!(
356            parse_assertion(&s, "S>0").unwrap().as_known_positive(),
357            Some(s.parse_tdim("S-1").unwrap())
358        );
359    }
360
361    #[test]
362    fn as_known_positive_lte() {
363        let s = SymbolScope::default();
364        assert_eq!(
365            parse_assertion(&s, "S<=0").unwrap().as_known_positive(),
366            Some(s.parse_tdim("-S").unwrap())
367        );
368    }
369
370    #[test]
371    fn as_known_positive_lt() {
372        let s = SymbolScope::default();
373        assert_eq!(
374            parse_assertion(&s, "S<0").unwrap().as_known_positive(),
375            Some(s.parse_tdim("-S - 1").unwrap())
376        );
377    }
378
379    #[test]
380    fn prove_positive_0() {
381        let s = SymbolScope::default();
382        assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero());
383    }
384
385    #[test]
386    fn prove_positive_1() {
387        let s = SymbolScope::default();
388        assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero());
389    }
390
391    #[test]
392    fn prove_positive_neg1() {
393        let s = SymbolScope::default();
394        assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero());
395    }
396
397    #[test]
398    fn prove_positive_add_0() {
399        let s = SymbolScope::default();
400        assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero());
401    }
402}