1use sim_kernel::{Cx, Diagnostic, Expr, Result, Value, shape_is_subshape_of};
5
6use crate::{
7 ExactExprShape, Shape, TableExtraPolicy, TableShape,
8 compare::normal::{ShapeNormalForm, ShapeNormalKind, normalize_shape},
9};
10
11#[derive(Clone, Debug)]
17pub struct ShapeRelation {
18 pub left: ShapeNormalForm,
20 pub right: ShapeNormalForm,
22 pub kind: ShapeRelationKind,
24 pub proven: bool,
26 pub witnesses: Vec<ShapeWitness>,
28 pub diagnostics: Vec<Diagnostic>,
30}
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum ShapeRelationKind {
35 Equal,
37 LeftSubshape,
39 RightSubshape,
41 Overlap,
43 Disjoint,
45 Unknown,
47}
48
49#[derive(Clone, Debug, PartialEq, Eq)]
51pub struct ShapeWitness {
52 pub label: String,
54 pub accepted_left: bool,
56 pub accepted_right: bool,
58 pub note: String,
60}
61
62#[derive(Clone, Debug)]
64pub enum ShapeProbe {
65 Value {
67 label: String,
69 value: Value,
71 },
72 Expr {
74 label: String,
76 expr: Expr,
78 },
79}
80
81pub fn relate_shapes(
98 cx: &mut Cx,
99 left: &dyn Shape,
100 right: &dyn Shape,
101 probes: &[ShapeProbe],
102) -> Result<ShapeRelation> {
103 let left_normal = normalize_shape(cx, left)?;
104 let right_normal = normalize_shape(cx, right)?;
105
106 let left_to_right = shape_is_subshape_of(cx, left, right)?;
107 let right_to_left = shape_is_subshape_of(cx, right, left)?;
108 if left_to_right && right_to_left {
109 return Ok(relation(
110 left_normal,
111 right_normal,
112 ShapeRelationKind::Equal,
113 true,
114 Vec::new(),
115 Vec::new(),
116 ));
117 }
118 if left_to_right {
119 return Ok(relation(
120 left_normal,
121 right_normal,
122 ShapeRelationKind::LeftSubshape,
123 true,
124 Vec::new(),
125 Vec::new(),
126 ));
127 }
128 if right_to_left {
129 return Ok(relation(
130 left_normal,
131 right_normal,
132 ShapeRelationKind::RightSubshape,
133 true,
134 Vec::new(),
135 Vec::new(),
136 ));
137 }
138
139 if let Some(message) = static_disjoint(cx, left, right, &left_normal, &right_normal)? {
140 return Ok(relation(
141 left_normal,
142 right_normal,
143 ShapeRelationKind::Disjoint,
144 true,
145 Vec::new(),
146 vec![Diagnostic::info(message)],
147 ));
148 }
149
150 let witnesses = probes
151 .iter()
152 .map(|probe| run_probe(cx, left, right, probe))
153 .collect::<Result<Vec<_>>>()?;
154 if witnesses
155 .iter()
156 .any(|witness| witness.accepted_left && witness.accepted_right)
157 {
158 return Ok(relation(
159 left_normal,
160 right_normal,
161 ShapeRelationKind::Overlap,
162 false,
163 witnesses,
164 Vec::new(),
165 ));
166 }
167
168 Ok(relation(
169 left_normal,
170 right_normal,
171 ShapeRelationKind::Unknown,
172 false,
173 witnesses,
174 Vec::new(),
175 ))
176}
177
178fn relation(
179 left: ShapeNormalForm,
180 right: ShapeNormalForm,
181 kind: ShapeRelationKind,
182 proven: bool,
183 witnesses: Vec<ShapeWitness>,
184 diagnostics: Vec<Diagnostic>,
185) -> ShapeRelation {
186 ShapeRelation {
187 left,
188 right,
189 kind,
190 proven,
191 witnesses,
192 diagnostics,
193 }
194}
195
196fn run_probe(
197 cx: &mut Cx,
198 left: &dyn Shape,
199 right: &dyn Shape,
200 probe: &ShapeProbe,
201) -> Result<ShapeWitness> {
202 let (label, accepted_left, accepted_right) = match probe {
203 ShapeProbe::Value { label, value } => (
204 label.clone(),
205 left.check_value(cx, value.clone())?.accepted,
206 right.check_value(cx, value.clone())?.accepted,
207 ),
208 ShapeProbe::Expr { label, expr } => (
209 label.clone(),
210 left.check_expr(cx, expr)?.accepted,
211 right.check_expr(cx, expr)?.accepted,
212 ),
213 };
214 let note = match (accepted_left, accepted_right) {
215 (true, true) => "accepted by both",
216 (true, false) => "accepted by left only",
217 (false, true) => "accepted by right only",
218 (false, false) => "accepted by neither",
219 }
220 .to_owned();
221 Ok(ShapeWitness {
222 label,
223 accepted_left,
224 accepted_right,
225 note,
226 })
227}
228
229fn static_disjoint(
230 cx: &mut Cx,
231 left: &dyn Shape,
232 right: &dyn Shape,
233 left_normal: &ShapeNormalForm,
234 right_normal: &ShapeNormalForm,
235) -> Result<Option<String>> {
236 if not_of(left_normal, right_normal) || not_of(right_normal, left_normal) {
237 return Ok(Some(
238 "shape-compare: negation excludes inner shape".to_owned(),
239 ));
240 }
241 if exact_exprs_differ(left, right) {
242 return Ok(Some(
243 "shape-compare: exact expression shapes differ".to_owned(),
244 ));
245 }
246 if fixed_list_lengths_differ(left_normal, right_normal) {
247 return Ok(Some("shape-compare: fixed list lengths differ".to_owned()));
248 }
249 if closed_table_field_disjoint(cx, left, right)? {
250 return Ok(Some(
251 "shape-compare: closed tables require disjoint shared field".to_owned(),
252 ));
253 }
254 Ok(None)
255}
256
257fn not_of(not_candidate: &ShapeNormalForm, other: &ShapeNormalForm) -> bool {
258 matches!(¬_candidate.kind, ShapeNormalKind::Not(inner) if inner.as_ref() == other)
259}
260
261fn exact_exprs_differ(left: &dyn Shape, right: &dyn Shape) -> bool {
262 let Some(left) = left.as_any().downcast_ref::<ExactExprShape>() else {
263 return false;
264 };
265 let Some(right) = right.as_any().downcast_ref::<ExactExprShape>() else {
266 return false;
267 };
268 !left.expected().canonical_eq(right.expected())
269}
270
271fn fixed_list_lengths_differ(left: &ShapeNormalForm, right: &ShapeNormalForm) -> bool {
272 match (&left.kind, &right.kind) {
273 (
274 ShapeNormalKind::List {
275 items: left,
276 rest: None,
277 },
278 ShapeNormalKind::List {
279 items: right,
280 rest: None,
281 },
282 ) => left.len() != right.len(),
283 _ => false,
284 }
285}
286
287fn closed_table_field_disjoint(cx: &mut Cx, left: &dyn Shape, right: &dyn Shape) -> Result<bool> {
288 let Some(left) = left.as_any().downcast_ref::<TableShape>() else {
289 return Ok(false);
290 };
291 let Some(right) = right.as_any().downcast_ref::<TableShape>() else {
292 return Ok(false);
293 };
294 if !matches!(left.extra(), TableExtraPolicy::Reject)
295 || !matches!(right.extra(), TableExtraPolicy::Reject)
296 {
297 return Ok(false);
298 }
299
300 for left_field in left.fields().iter().filter(|field| field.required) {
301 let Some(right_field) = right
302 .fields()
303 .iter()
304 .find(|field| field.required && field.key == left_field.key)
305 else {
306 continue;
307 };
308 let relation = relate_shapes(
309 cx,
310 left_field.shape.as_ref(),
311 right_field.shape.as_ref(),
312 &[],
313 )?;
314 if relation.proven && relation.kind == ShapeRelationKind::Disjoint {
315 return Ok(true);
316 }
317 }
318 Ok(false)
319}