Skip to main content

tract_data/dim/
sym.rs

1use itertools::Itertools;
2use parking_lot::ReentrantMutex;
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::fmt::{self, Display};
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::{Arc, Weak};
8use string_interner::DefaultStringInterner;
9use string_interner::Symbol as _;
10
11use crate::TractResult;
12
13use super::parse::parse_assertion;
14use super::{Assertion, TDim, parse_tdim};
15
16static SCOPE_COUNTER: AtomicUsize = AtomicUsize::new(0);
17
18/// Wrapper with lock-free hot-path flags alongside the scope data lock.
19/// Derefs to the inner mutex so existing `scope.0.lock()` call sites keep
20/// working.
21pub struct SymbolScopeInner {
22    /// Lock-free: true iff `scenarios` is non-empty. Maintained by
23    /// `add_scenario` / `add_scenario_assertion`. Lets `guess_scenario` skip
24    /// the lock entirely on the common no-scenarios path.
25    has_scenarios: AtomicBool,
26    data: ReentrantMutex<RefCell<SymbolScopeData>>,
27}
28
29impl Default for SymbolScopeInner {
30    fn default() -> Self {
31        SymbolScopeInner { has_scenarios: AtomicBool::new(false), data: Default::default() }
32    }
33}
34
35impl std::ops::Deref for SymbolScopeInner {
36    type Target = ReentrantMutex<RefCell<SymbolScopeData>>;
37    fn deref(&self) -> &Self::Target {
38        &self.data
39    }
40}
41
42#[derive(Clone, Default)]
43pub struct SymbolScope(pub Arc<SymbolScopeInner>);
44
45impl PartialEq for SymbolScope {
46    fn eq(&self, other: &Self) -> bool {
47        Arc::ptr_eq(&self.0, &other.0)
48    }
49}
50
51impl Eq for SymbolScope {}
52
53pub struct SymbolScopeData {
54    id: usize,
55    table: DefaultStringInterner,
56    assertions: Vec<Assertion>,
57    scenarios: Vec<(String, Vec<Assertion>)>,
58}
59
60impl Default for SymbolScopeData {
61    fn default() -> Self {
62        SymbolScopeData {
63            id: SCOPE_COUNTER.fetch_add(1, Ordering::Relaxed),
64            table: DefaultStringInterner::default(),
65            assertions: Vec::new(),
66            scenarios: Vec::new(),
67        }
68    }
69}
70
71impl SymbolScope {
72    pub fn id(&self) -> usize {
73        let locked = self.0.lock();
74        let locked = locked.borrow();
75        locked.id
76    }
77
78    pub fn proof_cache_session(&self) -> ProofCacheSession {
79        ProofCacheSession::new(self.id())
80    }
81
82    pub fn get(&self, name: &str) -> Option<Symbol> {
83        let locked = self.0.lock();
84        let locked = locked.borrow();
85        locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym))
86    }
87
88    /// Get or create the coordinate symbol for axis `k` (named "🎯{k}").
89    pub fn coord_sym(&self, k: usize) -> Symbol {
90        self.sym(&format!("🎯{k}"))
91    }
92
93    pub fn sym(&self, name: &str) -> Symbol {
94        let locked = self.0.lock();
95        let mut locked = locked.borrow_mut();
96        let sym = locked.table.get_or_intern(name);
97        Symbol(Arc::downgrade(&self.0), sym)
98    }
99
100    pub fn new_with_prefix(&self, prefix: &str) -> Symbol {
101        let locked = self.0.lock();
102        let mut locked = locked.borrow_mut();
103        let sym = if locked.table.get(prefix).is_none() {
104            locked.table.get_or_intern(prefix)
105        } else {
106            let mut i = 0;
107            loop {
108                let s = format!("{prefix}_{i}");
109                if locked.table.get(&s).is_none() {
110                    break locked.table.get_or_intern(s);
111                }
112                i += 1;
113            }
114        };
115        Symbol(Arc::downgrade(&self.0), sym)
116    }
117
118    pub fn parse_tdim(&self, input: impl AsRef<str>) -> TractResult<TDim> {
119        parse_tdim(self, input.as_ref())
120    }
121
122    pub fn add_assertion(&self, assert: impl Into<String>) -> TractResult<()> {
123        let assert = assert.into();
124        let assert = parse_assertion(self, &assert)?;
125        let locked = self.0.lock();
126        let mut locked = locked.borrow_mut();
127        locked.assertions.push(assert);
128        Ok(())
129    }
130
131    pub fn with_assertion(self, assert: impl Into<String>) -> TractResult<Self> {
132        self.add_assertion(assert)?;
133        Ok(self)
134    }
135
136    pub fn all_assertions(&self) -> Vec<Assertion> {
137        let locked = self.0.lock();
138        let locked = locked.borrow();
139        locked.assertions.clone()
140    }
141
142    pub fn all_scenarios(&self) -> impl IntoIterator<Item = (String, Vec<Assertion>)> {
143        let locked = self.0.lock();
144        let locked = locked.borrow();
145        locked.scenarios.clone()
146    }
147
148    pub fn add_scenario(&self, scenario: impl Into<String>) -> TractResult<()> {
149        let locked = self.0.lock();
150        let mut locked = locked.borrow_mut();
151        let s = scenario.into();
152        if !locked.scenarios.iter().any(|sc| sc.0 == s) {
153            locked.scenarios.push((s, vec![]));
154            self.0.has_scenarios.store(true, Ordering::Relaxed);
155        }
156        Ok(())
157    }
158
159    pub fn add_scenario_assertion(
160        &self,
161        scenario: impl Into<String>,
162        assertion: impl Into<String>,
163    ) -> TractResult<()> {
164        let assert = parse_assertion(self, &assertion.into())?;
165        let s = scenario.into();
166        let locked = self.0.lock();
167        let mut locked = locked.borrow_mut();
168        if let Some(s) = locked.scenarios.iter_mut().find(|sc| sc.0 == s) {
169            s.1.push(assert);
170        } else {
171            locked.scenarios.push((s, vec![assert]));
172            self.0.has_scenarios.store(true, Ordering::Relaxed);
173        }
174        Ok(())
175    }
176
177    pub fn with_scenario_assertion(
178        self,
179        scenario: impl Into<String>,
180        assertion: impl Into<String>,
181    ) -> TractResult<Self> {
182        self.add_scenario_assertion(scenario, assertion)?;
183        Ok(self)
184    }
185
186    pub fn with_scenario(self, scenario: impl Into<String>) -> TractResult<Self> {
187        self.add_scenario(scenario)?;
188        Ok(self)
189    }
190
191    pub fn all_symbols(&self) -> Vec<Symbol> {
192        self.0
193            .lock()
194            .borrow()
195            .table
196            .into_iter()
197            .map(|is| Symbol(Arc::downgrade(&self.0), is.0))
198            .collect()
199    }
200
201    pub fn guess_scenario(&self, values: &SymbolValues) -> TractResult<Option<usize>> {
202        // Hot path: most scopes have no scenarios. Skip the lock entirely.
203        if !self.0.has_scenarios.load(Ordering::Relaxed) {
204            return Ok(None);
205        }
206        let locked = self.0.lock();
207        let locked = locked.borrow();
208        if locked.scenarios.is_empty() {
209            return Ok(None);
210        }
211        let mut maybe = None;
212        for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
213            if assertions.iter().any(|a| a.check(values) == Some(false)) {
214                continue;
215            } else if assertions.iter().all(|a| a.check(values) == Some(true)) {
216                return Ok(Some(ix));
217            } else if maybe.is_none() {
218                maybe = Some(ix);
219            } else {
220                return Ok(None);
221            }
222        }
223        if maybe.is_some() {
224            Ok(maybe)
225        } else {
226            anyhow::bail!("No possible scenario");
227        }
228    }
229}
230
231thread_local! {
232    static PROOF_CACHE: RefCell<Option<ProofCache>> = const { RefCell::new(None) };
233}
234
235struct ProofCache {
236    scope_id: usize,
237    depth: usize,
238    cache: HashMap<TDim, bool>,
239}
240
241pub struct ProofCacheSession {
242    active: bool,
243}
244
245impl ProofCacheSession {
246    pub fn new(scope_id: usize) -> Self {
247        let active = PROOF_CACHE.with(|cell| {
248            let mut borrow = cell.borrow_mut();
249            match &mut *borrow {
250                None => {
251                    *borrow = Some(ProofCache { scope_id, depth: 1, cache: HashMap::new() });
252                    true
253                }
254                Some(pc) if pc.scope_id == scope_id => {
255                    pc.depth += 1;
256                    true
257                }
258                Some(_) => false,
259            }
260        });
261        ProofCacheSession { active }
262    }
263}
264
265impl Drop for ProofCacheSession {
266    fn drop(&mut self) {
267        if !self.active {
268            return;
269        }
270        PROOF_CACHE.with(|cell| {
271            let mut borrow = cell.borrow_mut();
272            if let Some(pc) = &mut *borrow {
273                pc.depth -= 1;
274                if pc.depth == 0 {
275                    *borrow = None;
276                }
277            }
278        });
279    }
280}
281
282impl SymbolScopeData {
283    pub fn all_assertions(&self) -> &[Assertion] {
284        &self.assertions
285    }
286
287    pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator<Item = &'_ Assertion> {
288        self.assertions.iter().chain(
289            scenario
290                .and_then(|s| self.scenarios.iter().find(|s2| s2.0 == s))
291                .map(|s| &*s.1)
292                .unwrap_or(&[])
293                .iter(),
294        )
295    }
296
297    pub fn scenarios(&self) -> impl Iterator<Item = &'_ str> {
298        self.scenarios.iter().map(|s| &*s.0)
299    }
300
301    pub fn scenario(&self, s: &str) -> impl Iterator<Item = &'_ Assertion> {
302        self.scenarios.iter().find(|sc| sc.0 == s).map(|sc| &*sc.1).unwrap_or(&[]).iter()
303    }
304
305    pub fn resolving<R>(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option<R> {
306        self.table.resolve(sym.1).map(f)
307    }
308
309    #[allow(clippy::mutable_key_type)]
310    pub fn prove_positive_or_zero(&self, t: &TDim) -> bool {
311        if let TDim::Val(v) = t {
312            return *v >= 0;
313        }
314        let cached = PROOF_CACHE.with(|cell| {
315            let borrow = cell.borrow();
316            if let Some(pc) = &*borrow {
317                debug_assert_eq!(pc.scope_id, self.id, "ProofCacheSession scope_id mismatch");
318                pc.cache.get(t).copied()
319            } else {
320                None
321            }
322        });
323        if let Some(result) = cached {
324            return result;
325        }
326        let result = self.prove_positive_or_zero_inner(t);
327        PROOF_CACHE.with(|cell| {
328            let mut borrow = cell.borrow_mut();
329            if let Some(pc) = &mut *borrow {
330                pc.cache.insert(t.clone(), result);
331            }
332        });
333        result
334    }
335
336    #[allow(clippy::mutable_key_type)]
337    fn prove_positive_or_zero_inner(&self, t: &TDim) -> bool {
338        self.prove_positive_or_zero_inner_with_extra(t, &[])
339    }
340
341    #[allow(clippy::mutable_key_type)]
342    fn prove_positive_or_zero_inner_with_extra(&self, t: &TDim, extra: &[Assertion]) -> bool {
343        let positives = self
344            .assertions
345            .iter()
346            .chain(extra.iter())
347            .filter_map(|i| i.as_known_positive())
348            .collect_vec();
349        let mut visited = vec![];
350        let mut todo = vec![t.clone()];
351        while let Some(t) = todo.pop() {
352            if t.to_i64().is_ok_and(|i| i >= 0) {
353                return true;
354            }
355            if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
356                return true;
357            }
358            // Div(a, q) with q >= 1 is non-negative whenever a is non-negative.
359            if let TDim::Div(a, q) = &t {
360                if *q >= 1 && self.prove_positive_or_zero_inner_with_extra(a, extra) {
361                    return true;
362                }
363            }
364            let syms = t.symbols();
365            for s in syms {
366                let me = t.guess_slope(&s);
367                for pos in &positives {
368                    if pos.symbols().contains(&s) {
369                        let other = pos.guess_slope(&s);
370                        if me.0.signum() == other.0.signum() {
371                            let new = t.clone() * me.1 * other.0.abs()
372                                - pos.clone() * me.0.abs() * other.1;
373                            if !visited.contains(&new) {
374                                todo.push(new);
375                            }
376                        }
377                    }
378                }
379            }
380            visited.push(t);
381            if visited.len() > 10 {
382                break;
383            }
384        }
385        false
386    }
387
388    pub(crate) fn prove_positive_or_zero_with_extra(&self, t: &TDim, extra: &[Assertion]) -> bool {
389        if let TDim::Val(v) = t {
390            return *v >= 0;
391        }
392        // Skip the proof cache for extra-assertion calls (cache is keyed without extra context)
393        self.prove_positive_or_zero_inner_with_extra(t, extra)
394    }
395
396    pub(crate) fn prove_strict_positive_with_extra(&self, b: &TDim, extra: &[Assertion]) -> bool {
397        self.prove_positive_or_zero_with_extra(&(b.clone() - 1), extra)
398    }
399}
400
401impl fmt::Debug for SymbolScope {
402    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403        let locked = self.0.lock();
404        let locked = locked.borrow();
405        write!(
406            f,
407            "symbols: {}; assertions: {}; {}",
408            locked.table.into_iter().map(|(_, s)| s).sorted().join(", "),
409            locked.assertions.iter().map(|s| s.to_string()).sorted().join(", "),
410            locked
411                .scenarios
412                .iter()
413                .map(|s| format!(
414                    "{}: {}",
415                    s.0,
416                    s.1.iter().map(|s| s.to_string()).sorted().join(", ")
417                ))
418                .join(" ; "),
419        )
420    }
421}
422
423#[derive(Clone)]
424pub struct Symbol(Weak<SymbolScopeInner>, string_interner::DefaultSymbol);
425
426impl Eq for Symbol {}
427
428impl PartialEq for Symbol {
429    fn eq(&self, other: &Self) -> bool {
430        self.1 == other.1
431    }
432}
433
434impl Symbol {
435    pub fn scope(&self) -> Option<SymbolScope> {
436        self.0.upgrade().map(SymbolScope)
437    }
438}
439
440impl PartialOrd for Symbol {
441    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
442        Some(self.cmp(other))
443    }
444}
445
446impl Ord for Symbol {
447    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
448        self.1.cmp(&other.1)
449    }
450}
451
452impl std::hash::Hash for Symbol {
453    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
454        self.1.hash(state)
455    }
456}
457
458impl std::fmt::Display for Symbol {
459    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460        if let Some(scope) = self.scope() {
461            let lock = scope.0.lock();
462            let lock = lock.borrow();
463            if let Some(s) = lock.table.resolve(self.1) {
464                return write!(f, "{s}");
465            }
466        }
467        write!(f, "<Sym{}>", self.1.to_usize())
468    }
469}
470
471impl fmt::Debug for Symbol {
472    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473        Display::fmt(&self, f)
474    }
475}
476
477#[derive(Clone, Debug, Default)]
478pub struct SymbolValues {
479    values: HashMap<Symbol, i64>,
480}
481
482impl SymbolValues {
483    pub fn with(mut self, s: &Symbol, v: i64) -> Self {
484        self.set(s, v);
485        self
486    }
487
488    pub fn set(&mut self, s: &Symbol, v: i64) {
489        self.values.insert(s.clone(), v);
490    }
491
492    pub fn get(&self, s: &Symbol) -> Option<i64> {
493        self.values.get(s).copied()
494    }
495
496    /// View the bindings as `Symbol → TDim` (each `i64` lifted to `TDim::Val`),
497    /// for callers that need to plug into APIs taking `HashMap<Symbol, TDim>`.
498    pub fn to_dim_map(&self) -> HashMap<Symbol, TDim> {
499        self.values.iter().map(|(s, v)| (s.clone(), TDim::Val(*v))).collect()
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn as_known_positive_gte() {
509        let s = SymbolScope::default();
510        assert_eq!(
511            parse_assertion(&s, "S>=0").unwrap().as_known_positive(),
512            Some(s.parse_tdim("S").unwrap())
513        );
514    }
515
516    #[test]
517    fn as_known_positive_gt() {
518        let s = SymbolScope::default();
519        assert_eq!(
520            parse_assertion(&s, "S>0").unwrap().as_known_positive(),
521            Some(s.parse_tdim("S-1").unwrap())
522        );
523    }
524
525    #[test]
526    fn as_known_positive_lte() {
527        let s = SymbolScope::default();
528        assert_eq!(
529            parse_assertion(&s, "S<=0").unwrap().as_known_positive(),
530            Some(s.parse_tdim("-S").unwrap())
531        );
532    }
533
534    #[test]
535    fn as_known_positive_lt() {
536        let s = SymbolScope::default();
537        assert_eq!(
538            parse_assertion(&s, "S<0").unwrap().as_known_positive(),
539            Some(s.parse_tdim("-S - 1").unwrap())
540        );
541    }
542
543    #[test]
544    fn prove_positive_0() {
545        let s = SymbolScope::default();
546        assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero());
547    }
548
549    #[test]
550    fn prove_positive_1() {
551        let s = SymbolScope::default();
552        assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero());
553    }
554
555    #[test]
556    fn prove_positive_neg1() {
557        let s = SymbolScope::default();
558        assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero());
559    }
560
561    #[test]
562    fn prove_positive_add_0() {
563        let s = SymbolScope::default();
564        assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero());
565    }
566}