1use std::collections::HashSet;
2use std::fmt;
3
4use crate::cps_ir::{
5 CpsContinuation, CpsContinuationId, CpsFunction, CpsHandlerId, CpsModule, CpsStmt,
6 CpsTerminator, CpsValueId,
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum CpsValidateError {
11 MissingEntry {
12 function: String,
13 entry: CpsContinuationId,
14 },
15 DuplicateContinuation {
16 function: String,
17 id: CpsContinuationId,
18 },
19 MissingContinuation {
20 function: String,
21 id: CpsContinuationId,
22 },
23 DuplicateHandler {
24 function: String,
25 id: CpsHandlerId,
26 },
27 MissingHandler {
28 function: String,
29 id: CpsHandlerId,
30 },
31 ContinuationArityMismatch {
32 function: String,
33 id: CpsContinuationId,
34 expected: usize,
35 actual: usize,
36 },
37 MissingValue {
38 function: String,
39 id: CpsValueId,
40 },
41}
42
43impl fmt::Display for CpsValidateError {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 CpsValidateError::MissingEntry { function, entry } => {
47 write!(
48 f,
49 "CPS function {function} has no entry continuation {entry:?}"
50 )
51 }
52 CpsValidateError::DuplicateContinuation { function, id } => {
53 write!(
54 f,
55 "CPS function {function} defines continuation {id:?} twice"
56 )
57 }
58 CpsValidateError::MissingContinuation { function, id } => {
59 write!(
60 f,
61 "CPS function {function} references missing continuation {id:?}"
62 )
63 }
64 CpsValidateError::DuplicateHandler { function, id } => {
65 write!(f, "CPS function {function} defines handler {id:?} twice")
66 }
67 CpsValidateError::MissingHandler { function, id } => {
68 write!(
69 f,
70 "CPS function {function} references missing handler {id:?}"
71 )
72 }
73 CpsValidateError::ContinuationArityMismatch {
74 function,
75 id,
76 expected,
77 actual,
78 } => write!(
79 f,
80 "CPS function {function} calls continuation {id:?} with {actual} arguments, expected {expected}"
81 ),
82 CpsValidateError::MissingValue { function, id } => {
83 write!(f, "CPS function {function} references missing value {id:?}")
84 }
85 }
86 }
87}
88
89impl std::error::Error for CpsValidateError {}
90
91pub fn validate_cps_module(module: &CpsModule) -> Result<(), CpsValidateError> {
92 for function in module.functions.iter().chain(&module.roots) {
93 validate_function(function)?;
94 }
95 Ok(())
96}
97
98fn validate_function(function: &CpsFunction) -> Result<(), CpsValidateError> {
99 let mut continuation_ids = HashSet::new();
100 for continuation in &function.continuations {
101 if !continuation_ids.insert(continuation.id) {
102 return Err(CpsValidateError::DuplicateContinuation {
103 function: function.name.clone(),
104 id: continuation.id,
105 });
106 }
107 }
108 if !continuation_ids.contains(&function.entry) {
109 return Err(CpsValidateError::MissingEntry {
110 function: function.name.clone(),
111 entry: function.entry,
112 });
113 }
114
115 let mut handler_ids = HashSet::new();
116 for handler in &function.handlers {
117 if !handler_ids.insert(handler.id) {
118 return Err(CpsValidateError::DuplicateHandler {
119 function: function.name.clone(),
120 id: handler.id,
121 });
122 }
123 for arm in &handler.arms {
124 require_continuation(function, &continuation_ids, arm.entry)?;
125 }
126 }
127
128 let defined_values = function_defined_values(function);
129 for continuation in &function.continuations {
130 validate_continuation(
131 function,
132 continuation,
133 &continuation_ids,
134 &handler_ids,
135 &defined_values,
136 )?;
137 }
138 Ok(())
139}
140
141fn function_defined_values(function: &CpsFunction) -> HashSet<CpsValueId> {
142 let mut values = function.params.iter().copied().collect::<HashSet<_>>();
143 for continuation in &function.continuations {
144 values.extend(continuation.params.iter().copied());
145 for stmt in &continuation.stmts {
146 match stmt {
147 CpsStmt::Literal { dest, .. }
148 | CpsStmt::FreshGuard { dest, .. }
149 | CpsStmt::PeekGuard { dest }
150 | CpsStmt::FindGuard { dest, .. }
151 | CpsStmt::MakeThunk { dest, .. }
152 | CpsStmt::AddThunkBoundary { dest, .. }
153 | CpsStmt::MakeClosure { dest, .. }
154 | CpsStmt::MakeRecursiveClosure { dest, .. }
155 | CpsStmt::ForceThunk { dest, .. }
156 | CpsStmt::Tuple { dest, .. }
157 | CpsStmt::Record { dest, .. }
158 | CpsStmt::RecordWithoutFields { dest, .. }
159 | CpsStmt::Variant { dest, .. }
160 | CpsStmt::Select { dest, .. }
161 | CpsStmt::SelectWithDefault { dest, .. }
162 | CpsStmt::RecordHasField { dest, .. }
163 | CpsStmt::TupleGet { dest, .. }
164 | CpsStmt::VariantTagEq { dest, .. }
165 | CpsStmt::VariantPayload { dest, .. }
166 | CpsStmt::Primitive { dest, .. }
167 | CpsStmt::DirectCall { dest, .. }
168 | CpsStmt::ApplyClosure { dest, .. }
169 | CpsStmt::CloneContinuation { dest, .. }
170 | CpsStmt::Resume { dest, .. }
171 | CpsStmt::ResumeWithHandler { dest, .. } => {
172 values.insert(*dest);
173 }
174 CpsStmt::InstallHandler { .. } | CpsStmt::UninstallHandler { .. } => {}
175 }
176 }
177 }
178 values
179}
180
181fn validate_continuation(
182 function: &CpsFunction,
183 continuation: &CpsContinuation,
184 continuation_ids: &HashSet<CpsContinuationId>,
185 handler_ids: &HashSet<CpsHandlerId>,
186 defined_values: &HashSet<CpsValueId>,
187) -> Result<(), CpsValidateError> {
188 let mut values = continuation.params.iter().copied().collect::<HashSet<_>>();
189 for capture in &continuation.captures {
190 require_value(function, defined_values, *capture)?;
191 values.insert(*capture);
192 }
193
194 for stmt in &continuation.stmts {
195 match stmt {
196 CpsStmt::Literal { dest, .. } => {
197 values.insert(*dest);
198 }
199 CpsStmt::FreshGuard { dest, .. } | CpsStmt::PeekGuard { dest } => {
200 values.insert(*dest);
201 }
202 CpsStmt::FindGuard { dest, guard } => {
203 require_value(function, &values, *guard)?;
204 values.insert(*dest);
205 }
206 CpsStmt::MakeThunk { dest, entry } => {
207 require_continuation(function, continuation_ids, *entry)?;
208 values.insert(*dest);
209 }
210 CpsStmt::AddThunkBoundary {
211 dest, thunk, guard, ..
212 } => {
213 require_value(function, &values, *thunk)?;
214 require_value(function, &values, *guard)?;
215 values.insert(*dest);
216 }
217 CpsStmt::MakeClosure { dest, entry } => {
218 require_continuation(function, continuation_ids, *entry)?;
219 values.insert(*dest);
220 }
221 CpsStmt::MakeRecursiveClosure { dest, entry } => {
222 require_continuation(function, continuation_ids, *entry)?;
223 values.insert(*dest);
224 }
225 CpsStmt::ForceThunk { dest, thunk } => {
226 require_value(function, &values, *thunk)?;
227 values.insert(*dest);
228 }
229 CpsStmt::Tuple { dest, items } => {
230 for item in items {
231 require_value(function, &values, *item)?;
232 }
233 values.insert(*dest);
234 }
235 CpsStmt::Record { dest, base, fields } => {
236 if let Some(base) = base {
237 require_value(function, &values, *base)?;
238 }
239 for field in fields {
240 require_value(function, &values, field.value)?;
241 }
242 values.insert(*dest);
243 }
244 CpsStmt::RecordWithoutFields { dest, base, .. } => {
245 require_value(function, &values, *base)?;
246 values.insert(*dest);
247 }
248 CpsStmt::Variant { dest, value, .. } => {
249 if let Some(value) = value {
250 require_value(function, &values, *value)?;
251 }
252 values.insert(*dest);
253 }
254 CpsStmt::Select { dest, base, .. } => {
255 require_value(function, &values, *base)?;
256 values.insert(*dest);
257 }
258 CpsStmt::SelectWithDefault {
259 dest,
260 base,
261 default,
262 ..
263 } => {
264 require_value(function, &values, *base)?;
265 require_value(function, &values, *default)?;
266 values.insert(*dest);
267 }
268 CpsStmt::RecordHasField { dest, base, .. } => {
269 require_value(function, &values, *base)?;
270 values.insert(*dest);
271 }
272 CpsStmt::TupleGet { dest, tuple, .. } => {
273 require_value(function, &values, *tuple)?;
274 values.insert(*dest);
275 }
276 CpsStmt::VariantTagEq { dest, variant, .. }
277 | CpsStmt::VariantPayload { dest, variant, .. } => {
278 require_value(function, &values, *variant)?;
279 values.insert(*dest);
280 }
281 CpsStmt::Primitive { dest, args, .. } | CpsStmt::DirectCall { dest, args, .. } => {
282 for arg in args {
283 require_value(function, &values, *arg)?;
284 }
285 values.insert(*dest);
286 }
287 CpsStmt::ApplyClosure { dest, closure, arg } => {
288 require_value(function, &values, *closure)?;
289 require_value(function, &values, *arg)?;
290 values.insert(*dest);
291 }
292 CpsStmt::CloneContinuation { dest, source } => {
293 require_value(function, &values, *source)?;
294 values.insert(*dest);
295 }
296 CpsStmt::Resume {
297 dest,
298 resumption,
299 arg,
300 } => {
301 require_value(function, &values, *resumption)?;
302 require_value(function, &values, *arg)?;
303 values.insert(*dest);
304 }
305 CpsStmt::ResumeWithHandler {
306 dest,
307 resumption,
308 arg,
309 envs,
310 ..
311 } => {
312 require_value(function, &values, *resumption)?;
313 require_value(function, &values, *arg)?;
314 for env in envs {
315 for value in &env.values {
316 require_value(function, &values, *value)?;
317 }
318 }
319 values.insert(*dest);
320 }
321 CpsStmt::InstallHandler { envs, .. } => {
322 for env in envs {
323 for value in &env.values {
324 require_value(function, &values, *value)?;
325 }
326 }
327 }
328 CpsStmt::UninstallHandler { .. } => {}
329 }
330 }
331
332 match &continuation.terminator {
333 CpsTerminator::Return(value) => require_value(function, &values, *value),
334 CpsTerminator::Continue { target, args } => {
335 let target_cont = function
336 .continuations
337 .iter()
338 .find(|continuation| continuation.id == *target)
339 .ok_or_else(|| CpsValidateError::MissingContinuation {
340 function: function.name.clone(),
341 id: *target,
342 })?;
343 if target_cont.params.len() != args.len() {
344 return Err(CpsValidateError::ContinuationArityMismatch {
345 function: function.name.clone(),
346 id: *target,
347 expected: target_cont.params.len(),
348 actual: args.len(),
349 });
350 }
351 for arg in args {
352 require_value(function, &values, *arg)?;
353 }
354 Ok(())
355 }
356 CpsTerminator::Branch {
357 cond,
358 then_cont,
359 else_cont,
360 } => {
361 require_value(function, &values, *cond)?;
362 require_continuation(function, continuation_ids, *then_cont)?;
363 require_continuation(function, continuation_ids, *else_cont)
364 }
365 CpsTerminator::Perform {
366 payload,
367 resume,
368 blocked,
369 handler,
370 ..
371 } => {
372 require_value(function, &values, *payload)?;
373 if let Some(blocked) = blocked {
374 require_value(function, &values, *blocked)?;
375 }
376 require_continuation(function, continuation_ids, *resume)?;
377 if handler.0 == usize::MAX {
378 Ok(())
379 } else {
380 require_handler(function, handler_ids, *handler)
381 }
382 }
383 CpsTerminator::EffectfulCall { args, resume, .. } => {
384 for arg in args {
385 require_value(function, &values, *arg)?;
386 }
387 require_continuation(function, continuation_ids, *resume)
388 }
389 CpsTerminator::EffectfulApply {
390 closure,
391 arg,
392 resume,
393 } => {
394 require_value(function, &values, *closure)?;
395 require_value(function, &values, *arg)?;
396 require_continuation(function, continuation_ids, *resume)
397 }
398 CpsTerminator::EffectfulForce { thunk, resume } => {
399 require_value(function, &values, *thunk)?;
400 require_continuation(function, continuation_ids, *resume)
401 }
402 }
403}
404
405fn require_value(
406 function: &CpsFunction,
407 values: &HashSet<CpsValueId>,
408 id: CpsValueId,
409) -> Result<(), CpsValidateError> {
410 if values.contains(&id) {
411 Ok(())
412 } else {
413 Err(CpsValidateError::MissingValue {
414 function: function.name.clone(),
415 id,
416 })
417 }
418}
419
420fn require_continuation(
421 function: &CpsFunction,
422 continuation_ids: &HashSet<CpsContinuationId>,
423 id: CpsContinuationId,
424) -> Result<(), CpsValidateError> {
425 if continuation_ids.contains(&id) {
426 Ok(())
427 } else {
428 Err(CpsValidateError::MissingContinuation {
429 function: function.name.clone(),
430 id,
431 })
432 }
433}
434
435fn require_handler(
436 function: &CpsFunction,
437 handler_ids: &HashSet<CpsHandlerId>,
438 id: CpsHandlerId,
439) -> Result<(), CpsValidateError> {
440 if handler_ids.contains(&id) {
441 Ok(())
442 } else {
443 Err(CpsValidateError::MissingHandler {
444 function: function.name.clone(),
445 id,
446 })
447 }
448}