Skip to main content

sim_shape/hooks/
hooked.rs

1//! `HookedShape`: a shape that wraps an inner shape and runs an ordered list of
2//! match hooks around each check to adjust acceptance, score, and diagnostics.
3
4use std::sync::Arc;
5
6use sim_kernel::{Cx, Diagnostic, Expr, Result, ShapeRef, Value};
7
8use crate::{
9    MatchScore, Shape, ShapeDoc, ShapeMatch,
10    hooks::types::{
11        MatchHook, MatchHookContext, MatchHookDecision, MatchHookKind, MatchHookPhase,
12        MatchHookTargetKind,
13    },
14};
15
16/// Shape wrapper that runs neutral match hooks around an inner shape.
17///
18/// `HookedShape` keeps the kernel `Shape` trait unchanged. Mark hooks observe,
19/// accept hooks can repair rejections, discard hooks can veto acceptances, and
20/// annotate hooks can adjust score and diagnostics without changing acceptance.
21///
22/// ```rust
23/// # use std::sync::Arc;
24/// # use sim_kernel::{Cx, DefaultFactory, Expr, NoopEvalPolicy};
25/// # use sim_shape::{AnyShape, HookedShape, Shape, TraceMarkHook};
26/// # let mut cx = Cx::new(Arc::new(NoopEvalPolicy), Arc::new(DefaultFactory));
27/// let shape = HookedShape::new(Arc::new(AnyShape), vec![Arc::new(TraceMarkHook)]);
28/// let matched = shape.check_expr(&mut cx, &Expr::Bool(true)).unwrap();
29///
30/// assert!(matched.accepted);
31/// assert!(matched
32///     .diagnostics
33///     .iter()
34///     .any(|diagnostic| diagnostic.message.starts_with("shape-hook:mark")));
35/// ```
36pub struct HookedShape {
37    inner: Arc<dyn Shape>,
38    hooks: Vec<Arc<dyn MatchHook>>,
39}
40
41impl HookedShape {
42    /// Wrap an inner shape with an ordered list of match hooks.
43    pub fn new(inner: Arc<dyn Shape>, hooks: Vec<Arc<dyn MatchHook>>) -> Self {
44        Self { inner, hooks }
45    }
46
47    /// Borrow the wrapped inner shape.
48    pub fn inner(&self) -> &Arc<dyn Shape> {
49        &self.inner
50    }
51
52    /// Borrow the hooks run around the inner shape, in registration order.
53    pub fn hooks(&self) -> &[Arc<dyn MatchHook>] {
54        &self.hooks
55    }
56}
57
58impl Shape for HookedShape {
59    fn parents(&self, cx: &mut Cx) -> Result<Vec<ShapeRef>> {
60        self.inner.parents(cx)
61    }
62
63    fn is_effectful(&self) -> bool {
64        self.inner.is_effectful()
65    }
66
67    fn is_total(&self) -> bool {
68        self.inner.is_total()
69    }
70
71    fn is_subshape_of(&self, cx: &mut Cx, parent: &dyn Shape) -> Result<Option<bool>> {
72        self.inner.is_subshape_of(cx, parent)
73    }
74
75    fn check_value(&self, cx: &mut Cx, value: Value) -> Result<ShapeMatch> {
76        let label = self.inner.describe(cx)?.name;
77        let before = self.run_marks(
78            cx,
79            MatchHookTargetKind::Value,
80            MatchHookPhase::BeforeInner,
81            &label,
82            None,
83        )?;
84        let matched = self.inner.check_value(cx, value)?;
85        self.finish_match(cx, MatchHookTargetKind::Value, label, matched, before)
86    }
87
88    fn check_expr(&self, cx: &mut Cx, expr: &Expr) -> Result<ShapeMatch> {
89        let label = self.inner.describe(cx)?.name;
90        let before = self.run_marks(
91            cx,
92            MatchHookTargetKind::Expr,
93            MatchHookPhase::BeforeInner,
94            &label,
95            None,
96        )?;
97        let matched = self.inner.check_expr(cx, expr)?;
98        self.finish_match(cx, MatchHookTargetKind::Expr, label, matched, before)
99    }
100
101    fn describe(&self, cx: &mut Cx) -> Result<ShapeDoc> {
102        let mut doc = ShapeDoc::new("hooked shape").with_detail(self.inner.describe(cx)?.name);
103        for hook in &self.hooks {
104            doc = doc.with_detail(hook.symbol().to_string());
105        }
106        Ok(doc)
107    }
108}
109
110impl HookedShape {
111    fn finish_match(
112        &self,
113        cx: &mut Cx,
114        target_kind: MatchHookTargetKind,
115        label: String,
116        mut matched: ShapeMatch,
117        before_marks: Vec<Diagnostic>,
118    ) -> Result<ShapeMatch> {
119        matched.diagnostics.extend(before_marks);
120        let after_marks = self.run_marks(
121            cx,
122            target_kind,
123            MatchHookPhase::AfterInner,
124            &label,
125            Some(&matched),
126        )?;
127        matched.diagnostics.extend(after_marks);
128
129        if !matched.accepted {
130            matched = self.run_accept_hooks(cx, target_kind, &label, matched)?;
131        }
132        if matched.accepted {
133            matched = self.run_discard_hooks(cx, target_kind, &label, matched)?;
134        }
135        self.run_annotate_hooks(cx, target_kind, &label, matched)
136    }
137
138    fn run_marks(
139        &self,
140        cx: &mut Cx,
141        target_kind: MatchHookTargetKind,
142        phase: MatchHookPhase,
143        shape_label: &str,
144        current: Option<&ShapeMatch>,
145    ) -> Result<Vec<Diagnostic>> {
146        let mut diagnostics = Vec::new();
147        for (hook_index, hook) in self.hooks.iter().enumerate() {
148            if hook.kind() != MatchHookKind::Mark {
149                continue;
150            }
151            let ctx = MatchHookContext {
152                hook_index,
153                phase,
154                target_kind,
155                shape_label: shape_label.to_owned(),
156            };
157            if let MatchHookDecision::Mark { message } = hook.apply(cx, &ctx, current)? {
158                diagnostics.push(Diagnostic::info(format!("shape-hook:mark {message}")));
159            }
160        }
161        Ok(diagnostics)
162    }
163
164    fn run_accept_hooks(
165        &self,
166        cx: &mut Cx,
167        target_kind: MatchHookTargetKind,
168        shape_label: &str,
169        mut matched: ShapeMatch,
170    ) -> Result<ShapeMatch> {
171        for (hook_index, hook) in self.hooks.iter().enumerate() {
172            if hook.kind() != MatchHookKind::Accept {
173                continue;
174            }
175            let ctx = MatchHookContext {
176                hook_index,
177                phase: MatchHookPhase::AfterInner,
178                target_kind,
179                shape_label: shape_label.to_owned(),
180            };
181            if let MatchHookDecision::Accept { reason, score } =
182                hook.apply(cx, &ctx, Some(&matched))?
183            {
184                matched.accepted = true;
185                if matched.score == MatchScore::reject() {
186                    matched.score = score;
187                }
188                matched
189                    .diagnostics
190                    .push(Diagnostic::info(format!("shape-hook:accept {reason}")));
191            }
192        }
193        Ok(matched)
194    }
195
196    fn run_discard_hooks(
197        &self,
198        cx: &mut Cx,
199        target_kind: MatchHookTargetKind,
200        shape_label: &str,
201        mut matched: ShapeMatch,
202    ) -> Result<ShapeMatch> {
203        for (hook_index, hook) in self.hooks.iter().enumerate() {
204            if hook.kind() != MatchHookKind::Discard {
205                continue;
206            }
207            let ctx = MatchHookContext {
208                hook_index,
209                phase: MatchHookPhase::AfterInner,
210                target_kind,
211                shape_label: shape_label.to_owned(),
212            };
213            if let MatchHookDecision::Discard { reason } = hook.apply(cx, &ctx, Some(&matched))? {
214                matched.accepted = false;
215                matched.score = MatchScore::reject();
216                matched
217                    .diagnostics
218                    .push(Diagnostic::error(format!("shape-hook:discard {reason}")));
219                break;
220            }
221        }
222        Ok(matched)
223    }
224
225    fn run_annotate_hooks(
226        &self,
227        cx: &mut Cx,
228        target_kind: MatchHookTargetKind,
229        shape_label: &str,
230        mut matched: ShapeMatch,
231    ) -> Result<ShapeMatch> {
232        for (hook_index, hook) in self.hooks.iter().enumerate() {
233            if hook.kind() != MatchHookKind::Annotate {
234                continue;
235            }
236            let ctx = MatchHookContext {
237                hook_index,
238                phase: MatchHookPhase::AfterInner,
239                target_kind,
240                shape_label: shape_label.to_owned(),
241            };
242            if let MatchHookDecision::Annotate {
243                message,
244                score_delta,
245            } = hook.apply(cx, &ctx, Some(&matched))?
246            {
247                matched.score += MatchScore::exact(score_delta);
248                matched
249                    .diagnostics
250                    .push(Diagnostic::info(format!("shape-hook:annotate {message}")));
251            }
252        }
253        Ok(matched)
254    }
255}