1use crate::*;
2
3use std::hash::{Hash, Hasher};
4
5#[derive(Clone, PartialEq, Eq, Hash)]
6pub struct Equation {
7 pub l: AppliedId,
8 pub r: AppliedId,
9}
10
11#[derive(Clone, Debug)]
12pub struct ExplicitProof(pub Option<String>);
13#[derive(Clone, Debug)]
14pub struct ReflexivityProof;
15#[derive(Clone, Debug)]
16pub struct SymmetryProof(pub ProvenEq);
17#[derive(Clone, Debug)]
18pub struct TransitivityProof(pub ProvenEq, pub ProvenEq);
19#[derive(Clone, Debug)]
20pub struct CongruenceProof(pub Vec<ProvenEq>);
21
22#[derive(Debug, Clone)]
23pub enum Proof {
24 Explicit(ExplicitProof),
25 Reflexivity(ReflexivityProof),
26 Symmetry(SymmetryProof),
27 Transitivity(TransitivityProof),
28 Congruence(CongruenceProof),
29 }
34
35pub type ProvenEq = Arc<ProvenEqRaw>;
36
37#[derive(Debug, Clone)]
38pub struct ProvenEqRaw {
39 eq: Equation,
42 proof: Proof,
43}
44
45impl ProvenEqRaw {
46 pub fn equ(&self) -> Equation {
47 (**self).clone()
48 }
49
50 pub(crate) fn check<L: Language, N: Analysis<L>>(&self, eg: &EGraph<L, N>) {
51 let Equation { l, r } = self.equ();
52 eg.check_syn_applied_id(&l);
53 eg.check_syn_applied_id(&r);
54 }
55}
56
57impl PartialEq for ProvenEqRaw {
58 fn eq(&self, other: &Self) -> bool {
60 self.eq == other.eq
61 }
62}
63
64impl Eq for ProvenEqRaw {}
65
66impl Hash for ProvenEqRaw {
67 fn hash<H: Hasher>(&self, hasher: &mut H) {
68 self.eq.hash(hasher);
70 }
71}
72
73impl ExplicitProof {
74 pub(crate) fn check(&self, eq: &Equation, reg: &ProofRegistry) -> ProvenEq {
75 let eq = eq.clone();
76 let proof = Proof::Explicit(self.clone());
77 reg.insert(Arc::new(ProvenEqRaw { eq, proof }))
78 }
79}
80
81impl ReflexivityProof {
82 pub(crate) fn check(&self, eq: &Equation, reg: &ProofRegistry) -> ProvenEq {
83 assert_eq!(eq.l, eq.r);
84
85 let eq = eq.clone();
86 let proof = Proof::Reflexivity(self.clone());
87 reg.insert(Arc::new(ProvenEqRaw { eq, proof }))
88 }
89}
90
91impl SymmetryProof {
92 pub(crate) fn check(&self, eq: &Equation, reg: &ProofRegistry) -> ProvenEq {
93 let SymmetryProof(x) = self;
94
95 let flipped = Equation {
96 l: x.r.clone(),
97 r: x.l.clone(),
98 };
99 assert_match_equation(eq, &flipped);
100
101 let eq = eq.clone();
102 let proof = Proof::Symmetry(self.clone());
103 reg.insert(Arc::new(ProvenEqRaw { eq, proof }))
104 }
105}
106
107impl TransitivityProof {
108 pub(crate) fn check(&self, eq: &Equation, reg: &ProofRegistry) -> ProvenEq {
109 let TransitivityProof(eq1, eq2) = self;
110
111 let mut theta1 = {
112 eq1.l.m.inverse().compose_partial(&eq.l.m)
115 };
116 let mut theta2 = {
117 eq2.r.m.inverse().compose_partial(&eq.r.m)
120 };
121
122 let recompute_theta1 = |theta1: &mut SlotMap, theta2: &SlotMap| {
123 *theta1 = theta1
126 .try_union(
127 &eq1.r
128 .m
129 .inverse()
130 .compose_partial(&eq2.l.m)
131 .compose_partial(theta2),
132 )
133 .unwrap();
134 };
135
136 let recompute_theta2 = |theta1: &SlotMap, theta2: &mut SlotMap| {
137 *theta2 = theta2
140 .try_union(
141 &eq2.l
142 .m
143 .inverse()
144 .compose_partial(&eq1.r.m)
145 .compose_partial(theta1),
146 )
147 .unwrap();
148 };
149
150 recompute_theta1(&mut theta1, &theta2);
151 recompute_theta2(&theta1, &mut theta2);
152
153 for x in eq1.slots() {
154 if !theta1.contains_key(x) {
155 theta1.insert(x, Slot::fresh());
156 }
157 }
158 recompute_theta2(&theta1, &mut theta2);
159 for x in eq2.slots() {
160 if !theta2.contains_key(x) {
161 theta2.insert(x, Slot::fresh());
162 }
163 }
164
165 let renamed_eq1 = eq1.apply_slotmap(&theta1);
166 let renamed_eq2 = eq2.apply_slotmap(&theta2);
167
168 assert_eq!(renamed_eq1.l, eq.l);
169 assert_eq!(renamed_eq2.r, eq.r);
170 assert_eq!(renamed_eq1.r, renamed_eq2.l);
171
172 let eq = eq.clone();
173 let proof = Proof::Transitivity(self.clone());
174 reg.insert(Arc::new(ProvenEqRaw { eq, proof }))
175 }
176}
177
178pub(crate) fn alpha_normalize<L: Language>(n: &L) -> L {
180 let (sh, bij) = n.weak_shape();
181 if CHECKS {
182 let all_slots: SmallHashSet<_> = sh.all_slot_occurrences().into_iter().collect();
183 assert!(&bij.values().is_disjoint(&all_slots));
184 }
185 sh.apply_slotmap(&bij)
186}
187
188impl CongruenceProof {
189 pub fn check<L: Language, N: Analysis<L>>(&self, eq: &Equation, eg: &EGraph<L, N>) -> ProvenEq {
190 let CongruenceProof(child_proofs) = self;
191
192 let l = alpha_normalize(&eg.get_syn_node(&eq.l));
193 let r = alpha_normalize(&eg.get_syn_node(&eq.r));
194
195 let null_l = nullify_app_ids(&l);
196 let null_r = nullify_app_ids(&r);
197 assert_eq!(null_l, null_r);
198
199 let l_v = l.applied_id_occurrences();
200 let r_v = r.applied_id_occurrences();
201
202 assert_eq!(l_v.len(), child_proofs.len());
203 assert_eq!(r_v.len(), child_proofs.len());
204
205 let l_v = l_v.into_iter().cloned();
206 let r_v = r_v.into_iter().cloned();
207
208 let c_v = child_proofs.into_iter();
209 for ((ll, rr), prf) in l_v.zip(r_v).zip(c_v) {
210 let eq1 = &Equation { l: ll, r: rr };
211 let eq2 = prf.deref();
212 assert_match_equation(eq1, eq2);
213 }
214
215 let eq = eq.clone();
216 let proof = Proof::Congruence(self.clone());
217 eg.proof_registry
218 .insert(Arc::new(ProvenEqRaw { eq, proof }))
219 }
220}
221
222impl Equation {
223 pub fn slots(&self) -> SmallHashSet<Slot> {
224 &self.l.slots() | &self.r.slots()
225 }
226
227 #[track_caller]
228 pub fn apply_slotmap(&self, m: &SlotMap) -> Self {
229 Equation {
230 l: self.l.apply_slotmap(&m),
231 r: self.r.apply_slotmap(&m),
232 }
233 }
234
235 pub fn apply_slotmap_fresh(&self, m: &SlotMap) -> Self {
236 let mut m = m.clone();
237 for s in &self.l.slots() | &self.r.slots() {
238 if !m.contains_key(s) {
239 m.insert(s, Slot::fresh());
240 }
241 }
242 Equation {
243 l: self.l.apply_slotmap(&m),
244 r: self.r.apply_slotmap(&m),
245 }
246 }
247}
248
249impl Deref for ProvenEqRaw {
250 type Target = Equation;
251
252 fn deref(&self) -> &Equation {
253 &self.eq
254 }
255}
256
257impl ProvenEqRaw {
258 pub fn proof(&self) -> &Proof {
259 &self.proof
260 }
261}
262
263#[track_caller]
265pub(crate) fn match_app_id(a: &AppliedId, b: &AppliedId) -> SlotMap {
266 if CHECKS {
267 assert_eq!(a.id, b.id);
268 assert_eq!(
269 a.m.keys(),
270 b.m.keys(),
271 "match_app_id failed: different set of arguments"
272 );
273 }
274
275 let theta = a.m.inverse().compose(&b.m);
279
280 if CHECKS {
281 assert_eq!(&a.apply_slotmap(&theta), b);
282 }
283
284 theta
285}
286
287pub(crate) fn assert_match_equation(a: &Equation, b: &Equation) -> SlotMap {
289 let theta_l = match_app_id(&a.l, &b.l);
290 let theta_r = match_app_id(&a.r, &b.r);
291
292 let theta = theta_l.try_union(&theta_r).unwrap_or_else(|| panic!("trying to union {theta_l:?} with {theta_r:?} while trying to match '{a:?}' against '{b:?}'"));
293
294 if CHECKS {
295 assert!(theta.is_bijection(), "trying to unify {theta_l:?} with {theta_r:?}, in assert_match_equation(\n {a:?},\n {b:?}\n)");
296
297 assert_eq!(&a.apply_slotmap(&theta), b);
298 }
299
300 theta
301}
302
303pub(crate) fn assert_proves_equation(peq: &ProvenEq, eq: &Equation) {
304 let mut e: Equation = (***peq).clone();
305
306 for s in e.l.m.keys() {
307 if !eq.l.m.contains_key(s) {
308 e.l.m.remove(s);
309 }
310 }
311
312 for s in e.r.m.keys() {
313 if !eq.r.m.contains_key(s) {
314 e.r.m.remove(s);
315 }
316 }
317
318 assert_match_equation(&e, eq);
319}