Skip to main content

sim_lib_pattern/
shapes.rs

1use sim_kernel::{
2    Cx, Expr, MatchScore, Result, Shape, ShapeBindings, ShapeDoc, ShapeMatch, Symbol, Value,
3    shape_is_subshape_of,
4};
5
6use crate::{PatternField, tagged_value};
7
8/// A kernel [`Shape`] that matches any [`TaggedValue`](crate::TaggedValue) of a
9/// given ADT.
10///
11/// The kernel defines the [`Shape`] match/binding protocol; `AdtShape` is the
12/// pattern organ's concrete implementation that accepts a value when one of the
13/// ADT's [`VariantShape`]s accepts it, binding that variant's captures.
14pub struct AdtShape {
15    adt: Symbol,
16    variants: Vec<VariantShape>,
17}
18
19impl AdtShape {
20    /// Builds an ADT shape from its name and per-variant shapes.
21    pub fn new(adt: Symbol, variants: Vec<VariantShape>) -> Self {
22        Self { adt, variants }
23    }
24
25    /// Returns the ADT name symbol.
26    pub fn adt(&self) -> &Symbol {
27        &self.adt
28    }
29
30    /// Returns the per-variant shapes this ADT shape dispatches over.
31    pub fn variants(&self) -> &[VariantShape] {
32        &self.variants
33    }
34}
35
36impl Shape for AdtShape {
37    fn symbol(&self) -> Option<Symbol> {
38        Some(Symbol::qualified("pattern-adt", self.adt.to_string()))
39    }
40
41    fn check_value(&self, cx: &mut Cx, value: Value) -> Result<ShapeMatch> {
42        let Some(tagged) = tagged_value(&value) else {
43            return Ok(ShapeMatch::reject("expected tagged ADT value"));
44        };
45        if tagged.adt() != &self.adt {
46            return Ok(ShapeMatch::reject(format!(
47                "expected ADT {}, got {}",
48                self.adt,
49                tagged.adt()
50            )));
51        }
52        let mut diagnostics = Vec::new();
53        for variant in &self.variants {
54            let matched = variant.check_value(cx, value.clone())?;
55            if matched.accepted {
56                return Ok(matched);
57            }
58            diagnostics.extend(matched.diagnostics);
59        }
60        Ok(ShapeMatch {
61            accepted: false,
62            captures: ShapeBindings::new(),
63            score: MatchScore::reject(),
64            diagnostics,
65        })
66    }
67
68    fn check_expr(&self, cx: &mut Cx, expr: &Expr) -> Result<ShapeMatch> {
69        let mut diagnostics = Vec::new();
70        for variant in &self.variants {
71            let matched = variant.check_expr(cx, expr)?;
72            if matched.accepted {
73                return Ok(matched);
74            }
75            diagnostics.extend(matched.diagnostics);
76        }
77        Ok(ShapeMatch {
78            accepted: false,
79            captures: ShapeBindings::new(),
80            score: MatchScore::reject(),
81            diagnostics,
82        })
83    }
84
85    fn describe(&self, _cx: &mut Cx) -> Result<ShapeDoc> {
86        Ok(ShapeDoc::new(format!("ADT {}", self.adt)))
87    }
88}
89
90/// A kernel [`Shape`] that matches one specific ADT variant by tag and fields.
91///
92/// `VariantShape` accepts a [`TaggedValue`](crate::TaggedValue) whose ADT and
93/// variant match and whose fields each pass their [`PatternField`] shape,
94/// accumulating their captures. It also matches the equivalent constructor
95/// [`Expr`], and reports subshape relationships against sibling variants and
96/// the enclosing [`AdtShape`].
97#[derive(Clone)]
98pub struct VariantShape {
99    adt: Symbol,
100    variant: Symbol,
101    fields: Vec<PatternField>,
102}
103
104impl VariantShape {
105    /// Builds a variant shape from its ADT name, variant tag, and fields.
106    pub fn new(adt: Symbol, variant: Symbol, fields: Vec<PatternField>) -> Self {
107        Self {
108            adt,
109            variant,
110            fields,
111        }
112    }
113
114    /// Returns the owning ADT name symbol.
115    pub fn adt(&self) -> &Symbol {
116        &self.adt
117    }
118
119    /// Returns the variant tag symbol.
120    pub fn variant(&self) -> &Symbol {
121        &self.variant
122    }
123
124    /// Returns the variant fields in declaration order.
125    pub fn fields(&self) -> &[PatternField] {
126        &self.fields
127    }
128
129    fn check_field_values(&self, cx: &mut Cx, fields: &[(Symbol, Value)]) -> Result<ShapeMatch> {
130        if fields.len() != self.fields.len() {
131            return Ok(ShapeMatch::reject(format!(
132                "variant {} expected {} fields, got {}",
133                self.variant,
134                self.fields.len(),
135                fields.len()
136            )));
137        }
138
139        let mut out = ShapeMatch::accept(MatchScore::exact(30));
140        for (field, (actual_name, value)) in self.fields.iter().zip(fields.iter()) {
141            if field.name() != actual_name {
142                return Ok(ShapeMatch::reject(format!(
143                    "expected field {}, got {}",
144                    field.name(),
145                    actual_name
146                )));
147            }
148            let matched = field.shape().check_value(cx, value.clone())?;
149            if !matched.accepted {
150                return Ok(matched);
151            }
152            out.captures.extend(matched.captures);
153            out.score += matched.score;
154        }
155        Ok(out)
156    }
157
158    fn check_field_exprs(&self, cx: &mut Cx, args: &[Expr]) -> Result<ShapeMatch> {
159        if args.len() != self.fields.len() {
160            return Ok(ShapeMatch::reject(format!(
161                "variant {} expected {} fields, got {}",
162                self.variant,
163                self.fields.len(),
164                args.len()
165            )));
166        }
167
168        let mut out = ShapeMatch::accept(MatchScore::exact(25));
169        for (field, expr) in self.fields.iter().zip(args.iter()) {
170            let matched = field.shape().check_expr(cx, expr)?;
171            if !matched.accepted {
172                return Ok(matched);
173            }
174            out.captures.extend(matched.captures);
175            out.score += matched.score;
176        }
177        Ok(out)
178    }
179}
180
181impl Shape for VariantShape {
182    fn symbol(&self) -> Option<Symbol> {
183        Some(Symbol::qualified(
184            "pattern-variant",
185            self.variant.to_string(),
186        ))
187    }
188
189    fn is_subshape_of(&self, cx: &mut Cx, parent: &dyn Shape) -> Result<Option<bool>> {
190        if let Some(parent) = parent.as_any().downcast_ref::<Self>() {
191            return Ok(Some(
192                self.adt == parent.adt && self.variant == parent.variant,
193            ));
194        }
195        if let Some(parent) = parent.as_any().downcast_ref::<AdtShape>() {
196            return Ok(Some(self.adt == parent.adt));
197        }
198        shape_is_subshape_of(cx, self.fields_parent().as_ref(), parent).map(Some)
199    }
200
201    fn check_value(&self, cx: &mut Cx, value: Value) -> Result<ShapeMatch> {
202        let Some(tagged) = tagged_value(&value) else {
203            return Ok(ShapeMatch::reject("expected tagged ADT value"));
204        };
205        if tagged.adt() != &self.adt {
206            return Ok(ShapeMatch::reject(format!(
207                "expected ADT {}, got {}",
208                self.adt,
209                tagged.adt()
210            )));
211        }
212        if tagged.variant() != &self.variant {
213            return Ok(ShapeMatch::reject(format!(
214                "expected variant {}, got {}",
215                self.variant,
216                tagged.variant()
217            )));
218        }
219        self.check_field_values(cx, tagged.fields())
220    }
221
222    fn check_expr(&self, cx: &mut Cx, expr: &Expr) -> Result<ShapeMatch> {
223        let (operator, args) = match expr {
224            Expr::Call { operator, args } => (operator.as_ref(), args.as_slice()),
225            Expr::List(items) if !items.is_empty() => (&items[0], &items[1..]),
226            _ => {
227                return Ok(ShapeMatch::reject(
228                    "expected variant constructor expression",
229                ));
230            }
231        };
232        let Expr::Symbol(symbol) = operator else {
233            return Ok(ShapeMatch::reject("expected symbolic variant constructor"));
234        };
235        if symbol != &self.variant {
236            return Ok(ShapeMatch::reject(format!(
237                "expected variant {}, got {}",
238                self.variant, symbol
239            )));
240        }
241        self.check_field_exprs(cx, args)
242    }
243
244    fn describe(&self, _cx: &mut Cx) -> Result<ShapeDoc> {
245        Ok(ShapeDoc::new(format!("variant {}", self.variant)))
246    }
247}
248
249impl VariantShape {
250    fn fields_parent(&self) -> std::sync::Arc<dyn Shape> {
251        std::sync::Arc::new(AdtShape::new(self.adt.clone(), vec![self.clone()]))
252    }
253}