Skip to main content

sim_shape/compare/
relation.rs

1//! Shape relation analysis: classify how two shapes relate (subshape, equal,
2//! disjoint, overlapping) via probes over their normalized forms.
3
4use sim_kernel::{Cx, Diagnostic, Expr, Result, Value, shape_is_subshape_of};
5
6use crate::{
7    ExactExprShape, Shape, TableExtraPolicy, TableShape,
8    compare::normal::{ShapeNormalForm, ShapeNormalKind, normalize_shape},
9};
10
11/// Conservative relation report between two shapes.
12///
13/// A relation is `proven` only when the runtime can establish it through
14/// subshape checks or explicit conservative rules. Probe-only overlap remains
15/// useful evidence but is not a proof.
16#[derive(Clone, Debug)]
17pub struct ShapeRelation {
18    /// Normal form for the left input shape.
19    pub left: ShapeNormalForm,
20    /// Normal form for the right input shape.
21    pub right: ShapeNormalForm,
22    /// Best relation kind the runtime could determine.
23    pub kind: ShapeRelationKind,
24    /// Whether the relation kind is proven rather than probe-derived.
25    pub proven: bool,
26    /// Probe results collected while comparing the shapes.
27    pub witnesses: Vec<ShapeWitness>,
28    /// Diagnostic notes explaining conservative proof rules.
29    pub diagnostics: Vec<Diagnostic>,
30}
31
32/// Relation categories reported by [`ShapeRelation`].
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum ShapeRelationKind {
35    /// Both shapes imply each other.
36    Equal,
37    /// The left shape is known to be a subshape of the right shape.
38    LeftSubshape,
39    /// The right shape is known to be a subshape of the left shape.
40    RightSubshape,
41    /// At least one probe is accepted by both shapes.
42    Overlap,
43    /// The shapes are conservatively known to have no common values.
44    Disjoint,
45    /// The runtime could not prove a stronger relation.
46    Unknown,
47}
48
49/// Result of checking one probe against both compared shapes.
50#[derive(Clone, Debug, PartialEq, Eq)]
51pub struct ShapeWitness {
52    /// Stable label supplied by the caller.
53    pub label: String,
54    /// Whether the left shape accepted the probe.
55    pub accepted_left: bool,
56    /// Whether the right shape accepted the probe.
57    pub accepted_right: bool,
58    /// Short explanation such as `accepted by both`.
59    pub note: String,
60}
61
62/// Example value or expression used to gather relation evidence.
63#[derive(Clone, Debug)]
64pub enum ShapeProbe {
65    /// Runtime value probe.
66    Value {
67        /// Stable label recorded on the resulting witness.
68        label: String,
69        /// Value checked against both shapes.
70        value: Value,
71    },
72    /// Expression probe.
73    Expr {
74        /// Stable label recorded on the resulting witness.
75        label: String,
76        /// Expression checked against both shapes.
77        expr: Expr,
78    },
79}
80
81/// Compare two shapes with conservative proof rules plus optional probes.
82///
83/// ```rust
84/// # use std::sync::Arc;
85/// # use sim_kernel::{Cx, DefaultFactory, Expr, NoopEvalPolicy};
86/// # use sim_shape::{
87/// #     ExactExprShape, ExprKind, ExprKindShape, ShapeRelationKind, relate_shapes,
88/// # };
89/// # let mut cx = Cx::new(Arc::new(NoopEvalPolicy), Arc::new(DefaultFactory));
90/// let exact_true = ExactExprShape::new(Expr::Bool(true));
91/// let bool_expr = ExprKindShape::new(ExprKind::Bool);
92/// let relation = relate_shapes(&mut cx, &exact_true, &bool_expr, &[]).unwrap();
93///
94/// assert_eq!(relation.kind, ShapeRelationKind::LeftSubshape);
95/// assert!(relation.proven);
96/// ```
97pub fn relate_shapes(
98    cx: &mut Cx,
99    left: &dyn Shape,
100    right: &dyn Shape,
101    probes: &[ShapeProbe],
102) -> Result<ShapeRelation> {
103    let left_normal = normalize_shape(cx, left)?;
104    let right_normal = normalize_shape(cx, right)?;
105
106    let left_to_right = shape_is_subshape_of(cx, left, right)?;
107    let right_to_left = shape_is_subshape_of(cx, right, left)?;
108    if left_to_right && right_to_left {
109        return Ok(relation(
110            left_normal,
111            right_normal,
112            ShapeRelationKind::Equal,
113            true,
114            Vec::new(),
115            Vec::new(),
116        ));
117    }
118    if left_to_right {
119        return Ok(relation(
120            left_normal,
121            right_normal,
122            ShapeRelationKind::LeftSubshape,
123            true,
124            Vec::new(),
125            Vec::new(),
126        ));
127    }
128    if right_to_left {
129        return Ok(relation(
130            left_normal,
131            right_normal,
132            ShapeRelationKind::RightSubshape,
133            true,
134            Vec::new(),
135            Vec::new(),
136        ));
137    }
138
139    if let Some(message) = static_disjoint(cx, left, right, &left_normal, &right_normal)? {
140        return Ok(relation(
141            left_normal,
142            right_normal,
143            ShapeRelationKind::Disjoint,
144            true,
145            Vec::new(),
146            vec![Diagnostic::info(message)],
147        ));
148    }
149
150    let witnesses = probes
151        .iter()
152        .map(|probe| run_probe(cx, left, right, probe))
153        .collect::<Result<Vec<_>>>()?;
154    if witnesses
155        .iter()
156        .any(|witness| witness.accepted_left && witness.accepted_right)
157    {
158        return Ok(relation(
159            left_normal,
160            right_normal,
161            ShapeRelationKind::Overlap,
162            false,
163            witnesses,
164            Vec::new(),
165        ));
166    }
167
168    Ok(relation(
169        left_normal,
170        right_normal,
171        ShapeRelationKind::Unknown,
172        false,
173        witnesses,
174        Vec::new(),
175    ))
176}
177
178fn relation(
179    left: ShapeNormalForm,
180    right: ShapeNormalForm,
181    kind: ShapeRelationKind,
182    proven: bool,
183    witnesses: Vec<ShapeWitness>,
184    diagnostics: Vec<Diagnostic>,
185) -> ShapeRelation {
186    ShapeRelation {
187        left,
188        right,
189        kind,
190        proven,
191        witnesses,
192        diagnostics,
193    }
194}
195
196fn run_probe(
197    cx: &mut Cx,
198    left: &dyn Shape,
199    right: &dyn Shape,
200    probe: &ShapeProbe,
201) -> Result<ShapeWitness> {
202    let (label, accepted_left, accepted_right) = match probe {
203        ShapeProbe::Value { label, value } => (
204            label.clone(),
205            left.check_value(cx, value.clone())?.accepted,
206            right.check_value(cx, value.clone())?.accepted,
207        ),
208        ShapeProbe::Expr { label, expr } => (
209            label.clone(),
210            left.check_expr(cx, expr)?.accepted,
211            right.check_expr(cx, expr)?.accepted,
212        ),
213    };
214    let note = match (accepted_left, accepted_right) {
215        (true, true) => "accepted by both",
216        (true, false) => "accepted by left only",
217        (false, true) => "accepted by right only",
218        (false, false) => "accepted by neither",
219    }
220    .to_owned();
221    Ok(ShapeWitness {
222        label,
223        accepted_left,
224        accepted_right,
225        note,
226    })
227}
228
229fn static_disjoint(
230    cx: &mut Cx,
231    left: &dyn Shape,
232    right: &dyn Shape,
233    left_normal: &ShapeNormalForm,
234    right_normal: &ShapeNormalForm,
235) -> Result<Option<String>> {
236    if not_of(left_normal, right_normal) || not_of(right_normal, left_normal) {
237        return Ok(Some(
238            "shape-compare: negation excludes inner shape".to_owned(),
239        ));
240    }
241    if exact_exprs_differ(left, right) {
242        return Ok(Some(
243            "shape-compare: exact expression shapes differ".to_owned(),
244        ));
245    }
246    if fixed_list_lengths_differ(left_normal, right_normal) {
247        return Ok(Some("shape-compare: fixed list lengths differ".to_owned()));
248    }
249    if closed_table_field_disjoint(cx, left, right)? {
250        return Ok(Some(
251            "shape-compare: closed tables require disjoint shared field".to_owned(),
252        ));
253    }
254    Ok(None)
255}
256
257fn not_of(not_candidate: &ShapeNormalForm, other: &ShapeNormalForm) -> bool {
258    matches!(&not_candidate.kind, ShapeNormalKind::Not(inner) if inner.as_ref() == other)
259}
260
261fn exact_exprs_differ(left: &dyn Shape, right: &dyn Shape) -> bool {
262    let Some(left) = left.as_any().downcast_ref::<ExactExprShape>() else {
263        return false;
264    };
265    let Some(right) = right.as_any().downcast_ref::<ExactExprShape>() else {
266        return false;
267    };
268    !left.expected().canonical_eq(right.expected())
269}
270
271fn fixed_list_lengths_differ(left: &ShapeNormalForm, right: &ShapeNormalForm) -> bool {
272    match (&left.kind, &right.kind) {
273        (
274            ShapeNormalKind::List {
275                items: left,
276                rest: None,
277            },
278            ShapeNormalKind::List {
279                items: right,
280                rest: None,
281            },
282        ) => left.len() != right.len(),
283        _ => false,
284    }
285}
286
287fn closed_table_field_disjoint(cx: &mut Cx, left: &dyn Shape, right: &dyn Shape) -> Result<bool> {
288    let Some(left) = left.as_any().downcast_ref::<TableShape>() else {
289        return Ok(false);
290    };
291    let Some(right) = right.as_any().downcast_ref::<TableShape>() else {
292        return Ok(false);
293    };
294    if !matches!(left.extra(), TableExtraPolicy::Reject)
295        || !matches!(right.extra(), TableExtraPolicy::Reject)
296    {
297        return Ok(false);
298    }
299
300    for left_field in left.fields().iter().filter(|field| field.required) {
301        let Some(right_field) = right
302            .fields()
303            .iter()
304            .find(|field| field.required && field.key == left_field.key)
305        else {
306            continue;
307        };
308        let relation = relate_shapes(
309            cx,
310            left_field.shape.as_ref(),
311            right_field.shape.as_ref(),
312            &[],
313        )?;
314        if relation.proven && relation.kind == ShapeRelationKind::Disjoint {
315            return Ok(true);
316        }
317    }
318    Ok(false)
319}