1use itertools::Itertools;
2use parking_lot::ReentrantMutex;
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::fmt::{self, Display};
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::sync::{Arc, Weak};
8use string_interner::DefaultStringInterner;
9use string_interner::Symbol as _;
10
11use crate::TractResult;
12
13use super::parse::parse_assertion;
14use super::{Assertion, TDim, parse_tdim};
15
16static SCOPE_COUNTER: AtomicUsize = AtomicUsize::new(0);
17
18#[derive(Clone, Default)]
19pub struct SymbolScope(pub Arc<ReentrantMutex<RefCell<SymbolScopeData>>>);
20
21impl PartialEq for SymbolScope {
22 fn eq(&self, other: &Self) -> bool {
23 Arc::ptr_eq(&self.0, &other.0)
24 }
25}
26
27impl Eq for SymbolScope {}
28
29pub struct SymbolScopeData {
30 id: usize,
31 table: DefaultStringInterner,
32 assertions: Vec<Assertion>,
33 scenarios: Vec<(String, Vec<Assertion>)>,
34}
35
36impl Default for SymbolScopeData {
37 fn default() -> Self {
38 SymbolScopeData {
39 id: SCOPE_COUNTER.fetch_add(1, Ordering::Relaxed),
40 table: DefaultStringInterner::default(),
41 assertions: Vec::new(),
42 scenarios: Vec::new(),
43 }
44 }
45}
46
47impl SymbolScope {
48 pub fn id(&self) -> usize {
49 let locked = self.0.lock();
50 let locked = locked.borrow();
51 locked.id
52 }
53
54 pub fn proof_cache_session(&self) -> ProofCacheSession {
55 ProofCacheSession::new(self.id())
56 }
57
58 pub fn get(&self, name: &str) -> Option<Symbol> {
59 let locked = self.0.lock();
60 let locked = locked.borrow();
61 locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym))
62 }
63
64 pub fn coord_sym(&self, k: usize) -> Symbol {
66 self.sym(&format!("🎯{k}"))
67 }
68
69 pub fn sym(&self, name: &str) -> Symbol {
70 let locked = self.0.lock();
71 let mut locked = locked.borrow_mut();
72 let sym = locked.table.get_or_intern(name);
73 Symbol(Arc::downgrade(&self.0), sym)
74 }
75
76 pub fn new_with_prefix(&self, prefix: &str) -> Symbol {
77 let locked = self.0.lock();
78 let mut locked = locked.borrow_mut();
79 let sym = if locked.table.get(prefix).is_none() {
80 locked.table.get_or_intern(prefix)
81 } else {
82 let mut i = 0;
83 loop {
84 let s = format!("{prefix}_{i}");
85 if locked.table.get(&s).is_none() {
86 break locked.table.get_or_intern(s);
87 }
88 i += 1;
89 }
90 };
91 Symbol(Arc::downgrade(&self.0), sym)
92 }
93
94 pub fn parse_tdim(&self, input: impl AsRef<str>) -> TractResult<TDim> {
95 parse_tdim(self, input.as_ref())
96 }
97
98 pub fn add_assertion(&self, assert: impl Into<String>) -> TractResult<()> {
99 let assert = assert.into();
100 let assert = parse_assertion(self, &assert)?;
101 let locked = self.0.lock();
102 let mut locked = locked.borrow_mut();
103 locked.assertions.push(assert);
104 Ok(())
105 }
106
107 pub fn with_assertion(self, assert: impl Into<String>) -> TractResult<Self> {
108 self.add_assertion(assert)?;
109 Ok(self)
110 }
111
112 pub fn all_assertions(&self) -> Vec<Assertion> {
113 let locked = self.0.lock();
114 let locked = locked.borrow();
115 locked.assertions.clone()
116 }
117
118 pub fn all_scenarios(&self) -> impl IntoIterator<Item = (String, Vec<Assertion>)> {
119 let locked = self.0.lock();
120 let locked = locked.borrow();
121 locked.scenarios.clone()
122 }
123
124 pub fn add_scenario(&self, scenario: impl Into<String>) -> TractResult<()> {
125 let locked = self.0.lock();
126 let mut locked = locked.borrow_mut();
127 let s = scenario.into();
128 if !locked.scenarios.iter().any(|sc| sc.0 == s) {
129 locked.scenarios.push((s, vec![]));
130 }
131 Ok(())
132 }
133
134 pub fn add_scenario_assertion(
135 &self,
136 scenario: impl Into<String>,
137 assertion: impl Into<String>,
138 ) -> TractResult<()> {
139 let assert = parse_assertion(self, &assertion.into())?;
140 let s = scenario.into();
141 let locked = self.0.lock();
142 let mut locked = locked.borrow_mut();
143 if let Some(s) = locked.scenarios.iter_mut().find(|sc| sc.0 == s) {
144 s.1.push(assert);
145 } else {
146 locked.scenarios.push((s, vec![assert]));
147 }
148 Ok(())
149 }
150
151 pub fn with_scenario_assertion(
152 self,
153 scenario: impl Into<String>,
154 assertion: impl Into<String>,
155 ) -> TractResult<Self> {
156 self.add_scenario_assertion(scenario, assertion)?;
157 Ok(self)
158 }
159
160 pub fn with_scenario(self, scenario: impl Into<String>) -> TractResult<Self> {
161 self.add_scenario(scenario)?;
162 Ok(self)
163 }
164
165 pub fn all_symbols(&self) -> Vec<Symbol> {
166 self.0
167 .lock()
168 .borrow()
169 .table
170 .into_iter()
171 .map(|is| Symbol(Arc::downgrade(&self.0), is.0))
172 .collect()
173 }
174
175 pub fn guess_scenario(&self, values: &SymbolValues) -> TractResult<Option<usize>> {
176 let locked = self.0.lock();
177 let locked = locked.borrow();
178 if locked.scenarios.len() == 0 {
179 return Ok(None);
180 }
181 let mut maybe = None;
182 for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
183 if assertions.iter().any(|a| a.check(values) == Some(false)) {
184 continue;
185 } else if assertions.iter().all(|a| a.check(values) == Some(true)) {
186 return Ok(Some(ix));
187 } else if maybe.is_none() {
188 maybe = Some(ix);
189 } else {
190 return Ok(None);
191 }
192 }
193 if maybe.is_some() {
194 Ok(maybe)
195 } else {
196 anyhow::bail!("No possible scenario");
197 }
198 }
199}
200
201thread_local! {
202 static PROOF_CACHE: RefCell<Option<ProofCache>> = const { RefCell::new(None) };
203}
204
205struct ProofCache {
206 scope_id: usize,
207 depth: usize,
208 cache: HashMap<TDim, bool>,
209}
210
211pub struct ProofCacheSession {
212 active: bool,
213}
214
215impl ProofCacheSession {
216 pub fn new(scope_id: usize) -> Self {
217 let active = PROOF_CACHE.with(|cell| {
218 let mut borrow = cell.borrow_mut();
219 match &mut *borrow {
220 None => {
221 *borrow = Some(ProofCache { scope_id, depth: 1, cache: HashMap::new() });
222 true
223 }
224 Some(pc) if pc.scope_id == scope_id => {
225 pc.depth += 1;
226 true
227 }
228 Some(_) => false,
229 }
230 });
231 ProofCacheSession { active }
232 }
233}
234
235impl Drop for ProofCacheSession {
236 fn drop(&mut self) {
237 if !self.active {
238 return;
239 }
240 PROOF_CACHE.with(|cell| {
241 let mut borrow = cell.borrow_mut();
242 if let Some(pc) = &mut *borrow {
243 pc.depth -= 1;
244 if pc.depth == 0 {
245 *borrow = None;
246 }
247 }
248 });
249 }
250}
251
252impl SymbolScopeData {
253 pub fn all_assertions(&self) -> &[Assertion] {
254 &self.assertions
255 }
256
257 pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator<Item = &'_ Assertion> {
258 self.assertions.iter().chain(
259 scenario
260 .and_then(|s| self.scenarios.iter().find(|s2| s2.0 == s))
261 .map(|s| &*s.1)
262 .unwrap_or(&[])
263 .iter(),
264 )
265 }
266
267 pub fn scenarios(&self) -> impl Iterator<Item = &'_ str> {
268 self.scenarios.iter().map(|s| &*s.0)
269 }
270
271 pub fn scenario(&self, s: &str) -> impl Iterator<Item = &'_ Assertion> {
272 self.scenarios.iter().find(|sc| sc.0 == s).map(|sc| &*sc.1).unwrap_or(&[]).iter()
273 }
274
275 pub fn resolving<R>(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option<R> {
276 self.table.resolve(sym.1).map(f)
277 }
278
279 #[allow(clippy::mutable_key_type)]
280 pub fn prove_positive_or_zero(&self, t: &TDim) -> bool {
281 if let TDim::Val(v) = t {
282 return *v >= 0;
283 }
284 let cached = PROOF_CACHE.with(|cell| {
285 let borrow = cell.borrow();
286 if let Some(pc) = &*borrow {
287 debug_assert_eq!(pc.scope_id, self.id, "ProofCacheSession scope_id mismatch");
288 pc.cache.get(t).copied()
289 } else {
290 None
291 }
292 });
293 if let Some(result) = cached {
294 return result;
295 }
296 let result = self.prove_positive_or_zero_inner(t);
297 PROOF_CACHE.with(|cell| {
298 let mut borrow = cell.borrow_mut();
299 if let Some(pc) = &mut *borrow {
300 pc.cache.insert(t.clone(), result);
301 }
302 });
303 result
304 }
305
306 #[allow(clippy::mutable_key_type)]
307 fn prove_positive_or_zero_inner(&self, t: &TDim) -> bool {
308 self.prove_positive_or_zero_inner_with_extra(t, &[])
309 }
310
311 #[allow(clippy::mutable_key_type)]
312 fn prove_positive_or_zero_inner_with_extra(&self, t: &TDim, extra: &[Assertion]) -> bool {
313 let positives = self
314 .assertions
315 .iter()
316 .chain(extra.iter())
317 .filter_map(|i| i.as_known_positive())
318 .collect_vec();
319 let mut visited = vec![];
320 let mut todo = vec![t.clone()];
321 while let Some(t) = todo.pop() {
322 if t.to_i64().is_ok_and(|i| i >= 0) {
323 return true;
324 }
325 if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
326 return true;
327 }
328 if let TDim::Div(a, q) = &t {
330 if *q >= 1 && self.prove_positive_or_zero_inner_with_extra(a, extra) {
331 return true;
332 }
333 }
334 let syms = t.symbols();
335 for s in syms {
336 let me = t.guess_slope(&s);
337 for pos in &positives {
338 if pos.symbols().contains(&s) {
339 let other = pos.guess_slope(&s);
340 if me.0.signum() == other.0.signum() {
341 let new = t.clone() * me.1 * other.0.abs()
342 - pos.clone() * me.0.abs() * other.1;
343 if !visited.contains(&new) {
344 todo.push(new);
345 }
346 }
347 }
348 }
349 }
350 visited.push(t);
351 if visited.len() > 10 {
352 break;
353 }
354 }
355 false
356 }
357
358 pub(crate) fn prove_positive_or_zero_with_extra(&self, t: &TDim, extra: &[Assertion]) -> bool {
359 if let TDim::Val(v) = t {
360 return *v >= 0;
361 }
362 self.prove_positive_or_zero_inner_with_extra(t, extra)
364 }
365
366 pub(crate) fn prove_strict_positive_with_extra(&self, b: &TDim, extra: &[Assertion]) -> bool {
367 self.prove_positive_or_zero_with_extra(&(b.clone() - 1), extra)
368 }
369}
370
371impl fmt::Debug for SymbolScope {
372 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373 let locked = self.0.lock();
374 let locked = locked.borrow();
375 write!(
376 f,
377 "symbols: {}; assertions: {}; {}",
378 locked.table.into_iter().map(|(_, s)| s).sorted().join(", "),
379 locked.assertions.iter().map(|s| s.to_string()).sorted().join(", "),
380 locked
381 .scenarios
382 .iter()
383 .map(|s| format!(
384 "{}: {}",
385 s.0,
386 s.1.iter().map(|s| s.to_string()).sorted().join(", ")
387 ))
388 .join(" ; "),
389 )
390 }
391}
392
393#[derive(Clone)]
394pub struct Symbol(Weak<ReentrantMutex<RefCell<SymbolScopeData>>>, string_interner::DefaultSymbol);
395
396impl Eq for Symbol {}
397
398impl PartialEq for Symbol {
399 fn eq(&self, other: &Self) -> bool {
400 self.1 == other.1
401 }
402}
403
404impl Symbol {
405 pub fn scope(&self) -> Option<SymbolScope> {
406 self.0.upgrade().map(SymbolScope)
407 }
408}
409
410impl PartialOrd for Symbol {
411 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
412 Some(self.cmp(other))
413 }
414}
415
416impl Ord for Symbol {
417 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
418 self.1.cmp(&other.1)
419 }
420}
421
422impl std::hash::Hash for Symbol {
423 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
424 self.1.hash(state)
425 }
426}
427
428impl std::fmt::Display for Symbol {
429 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
430 if let Some(scope) = self.scope() {
431 let lock = scope.0.lock();
432 let lock = lock.borrow();
433 if let Some(s) = lock.table.resolve(self.1) {
434 return write!(f, "{s}");
435 }
436 }
437 write!(f, "<Sym{}>", self.1.to_usize())
438 }
439}
440
441impl fmt::Debug for Symbol {
442 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
443 Display::fmt(&self, f)
444 }
445}
446
447#[derive(Clone, Debug, Default)]
448pub struct SymbolValues {
449 values: HashMap<Symbol, i64>,
450}
451
452impl SymbolValues {
453 pub fn with(mut self, s: &Symbol, v: i64) -> Self {
454 self.set(s, v);
455 self
456 }
457
458 pub fn set(&mut self, s: &Symbol, v: i64) {
459 self.values.insert(s.clone(), v);
460 }
461
462 pub fn get(&self, s: &Symbol) -> Option<i64> {
463 self.values.get(s).copied()
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470
471 #[test]
472 fn as_known_positive_gte() {
473 let s = SymbolScope::default();
474 assert_eq!(
475 parse_assertion(&s, "S>=0").unwrap().as_known_positive(),
476 Some(s.parse_tdim("S").unwrap())
477 );
478 }
479
480 #[test]
481 fn as_known_positive_gt() {
482 let s = SymbolScope::default();
483 assert_eq!(
484 parse_assertion(&s, "S>0").unwrap().as_known_positive(),
485 Some(s.parse_tdim("S-1").unwrap())
486 );
487 }
488
489 #[test]
490 fn as_known_positive_lte() {
491 let s = SymbolScope::default();
492 assert_eq!(
493 parse_assertion(&s, "S<=0").unwrap().as_known_positive(),
494 Some(s.parse_tdim("-S").unwrap())
495 );
496 }
497
498 #[test]
499 fn as_known_positive_lt() {
500 let s = SymbolScope::default();
501 assert_eq!(
502 parse_assertion(&s, "S<0").unwrap().as_known_positive(),
503 Some(s.parse_tdim("-S - 1").unwrap())
504 );
505 }
506
507 #[test]
508 fn prove_positive_0() {
509 let s = SymbolScope::default();
510 assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero());
511 }
512
513 #[test]
514 fn prove_positive_1() {
515 let s = SymbolScope::default();
516 assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero());
517 }
518
519 #[test]
520 fn prove_positive_neg1() {
521 let s = SymbolScope::default();
522 assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero());
523 }
524
525 #[test]
526 fn prove_positive_add_0() {
527 let s = SymbolScope::default();
528 assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero());
529 }
530}