1use itertools::Itertools;
2use parking_lot::ReentrantMutex;
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::fmt::{self, Display};
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::sync::{Arc, Weak};
8use string_interner::DefaultStringInterner;
9use string_interner::Symbol as _;
10
11use crate::TractResult;
12
13use super::parse::parse_assertion;
14use super::{Assertion, TDim, parse_tdim};
15
16static SCOPE_COUNTER: AtomicUsize = AtomicUsize::new(0);
17
18#[derive(Clone, Default)]
19pub struct SymbolScope(pub Arc<ReentrantMutex<RefCell<SymbolScopeData>>>);
20
21impl PartialEq for SymbolScope {
22 fn eq(&self, other: &Self) -> bool {
23 Arc::ptr_eq(&self.0, &other.0)
24 }
25}
26
27impl Eq for SymbolScope {}
28
29pub struct SymbolScopeData {
30 id: usize,
31 table: DefaultStringInterner,
32 assertions: Vec<Assertion>,
33 scenarios: Vec<(String, Vec<Assertion>)>,
34}
35
36impl Default for SymbolScopeData {
37 fn default() -> Self {
38 SymbolScopeData {
39 id: SCOPE_COUNTER.fetch_add(1, Ordering::Relaxed),
40 table: DefaultStringInterner::default(),
41 assertions: Vec::new(),
42 scenarios: Vec::new(),
43 }
44 }
45}
46
47impl SymbolScope {
48 pub fn id(&self) -> usize {
49 let locked = self.0.lock();
50 let locked = locked.borrow();
51 locked.id
52 }
53
54 pub fn proof_cache_session(&self) -> ProofCacheSession {
55 ProofCacheSession::new(self.id())
56 }
57
58 pub fn get(&self, name: &str) -> Option<Symbol> {
59 let locked = self.0.lock();
60 let locked = locked.borrow();
61 locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym))
62 }
63
64 pub fn sym(&self, name: &str) -> Symbol {
65 let locked = self.0.lock();
66 let mut locked = locked.borrow_mut();
67 let sym = locked.table.get_or_intern(name);
68 Symbol(Arc::downgrade(&self.0), sym)
69 }
70
71 pub fn new_with_prefix(&self, prefix: &str) -> Symbol {
72 let locked = self.0.lock();
73 let mut locked = locked.borrow_mut();
74 let sym = if locked.table.get(prefix).is_none() {
75 locked.table.get_or_intern(prefix)
76 } else {
77 let mut i = 0;
78 loop {
79 let s = format!("{prefix}_{i}");
80 if locked.table.get(&s).is_none() {
81 break locked.table.get_or_intern(s);
82 }
83 i += 1;
84 }
85 };
86 Symbol(Arc::downgrade(&self.0), sym)
87 }
88
89 pub fn parse_tdim(&self, input: impl AsRef<str>) -> TractResult<TDim> {
90 parse_tdim(self, input.as_ref())
91 }
92
93 pub fn add_assertion(&self, assert: impl Into<String>) -> TractResult<()> {
94 let assert = assert.into();
95 let assert = parse_assertion(self, &assert)?;
96 let locked = self.0.lock();
97 let mut locked = locked.borrow_mut();
98 locked.assertions.push(assert);
99 Ok(())
100 }
101
102 pub fn with_assertion(self, assert: impl Into<String>) -> TractResult<Self> {
103 self.add_assertion(assert)?;
104 Ok(self)
105 }
106
107 pub fn all_assertions(&self) -> Vec<Assertion> {
108 let locked = self.0.lock();
109 let locked = locked.borrow();
110 locked.assertions.clone()
111 }
112
113 pub fn all_scenarios(&self) -> impl IntoIterator<Item = (String, Vec<Assertion>)> {
114 let locked = self.0.lock();
115 let locked = locked.borrow();
116 locked.scenarios.clone()
117 }
118
119 pub fn add_scenario(&self, scenario: impl Into<String>) -> TractResult<()> {
120 let locked = self.0.lock();
121 let mut locked = locked.borrow_mut();
122 let s = scenario.into();
123 if !locked.scenarios.iter().any(|sc| sc.0 == s) {
124 locked.scenarios.push((s, vec![]));
125 }
126 Ok(())
127 }
128
129 pub fn add_scenario_assertion(
130 &self,
131 scenario: impl Into<String>,
132 assertion: impl Into<String>,
133 ) -> TractResult<()> {
134 let assert = parse_assertion(self, &assertion.into())?;
135 let s = scenario.into();
136 let locked = self.0.lock();
137 let mut locked = locked.borrow_mut();
138 if let Some(s) = locked.scenarios.iter_mut().find(|sc| sc.0 == s) {
139 s.1.push(assert);
140 } else {
141 locked.scenarios.push((s, vec![assert]));
142 }
143 Ok(())
144 }
145
146 pub fn with_scenario_assertion(
147 self,
148 scenario: impl Into<String>,
149 assertion: impl Into<String>,
150 ) -> TractResult<Self> {
151 self.add_scenario_assertion(scenario, assertion)?;
152 Ok(self)
153 }
154
155 pub fn with_scenario(self, scenario: impl Into<String>) -> TractResult<Self> {
156 self.add_scenario(scenario)?;
157 Ok(self)
158 }
159
160 pub fn all_symbols(&self) -> Vec<Symbol> {
161 self.0
162 .lock()
163 .borrow()
164 .table
165 .into_iter()
166 .map(|is| Symbol(Arc::downgrade(&self.0), is.0))
167 .collect()
168 }
169
170 pub fn guess_scenario(&self, values: &SymbolValues) -> TractResult<Option<usize>> {
171 let locked = self.0.lock();
172 let locked = locked.borrow();
173 if locked.scenarios.len() == 0 {
174 return Ok(None);
175 }
176 let mut maybe = None;
177 for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
178 if assertions.iter().any(|a| a.check(values) == Some(false)) {
179 continue;
180 } else if assertions.iter().all(|a| a.check(values) == Some(true)) {
181 return Ok(Some(ix));
182 } else if maybe.is_none() {
183 maybe = Some(ix);
184 } else {
185 return Ok(None);
186 }
187 }
188 if maybe.is_some() {
189 Ok(maybe)
190 } else {
191 anyhow::bail!("No possible scenario");
192 }
193 }
194}
195
196thread_local! {
197 static PROOF_CACHE: RefCell<Option<ProofCache>> = const { RefCell::new(None) };
198}
199
200struct ProofCache {
201 scope_id: usize,
202 depth: usize,
203 cache: HashMap<TDim, bool>,
204}
205
206pub struct ProofCacheSession {
207 active: bool,
208}
209
210impl ProofCacheSession {
211 pub fn new(scope_id: usize) -> Self {
212 let active = PROOF_CACHE.with(|cell| {
213 let mut borrow = cell.borrow_mut();
214 match &mut *borrow {
215 None => {
216 *borrow = Some(ProofCache { scope_id, depth: 1, cache: HashMap::new() });
217 true
218 }
219 Some(pc) if pc.scope_id == scope_id => {
220 pc.depth += 1;
221 true
222 }
223 Some(_) => false,
224 }
225 });
226 ProofCacheSession { active }
227 }
228}
229
230impl Drop for ProofCacheSession {
231 fn drop(&mut self) {
232 if !self.active {
233 return;
234 }
235 PROOF_CACHE.with(|cell| {
236 let mut borrow = cell.borrow_mut();
237 if let Some(pc) = &mut *borrow {
238 pc.depth -= 1;
239 if pc.depth == 0 {
240 *borrow = None;
241 }
242 }
243 });
244 }
245}
246
247impl SymbolScopeData {
248 pub fn all_assertions(&self) -> &[Assertion] {
249 &self.assertions
250 }
251
252 pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator<Item = &'_ Assertion> {
253 self.assertions.iter().chain(
254 scenario
255 .and_then(|s| self.scenarios.iter().find(|s2| s2.0 == s))
256 .map(|s| &*s.1)
257 .unwrap_or(&[])
258 .iter(),
259 )
260 }
261
262 pub fn scenarios(&self) -> impl Iterator<Item = &'_ str> {
263 self.scenarios.iter().map(|s| &*s.0)
264 }
265
266 pub fn scenario(&self, s: &str) -> impl Iterator<Item = &'_ Assertion> {
267 self.scenarios.iter().find(|sc| sc.0 == s).map(|sc| &*sc.1).unwrap_or(&[]).iter()
268 }
269
270 pub fn resolving<R>(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option<R> {
271 self.table.resolve(sym.1).map(f)
272 }
273
274 #[allow(clippy::mutable_key_type)]
275 pub fn prove_positive_or_zero(&self, t: &TDim) -> bool {
276 if let TDim::Val(v) = t {
277 return *v >= 0;
278 }
279 let cached = PROOF_CACHE.with(|cell| {
280 let borrow = cell.borrow();
281 if let Some(pc) = &*borrow {
282 debug_assert_eq!(pc.scope_id, self.id, "ProofCacheSession scope_id mismatch");
283 pc.cache.get(t).copied()
284 } else {
285 None
286 }
287 });
288 if let Some(result) = cached {
289 return result;
290 }
291 let result = self.prove_positive_or_zero_inner(t);
292 PROOF_CACHE.with(|cell| {
293 let mut borrow = cell.borrow_mut();
294 if let Some(pc) = &mut *borrow {
295 pc.cache.insert(t.clone(), result);
296 }
297 });
298 result
299 }
300
301 #[allow(clippy::mutable_key_type)]
302 fn prove_positive_or_zero_inner(&self, t: &TDim) -> bool {
303 let positives = self.assertions.iter().filter_map(|i| i.as_known_positive()).collect_vec();
304 let mut visited = vec![];
305 let mut todo = vec![t.clone()];
306 while let Some(t) = todo.pop() {
307 if t.to_i64().is_ok_and(|i| i >= 0) {
308 return true;
309 }
310 if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
311 return true;
312 }
313 let syms = t.symbols();
314 for s in syms {
315 let me = t.guess_slope(&s);
316 for pos in &positives {
317 if pos.symbols().contains(&s) {
318 let other = pos.guess_slope(&s);
319 if me.0.signum() == other.0.signum() {
320 let new = t.clone() * me.1 * other.0.abs()
321 - pos.clone() * me.0.abs() * other.1;
322 if !visited.contains(&new) {
323 todo.push(new);
324 }
325 }
326 }
327 }
328 }
329 visited.push(t);
330 if visited.len() > 10 {
331 break;
332 }
333 }
334 false
335 }
336
337 pub(crate) fn prove_strict_positive(&self, b: &TDim) -> bool {
338 self.prove_positive_or_zero(&(b.clone() - 1))
339 }
340}
341
342impl fmt::Debug for SymbolScope {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 let locked = self.0.lock();
345 let locked = locked.borrow();
346 write!(
347 f,
348 "symbols: {}; assertions: {}; {}",
349 locked.table.into_iter().map(|(_, s)| s).sorted().join(", "),
350 locked.assertions.iter().map(|s| s.to_string()).sorted().join(", "),
351 locked
352 .scenarios
353 .iter()
354 .map(|s| format!(
355 "{}: {}",
356 s.0,
357 s.1.iter().map(|s| s.to_string()).sorted().join(", ")
358 ))
359 .join(" ; "),
360 )
361 }
362}
363
364#[derive(Clone)]
365pub struct Symbol(Weak<ReentrantMutex<RefCell<SymbolScopeData>>>, string_interner::DefaultSymbol);
366
367impl Eq for Symbol {}
368
369impl PartialEq for Symbol {
370 fn eq(&self, other: &Self) -> bool {
371 self.1 == other.1
372 }
373}
374
375impl Symbol {
376 pub fn scope(&self) -> Option<SymbolScope> {
377 self.0.upgrade().map(SymbolScope)
378 }
379}
380
381impl PartialOrd for Symbol {
382 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
383 Some(self.cmp(other))
384 }
385}
386
387impl Ord for Symbol {
388 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
389 self.1.cmp(&other.1)
390 }
391}
392
393impl std::hash::Hash for Symbol {
394 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
395 self.1.hash(state)
396 }
397}
398
399impl std::fmt::Display for Symbol {
400 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401 if let Some(scope) = self.scope() {
402 let lock = scope.0.lock();
403 let lock = lock.borrow();
404 if let Some(s) = lock.table.resolve(self.1) {
405 return write!(f, "{s}");
406 }
407 }
408 write!(f, "<Sym{}>", self.1.to_usize())
409 }
410}
411
412impl fmt::Debug for Symbol {
413 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
414 Display::fmt(&self, f)
415 }
416}
417
418#[derive(Clone, Debug, Default)]
419pub struct SymbolValues {
420 values: HashMap<Symbol, i64>,
421}
422
423impl SymbolValues {
424 pub fn with(mut self, s: &Symbol, v: i64) -> Self {
425 self.set(s, v);
426 self
427 }
428
429 pub fn set(&mut self, s: &Symbol, v: i64) {
430 self.values.insert(s.clone(), v);
431 }
432
433 pub fn get(&self, s: &Symbol) -> Option<i64> {
434 self.values.get(s).copied()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn as_known_positive_gte() {
444 let s = SymbolScope::default();
445 assert_eq!(
446 parse_assertion(&s, "S>=0").unwrap().as_known_positive(),
447 Some(s.parse_tdim("S").unwrap())
448 );
449 }
450
451 #[test]
452 fn as_known_positive_gt() {
453 let s = SymbolScope::default();
454 assert_eq!(
455 parse_assertion(&s, "S>0").unwrap().as_known_positive(),
456 Some(s.parse_tdim("S-1").unwrap())
457 );
458 }
459
460 #[test]
461 fn as_known_positive_lte() {
462 let s = SymbolScope::default();
463 assert_eq!(
464 parse_assertion(&s, "S<=0").unwrap().as_known_positive(),
465 Some(s.parse_tdim("-S").unwrap())
466 );
467 }
468
469 #[test]
470 fn as_known_positive_lt() {
471 let s = SymbolScope::default();
472 assert_eq!(
473 parse_assertion(&s, "S<0").unwrap().as_known_positive(),
474 Some(s.parse_tdim("-S - 1").unwrap())
475 );
476 }
477
478 #[test]
479 fn prove_positive_0() {
480 let s = SymbolScope::default();
481 assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero());
482 }
483
484 #[test]
485 fn prove_positive_1() {
486 let s = SymbolScope::default();
487 assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero());
488 }
489
490 #[test]
491 fn prove_positive_neg1() {
492 let s = SymbolScope::default();
493 assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero());
494 }
495
496 #[test]
497 fn prove_positive_add_0() {
498 let s = SymbolScope::default();
499 assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero());
500 }
501}