smt_scope/analysis/
generalise.rs

1use crate::{
2    items::{Meaning, TermIdx},
3    parsers::z3::{
4        synthetic::{AnyTerm, SynthIdx, SynthTerm, SynthTermKind, SynthTerms},
5        terms::Terms,
6    },
7    FxHashMap, Result,
8};
9
10impl SynthTerms {
11    pub fn generalise_first(
12        &mut self,
13        table: &Terms,
14        smaller: TermIdx,
15        larger: TermIdx,
16        depth: usize,
17    ) -> Result<Option<SynthIdx>> {
18        if smaller == larger {
19            // if terms are equal, no need to generalize
20            return Ok(Some(smaller.into()));
21        }
22
23        let smaller_term = &table[smaller];
24        let larger_term = &table[larger];
25
26        let smaller_meaning = table.meaning(smaller);
27        let larger_meaning = table.meaning(larger);
28        if smaller_meaning.is_none()
29            && larger_meaning.is_none()
30            && smaller_term.kind() == larger_term.kind()
31            && smaller_term.child_ids.len() == larger_term.child_ids.len()
32        {
33            // TODO: remove this depth restriction
34            if depth == 0 {
35                return Ok(None);
36            }
37            let mut child_ids = Vec::new();
38            child_ids.try_reserve_exact(smaller_term.child_ids.len())?;
39            for i in 0..smaller_term.child_ids.len() {
40                let sc = smaller_term.child_ids[i];
41                let lc = larger_term.child_ids[i];
42                let gen = self.generalise_first(table, sc, lc, depth - 1)?;
43                let Some(gen) = gen else {
44                    // recursion limit reached
45                    break;
46                };
47                // We don't generalise up on the first pass, since the
48                // generalisation might be wrong at this point and we don't want
49                // to generalise up incorrectly.
50                // if self.is_generalised_by::<true>(table, smaller, larger, gen) {
51                //     return Ok(Some(gen));
52                // }
53                child_ids.push(gen);
54            }
55            // all children generalized successfully, but none generalise up to current
56            if child_ids.len() == smaller_term.child_ids.len() {
57                return self
58                    .new_synthetic(smaller_term.kind(), child_ids.into())
59                    .map(Some);
60            }
61        } else {
62            // If the meanings are some and equal, then the TermIdx should've
63            // been equal?
64            assert!(smaller_meaning.is_none() || smaller_meaning != larger_meaning);
65        }
66
67        if let Some(input_replace) = self.input_replace(table, smaller, larger)? {
68            self.new_generalised(Some(smaller), input_replace).map(Some)
69        } else {
70            Ok(None)
71        }
72    }
73
74    /// Given a `larger -> g(x)` and `smaller -> x`, return a `g($0)` term.
75    pub fn input_replace(
76        &mut self,
77        table: &Terms,
78        smaller: TermIdx,
79        larger: TermIdx,
80    ) -> Result<Option<SynthIdx>> {
81        assert_ne!(larger, smaller);
82        let smaller_meaning = table.meaning(smaller);
83        let mut cache = FxHashMap::default();
84        let (found_smaller, replaced) =
85            self.input_replace_inner(table, smaller, smaller_meaning, larger, &mut cache)?;
86        // If we didn't find the smaller term in the bigger one, then it's
87        // probably not a repeating pattern getting larger (so return `None`).
88        Ok(found_smaller.then_some(replaced))
89    }
90
91    fn input_replace_cached(
92        &mut self,
93        table: &Terms,
94        smaller: TermIdx,
95        smaller_meaning: Option<&Meaning>,
96        larger: TermIdx,
97        cache: &mut FxHashMap<TermIdx, (bool, SynthIdx)>,
98    ) -> Result<(bool, SynthIdx)> {
99        if let Some((found_smaller, replaced)) = cache.get(&larger) {
100            return Ok((*found_smaller, *replaced));
101        }
102        let (found_smaller, replaced) =
103            self.input_replace_inner(table, smaller, smaller_meaning, larger, cache)?;
104        cache.insert(larger, (found_smaller, replaced));
105        Ok((found_smaller, replaced))
106    }
107
108    fn input_replace_inner(
109        &mut self,
110        table: &Terms,
111        smaller: TermIdx,
112        smaller_meaning: Option<&Meaning>,
113        larger: TermIdx,
114        cache: &mut FxHashMap<TermIdx, (bool, SynthIdx)>,
115    ) -> Result<(bool, SynthIdx)> {
116        if larger == smaller {
117            return Ok((true, self.new_input(None)?));
118        }
119
120        match (smaller_meaning, table.meaning(larger)) {
121            (Some(Meaning::Arith(s)), Some(Meaning::Arith(l))) => {
122                let delta = &***l - &***s;
123                let meaning = Meaning::Arith(Box::new(delta.into()));
124                let term = self.new_input(Some(meaning))?;
125                return Ok((true, term));
126            }
127            (Some(_), Some(_)) => {
128                return Ok((false, larger.into()));
129            }
130            _ => (),
131        }
132
133        let larger_term = &table[larger];
134        let mut found_smaller = false;
135        let child_ids = larger_term
136            .child_ids
137            .iter()
138            .map(|&c| {
139                let (found, replaced) =
140                    self.input_replace_cached(table, smaller, smaller_meaning, c, cache)?;
141                found_smaller |= found;
142                Ok(replaced)
143            })
144            .collect::<Result<_>>()?;
145        let replaced = if found_smaller {
146            self.new_synthetic(larger_term.kind(), child_ids)?
147        } else {
148            larger.into()
149        };
150        Ok((found_smaller, replaced))
151    }
152
153    pub fn generalise_second(
154        &mut self,
155        table: &Terms,
156        smaller: TermIdx,
157        larger: TermIdx,
158        gen: SynthIdx,
159    ) -> Result<Option<SynthIdx>> {
160        let gen_term = self.index(table, gen);
161        match gen_term {
162            AnyTerm::Parsed(_) => {
163                let gen_tidx = self.as_tidx(gen).unwrap();
164                Ok((gen_tidx == larger).then_some(gen))
165            }
166            AnyTerm::Synth(synth_term) => match synth_term.kind {
167                SynthTermKind::Parsed(term_kind) => {
168                    let smaller_term = &table[smaller];
169                    let larger_term = &table[larger];
170                    let smaller_meaning = table.meaning(smaller);
171                    let larger_meaning = table.meaning(larger);
172                    if !(smaller_meaning.is_none()
173                        && larger_meaning.is_none()
174                        && threeway_eq(&term_kind, &smaller_term.kind(), &larger_term.kind())
175                        && threeway_eq(
176                            gen_term.child_ids().len(),
177                            smaller_term.child_ids.len(),
178                            larger_term.child_ids.len(),
179                        ))
180                    {
181                        return Ok(None);
182                    }
183                    let mut child_ids = Vec::new();
184                    child_ids.try_reserve_exact(smaller_term.child_ids.len())?;
185                    for i in 0..smaller_term.child_ids.len() {
186                        let sc = smaller_term.child_ids[i];
187                        let lc = larger_term.child_ids[i];
188                        let gen = self.index(table, gen).child_ids()[i];
189                        let Some(new) = self.generalise_second(table, sc, lc, gen)? else {
190                            return Ok(None);
191                        };
192                        if self.is_generalised_by::<false>(table, smaller, larger, new) {
193                            // this child generalises up to current term, discard
194                            // current term and push this generalisation up
195                            return Ok(Some(new));
196                        }
197                        child_ids.push(new);
198                    }
199                    self.new_synthetic(term_kind, child_ids.into()).map(Some)
200                }
201                SynthTermKind::Generalised(Some(orig)) => {
202                    if smaller == larger {
203                        // We have e.g. `x` -> `f(x)` -> `f(x)`
204                        return Ok(None);
205                    }
206                    let old_input = synth_term.child_ids[0];
207                    let Some(curr_input) = self.input_replace(table, smaller, larger)? else {
208                        return Ok(None);
209                    };
210                    if self.is_valid_input(table, orig, old_input, curr_input) {
211                        let new = self.new_generalised(None, curr_input)?;
212                        Ok(Some(new))
213                    } else {
214                        Ok(None)
215                    }
216                }
217                o => unreachable!("{o:?}"),
218            },
219            _ => unreachable!(),
220        }
221    }
222
223    fn is_valid_input(
224        &self,
225        table: &Terms,
226        orig: TermIdx,
227        old_input: SynthIdx,
228        curr_input: SynthIdx,
229    ) -> bool {
230        if old_input == curr_input {
231            return true;
232        }
233        let None = self.as_tidx(old_input) else {
234            return false;
235        };
236        let old_input = self.index(table, old_input);
237        if let Some(curr_input) = self.as_tidx(curr_input) {
238            let AnyTerm::Synth(SynthTerm {
239                kind: SynthTermKind::Input(None),
240                ..
241            }) = old_input
242            else {
243                return false;
244            };
245            return curr_input == orig;
246        }
247        // Both are true synth, but do not match (recurse)
248        let curr_input = self.index(table, curr_input);
249        if old_input.kind() != curr_input.kind()
250            || old_input.child_ids().len() != curr_input.child_ids().len()
251        {
252            return false;
253        }
254        for i in 0..old_input.child_ids().len() {
255            let old = old_input.child_ids()[i];
256            let curr = curr_input.child_ids()[i];
257            if !self.is_valid_input(table, orig, old, curr) {
258                return false;
259            }
260        }
261        true
262    }
263
264    pub fn is_generalised_by<const SECOND: bool>(
265        &mut self,
266        table: &Terms,
267        smaller: TermIdx,
268        larger: TermIdx,
269        gen: SynthIdx,
270    ) -> bool {
271        let gen_term = self.index(table, gen);
272        match gen_term {
273            AnyTerm::Parsed(_) => {
274                let gen = self.as_tidx(gen).unwrap();
275                gen == larger
276            }
277            AnyTerm::Synth(synth_term) => match &synth_term.kind {
278                SynthTermKind::Parsed(term_kind) => {
279                    let smaller_term = &table[smaller];
280                    let larger_term = &table[larger];
281                    let smaller_meaning = table.meaning(smaller);
282                    let larger_meaning = table.meaning(larger);
283                    if !(smaller_meaning.is_none()
284                        && larger_meaning.is_none()
285                        && threeway_eq(term_kind, &smaller_term.kind(), &larger_term.kind())
286                        && threeway_eq(
287                            gen_term.child_ids().len(),
288                            smaller_term.child_ids.len(),
289                            larger_term.child_ids.len(),
290                        ))
291                    {
292                        return false;
293                    }
294                    for i in 0..smaller_term.child_ids.len() {
295                        let smaller = smaller_term.child_ids[i];
296                        let larger = larger_term.child_ids[i];
297                        let gen = self.index(table, gen).child_ids()[i];
298                        if !self.is_generalised_by::<SECOND>(table, smaller, larger, gen) {
299                            return false;
300                        }
301                    }
302                    true
303                }
304                SynthTermKind::Generalised(orig) => {
305                    debug_assert_eq!(orig.is_some(), SECOND);
306                    if smaller == larger {
307                        return false;
308                    }
309                    let input = synth_term.child_ids[0];
310                    let actual_input = self.input_replace(table, smaller, larger).ok().flatten();
311                    actual_input == Some(input)
312                }
313                _ => unreachable!(),
314            },
315            _ => unreachable!(),
316        }
317    }
318}
319
320fn threeway_eq<T: Eq>(a: T, b: T, c: T) -> bool {
321    a == b && a == c
322}