smt_scope/parsers/z3/
synthetic.rs

1use std::collections::hash_map::Entry;
2
3#[cfg(feature = "mem_dbg")]
4use mem_dbg::{MemDbg, MemSize};
5
6use crate::{
7    idx,
8    items::{Meaning, Term, TermId, TermIdx, TermKind},
9    BoxSlice, FxHashMap, Result, TiVec,
10};
11
12use super::{terms::Terms, Z3Parser};
13
14#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
15#[derive(Debug, PartialEq, Eq, Hash, Clone)]
16pub enum AnyTerm {
17    Parsed(Term),
18    Synth(SynthTerm),
19    Constant(Meaning),
20}
21
22impl AnyTerm {
23    pub fn check_valid(&self, is_tidx: impl Fn(SynthIdx) -> bool) {
24        if let AnyTerm::Parsed(_) = self {
25            debug_assert!(false, "Parsed term should not be inserted as synthetic!");
26        }
27        use SynthTermKind::*;
28        match self.kind() {
29            Parsed(TermKind::Var(_)) | Variable(_) | Input(_) | Constant => {
30                debug_assert_eq!(self.child_ids().len(), 0)
31            }
32            Generalised(_) => debug_assert_eq!(self.child_ids().len(), 1),
33            Parsed(TermKind::Quant(_)) => debug_assert!(!self.child_ids().is_empty()),
34            Parsed(TermKind::App(_)) => debug_assert!(
35                self.child_ids().iter().any(|&c| !is_tidx(c)),
36                "Synthetic term must have at least one synthetic child"
37            ),
38        }
39    }
40
41    pub fn replace_child_ids(&self, child_ids: BoxSlice<SynthIdx>) -> Option<Self> {
42        assert_eq!(self.child_ids().len(), child_ids.len());
43        if self.child_ids() == &*child_ids {
44            return None;
45        }
46        match self {
47            AnyTerm::Parsed(term) => Some(AnyTerm::Synth(SynthTerm {
48                kind: SynthTermKind::Parsed(term.kind()),
49                child_ids,
50            })),
51            AnyTerm::Synth(synth_term) => Some(AnyTerm::Synth(SynthTerm {
52                kind: synth_term.kind,
53                child_ids,
54            })),
55            AnyTerm::Constant(_) => unreachable!(),
56        }
57    }
58}
59
60#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
61#[derive(Debug, PartialEq, Eq, Hash, Clone)]
62pub struct SynthTerm {
63    pub kind: SynthTermKind,
64    pub child_ids: BoxSlice<SynthIdx>,
65}
66
67#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
68#[cfg_attr(feature = "mem_dbg", copy_type)]
69#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
71#[repr(transparent)]
72pub struct SynthIdx(TermIdx);
73
74#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
75#[cfg_attr(feature = "mem_dbg", copy_type)]
76#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
77// Note that we must preserve `size_of::<SynthTermKind>() == size_of::<TermKind>()`!
78pub enum SynthTermKind {
79    Parsed(TermKind),
80    /// When generalising e.g. `f(x)` and `f(g(x))` we get back a `f(_)` term
81    /// where the `_` term is of kind `Generalised` and points to a `SynthIdx`
82    /// term of `g($0)` where the `$0` term is of kind `Input`.
83    ///
84    /// The `Option<TermIdx>` is the first term that was generalised over. The
85    /// `x` in the example above. Once three or more terms are generalised, this
86    /// is `None`.
87    Generalised(Option<TermIdx>),
88    Variable(u32),
89    /// When generalising e.g. `f(x)` and `f(g(x))` we get back a `f(_)` term
90    /// where the `_` term is of kind `Generalised` and points to a `SynthIdx`
91    /// term of `g($0)` where the `$0` term is of kind `Input`. The `Input` may
92    /// additionally contain a constant offset.
93    Input(Option<SynthIdx>),
94    /// Never actually stored in `SynthTerm.kind` byt created as a kind for
95    /// `AnyTerm::Constant`.
96    Constant,
97}
98
99impl From<TermIdx> for SynthIdx {
100    fn from(tidx: TermIdx) -> Self {
101        Self(tidx)
102    }
103}
104
105impl AnyTerm {
106    pub fn id(&self) -> Option<TermId> {
107        match self {
108            AnyTerm::Parsed(term) => Some(term.id),
109            AnyTerm::Synth(_) => None,
110            AnyTerm::Constant(_) => None,
111        }
112    }
113    pub fn kind(&self) -> SynthTermKind {
114        match self {
115            AnyTerm::Parsed(term) => SynthTermKind::Parsed(term.kind()),
116            AnyTerm::Synth(synth_term) => synth_term.kind,
117            AnyTerm::Constant(_) => SynthTermKind::Constant,
118        }
119    }
120    pub fn child_ids<'a>(&'a self) -> &'a [SynthIdx] {
121        match self {
122            AnyTerm::Parsed(term) => {
123                let child_ids = &*term.child_ids;
124                unsafe { std::mem::transmute::<&'a [TermIdx], &'a [SynthIdx]>(child_ids) }
125            }
126            AnyTerm::Synth(synth_term) => &synth_term.child_ids,
127            AnyTerm::Constant(_) => &[],
128        }
129    }
130}
131
132impl SynthTermKind {
133    pub fn parsed(self) -> Option<TermKind> {
134        match self {
135            SynthTermKind::Parsed(p) => Some(p),
136            _ => None,
137        }
138    }
139}
140
141impl Term {
142    #[allow(clippy::no_effect)]
143    const CHECK_REPR_EQ: bool = {
144        let sizeof_eq = core::mem::size_of::<Term>() == core::mem::size_of::<AnyTerm>();
145        [(); 1][!sizeof_eq as usize];
146        true
147    };
148    pub fn as_any(&self) -> &AnyTerm {
149        let _ = Self::CHECK_REPR_EQ;
150        // SAFETY: `AnyTerm` and `Term` have the same memory layout since they
151        // have the same `size_of`.
152        unsafe { std::mem::transmute::<&Term, &AnyTerm>(self) }
153    }
154}
155
156idx!(DefStIdx, "y{}");
157
158#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
159#[derive(Debug)]
160pub struct SynthTerms {
161    start_idx: TermIdx,
162    synth_terms: TiVec<DefStIdx, AnyTerm>,
163    interned: FxHashMap<AnyTerm, SynthIdx>,
164}
165
166impl Default for SynthTerms {
167    fn default() -> Self {
168        Self {
169            start_idx: TermIdx::MAX,
170            synth_terms: TiVec::default(),
171            interned: FxHashMap::default(),
172        }
173    }
174}
175
176impl SynthTerms {
177    fn tidx_to_dstidx(&self, tidx: SynthIdx) -> core::result::Result<DefStIdx, TermIdx> {
178        if self.start_idx <= tidx.0 {
179            Ok(DefStIdx::from(
180                usize::from(tidx.0) - usize::from(self.start_idx),
181            ))
182        } else {
183            Err(tidx.0)
184        }
185    }
186    fn dstidx_to_tidx(dstidx: DefStIdx, start_idx: TermIdx) -> SynthIdx {
187        assert!(
188            usize::from(dstidx) + usize::from(start_idx) < usize::MAX,
189            "SynthIdx overflow {dstidx} + {start_idx}"
190        );
191        SynthIdx(TermIdx::from(usize::from(dstidx) + usize::from(start_idx)))
192    }
193
194    pub fn eof(&mut self, start_idx: TermIdx) {
195        self.start_idx = start_idx;
196    }
197
198    pub fn as_tidx(&self, sidx: SynthIdx) -> Option<TermIdx> {
199        self.tidx_to_dstidx(sidx).err()
200    }
201
202    pub(crate) fn index<'a>(&'a self, terms: &'a Terms, idx: SynthIdx) -> &'a AnyTerm {
203        match self.tidx_to_dstidx(idx) {
204            Ok(idx) => &self.synth_terms[idx],
205            Err(tidx) => {
206                let term = &terms[tidx];
207                term.as_any()
208            }
209        }
210    }
211
212    pub fn new_constant(&mut self, meaning: Meaning) -> Result<SynthIdx> {
213        self.insert(AnyTerm::Constant(meaning))
214    }
215
216    pub fn new_input(&mut self, offset: Option<Meaning>) -> Result<SynthIdx> {
217        let offset = offset.map(|o| self.new_constant(o)).transpose()?;
218        let term = SynthTerm {
219            kind: SynthTermKind::Input(offset),
220            child_ids: Default::default(),
221        };
222        self.insert(AnyTerm::Synth(term))
223    }
224
225    pub fn new_variable(&mut self, id: u32) -> Result<SynthIdx> {
226        let term = SynthTerm {
227            kind: SynthTermKind::Variable(id),
228            child_ids: Default::default(),
229        };
230        self.insert(AnyTerm::Synth(term))
231    }
232
233    pub fn new_generalised(&mut self, first: Option<TermIdx>, gen: SynthIdx) -> Result<SynthIdx> {
234        let term = SynthTerm {
235            kind: SynthTermKind::Generalised(first),
236            child_ids: [gen].into_iter().collect(),
237        };
238        self.insert(AnyTerm::Synth(term))
239    }
240
241    pub fn new_synthetic(
242        &mut self,
243        kind: TermKind,
244        child_ids: BoxSlice<SynthIdx>,
245    ) -> Result<SynthIdx> {
246        let term = SynthTerm {
247            kind: SynthTermKind::Parsed(kind),
248            child_ids,
249        };
250        self.insert(AnyTerm::Synth(term))
251    }
252
253    pub(crate) fn insert(&mut self, term: AnyTerm) -> Result<SynthIdx> {
254        term.check_valid(|idx| self.as_tidx(idx).is_some());
255        self.interned.try_reserve(1)?;
256        match self.interned.entry(term) {
257            Entry::Occupied(entry) => Ok(*entry.get()),
258            Entry::Vacant(entry) => {
259                self.synth_terms.raw.try_reserve(1)?;
260                let idx = self.synth_terms.push_and_get_key(entry.key().clone());
261                let idx = Self::dstidx_to_tidx(idx, self.start_idx);
262                Ok(*entry.insert(idx))
263            }
264        }
265    }
266}
267
268pub trait TermWalker<'a> {
269    fn parser(&self) -> &'a Z3Parser;
270    fn walk_idx(&mut self, idx: SynthIdx) {
271        let term = self.parser().synth_terms.index(&self.parser().terms, idx);
272        if !self.walk_synth(term, idx) {
273            return;
274        }
275        if let AnyTerm::Parsed(term) = term {
276            let idx = self.parser().synth_terms.as_tidx(idx).unwrap();
277            if !self.walk_term(term, idx) {
278                return;
279            }
280        }
281        for &child in term.child_ids() {
282            self.walk_idx(child);
283        }
284    }
285    /// Return `false` to stop walking the children of this term.
286    fn walk_synth(&mut self, _term: &'a AnyTerm, _idx: SynthIdx) -> bool {
287        true
288    }
289    /// Return `false` to stop walking the children of this term.
290    fn walk_term(&mut self, _term: &'a Term, _idx: TermIdx) -> bool {
291        true
292    }
293}