1use std::sync::Arc;
5
6use sim_kernel::{Cx, Diagnostic, Expr, Result, ShapeRef, Value};
7
8use crate::{
9 MatchScore, Shape, ShapeDoc, ShapeMatch,
10 hooks::types::{
11 MatchHook, MatchHookContext, MatchHookDecision, MatchHookKind, MatchHookPhase,
12 MatchHookTargetKind,
13 },
14};
15
16pub struct HookedShape {
37 inner: Arc<dyn Shape>,
38 hooks: Vec<Arc<dyn MatchHook>>,
39}
40
41impl HookedShape {
42 pub fn new(inner: Arc<dyn Shape>, hooks: Vec<Arc<dyn MatchHook>>) -> Self {
44 Self { inner, hooks }
45 }
46
47 pub fn inner(&self) -> &Arc<dyn Shape> {
49 &self.inner
50 }
51
52 pub fn hooks(&self) -> &[Arc<dyn MatchHook>] {
54 &self.hooks
55 }
56}
57
58impl Shape for HookedShape {
59 fn parents(&self, cx: &mut Cx) -> Result<Vec<ShapeRef>> {
60 self.inner.parents(cx)
61 }
62
63 fn is_effectful(&self) -> bool {
64 self.inner.is_effectful()
65 }
66
67 fn is_total(&self) -> bool {
68 self.inner.is_total()
69 }
70
71 fn is_subshape_of(&self, cx: &mut Cx, parent: &dyn Shape) -> Result<Option<bool>> {
72 self.inner.is_subshape_of(cx, parent)
73 }
74
75 fn check_value(&self, cx: &mut Cx, value: Value) -> Result<ShapeMatch> {
76 let label = self.inner.describe(cx)?.name;
77 let before = self.run_marks(
78 cx,
79 MatchHookTargetKind::Value,
80 MatchHookPhase::BeforeInner,
81 &label,
82 None,
83 )?;
84 let matched = self.inner.check_value(cx, value)?;
85 self.finish_match(cx, MatchHookTargetKind::Value, label, matched, before)
86 }
87
88 fn check_expr(&self, cx: &mut Cx, expr: &Expr) -> Result<ShapeMatch> {
89 let label = self.inner.describe(cx)?.name;
90 let before = self.run_marks(
91 cx,
92 MatchHookTargetKind::Expr,
93 MatchHookPhase::BeforeInner,
94 &label,
95 None,
96 )?;
97 let matched = self.inner.check_expr(cx, expr)?;
98 self.finish_match(cx, MatchHookTargetKind::Expr, label, matched, before)
99 }
100
101 fn describe(&self, cx: &mut Cx) -> Result<ShapeDoc> {
102 let mut doc = ShapeDoc::new("hooked shape").with_detail(self.inner.describe(cx)?.name);
103 for hook in &self.hooks {
104 doc = doc.with_detail(hook.symbol().to_string());
105 }
106 Ok(doc)
107 }
108}
109
110impl HookedShape {
111 fn finish_match(
112 &self,
113 cx: &mut Cx,
114 target_kind: MatchHookTargetKind,
115 label: String,
116 mut matched: ShapeMatch,
117 before_marks: Vec<Diagnostic>,
118 ) -> Result<ShapeMatch> {
119 matched.diagnostics.extend(before_marks);
120 let after_marks = self.run_marks(
121 cx,
122 target_kind,
123 MatchHookPhase::AfterInner,
124 &label,
125 Some(&matched),
126 )?;
127 matched.diagnostics.extend(after_marks);
128
129 if !matched.accepted {
130 matched = self.run_accept_hooks(cx, target_kind, &label, matched)?;
131 }
132 if matched.accepted {
133 matched = self.run_discard_hooks(cx, target_kind, &label, matched)?;
134 }
135 self.run_annotate_hooks(cx, target_kind, &label, matched)
136 }
137
138 fn run_marks(
139 &self,
140 cx: &mut Cx,
141 target_kind: MatchHookTargetKind,
142 phase: MatchHookPhase,
143 shape_label: &str,
144 current: Option<&ShapeMatch>,
145 ) -> Result<Vec<Diagnostic>> {
146 let mut diagnostics = Vec::new();
147 for (hook_index, hook) in self.hooks.iter().enumerate() {
148 if hook.kind() != MatchHookKind::Mark {
149 continue;
150 }
151 let ctx = MatchHookContext {
152 hook_index,
153 phase,
154 target_kind,
155 shape_label: shape_label.to_owned(),
156 };
157 if let MatchHookDecision::Mark { message } = hook.apply(cx, &ctx, current)? {
158 diagnostics.push(Diagnostic::info(format!("shape-hook:mark {message}")));
159 }
160 }
161 Ok(diagnostics)
162 }
163
164 fn run_accept_hooks(
165 &self,
166 cx: &mut Cx,
167 target_kind: MatchHookTargetKind,
168 shape_label: &str,
169 mut matched: ShapeMatch,
170 ) -> Result<ShapeMatch> {
171 for (hook_index, hook) in self.hooks.iter().enumerate() {
172 if hook.kind() != MatchHookKind::Accept {
173 continue;
174 }
175 let ctx = MatchHookContext {
176 hook_index,
177 phase: MatchHookPhase::AfterInner,
178 target_kind,
179 shape_label: shape_label.to_owned(),
180 };
181 if let MatchHookDecision::Accept { reason, score } =
182 hook.apply(cx, &ctx, Some(&matched))?
183 {
184 matched.accepted = true;
185 if matched.score == MatchScore::reject() {
186 matched.score = score;
187 }
188 matched
189 .diagnostics
190 .push(Diagnostic::info(format!("shape-hook:accept {reason}")));
191 }
192 }
193 Ok(matched)
194 }
195
196 fn run_discard_hooks(
197 &self,
198 cx: &mut Cx,
199 target_kind: MatchHookTargetKind,
200 shape_label: &str,
201 mut matched: ShapeMatch,
202 ) -> Result<ShapeMatch> {
203 for (hook_index, hook) in self.hooks.iter().enumerate() {
204 if hook.kind() != MatchHookKind::Discard {
205 continue;
206 }
207 let ctx = MatchHookContext {
208 hook_index,
209 phase: MatchHookPhase::AfterInner,
210 target_kind,
211 shape_label: shape_label.to_owned(),
212 };
213 if let MatchHookDecision::Discard { reason } = hook.apply(cx, &ctx, Some(&matched))? {
214 matched.accepted = false;
215 matched.score = MatchScore::reject();
216 matched
217 .diagnostics
218 .push(Diagnostic::error(format!("shape-hook:discard {reason}")));
219 break;
220 }
221 }
222 Ok(matched)
223 }
224
225 fn run_annotate_hooks(
226 &self,
227 cx: &mut Cx,
228 target_kind: MatchHookTargetKind,
229 shape_label: &str,
230 mut matched: ShapeMatch,
231 ) -> Result<ShapeMatch> {
232 for (hook_index, hook) in self.hooks.iter().enumerate() {
233 if hook.kind() != MatchHookKind::Annotate {
234 continue;
235 }
236 let ctx = MatchHookContext {
237 hook_index,
238 phase: MatchHookPhase::AfterInner,
239 target_kind,
240 shape_label: shape_label.to_owned(),
241 };
242 if let MatchHookDecision::Annotate {
243 message,
244 score_delta,
245 } = hook.apply(cx, &ctx, Some(&matched))?
246 {
247 matched.score += MatchScore::exact(score_delta);
248 matched
249 .diagnostics
250 .push(Diagnostic::info(format!("shape-hook:annotate {message}")));
251 }
252 }
253 Ok(matched)
254 }
255}