1use std::collections::{HashMap, HashSet};
2use std::fmt;
3
4use yulang_runtime as runtime;
5use yulang_typed_ir as typed_ir;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NativeBackendPlan {
9 pub roots: Vec<NativeRootBackend>,
10}
11
12impl NativeBackendPlan {
13 pub fn module_backend(&self) -> NativeBackendSelection {
14 self.roots
15 .iter()
16 .find_map(|root| match &root.selection {
17 NativeBackendSelection::CpsMainline { reason } => {
18 Some(NativeBackendSelection::CpsMainline {
19 reason: reason.clone(),
20 })
21 }
22 NativeBackendSelection::ValueFastPath => None,
23 NativeBackendSelection::Unsupported { reason } => {
24 Some(NativeBackendSelection::Unsupported {
25 reason: reason.clone(),
26 })
27 }
28 })
29 .unwrap_or(NativeBackendSelection::ValueFastPath)
30 }
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct NativeRootBackend {
35 pub root: NativeRootLabel,
36 pub selection: NativeBackendSelection,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum NativeRootLabel {
41 Binding(typed_ir::Path),
42 Expr(usize),
43}
44
45impl fmt::Display for NativeRootLabel {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 NativeRootLabel::Binding(path) => write!(f, "binding {:?}", path),
49 NativeRootLabel::Expr(index) => write!(f, "root expr {index}"),
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum NativeBackendSelection {
56 ValueFastPath,
57 CpsMainline { reason: NativeBackendReason },
58 Unsupported { reason: NativeBackendReason },
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct NativeBackendReason {
63 pub root: NativeRootLabel,
64 pub kind: NativeBackendReasonKind,
65}
66
67impl fmt::Display for NativeBackendReason {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 write!(f, "{} contains {}", self.root, self.kind)
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum NativeBackendReasonKind {
75 EffectOperation,
76 Handler,
77 Thunk,
78 ThunkBoundary,
79 ClosureValue,
80 StructuralPatternBinding,
81 EffectIdScope,
82 EffectIdRead,
83}
84
85impl fmt::Display for NativeBackendReasonKind {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 let text = match self {
88 NativeBackendReasonKind::EffectOperation => "effect operation",
89 NativeBackendReasonKind::Handler => "effect handler",
90 NativeBackendReasonKind::Thunk => "thunk",
91 NativeBackendReasonKind::ThunkBoundary => "thunk boundary",
92 NativeBackendReasonKind::ClosureValue => "closure value",
93 NativeBackendReasonKind::StructuralPatternBinding => "structural pattern binding",
94 NativeBackendReasonKind::EffectIdScope => "effect id scope",
95 NativeBackendReasonKind::EffectIdRead => "effect id read",
96 };
97 f.write_str(text)
98 }
99}
100
101pub fn select_native_backends(module: &runtime::Module) -> NativeBackendPlan {
102 let bindings = module
103 .bindings
104 .iter()
105 .map(|binding| (binding.name.clone(), &binding.body))
106 .collect::<HashMap<_, _>>();
107 let roots = module
108 .roots
109 .iter()
110 .map(|root| {
111 let label = root_label(root);
112 let reason = match root {
113 runtime::Root::Binding(path) => bindings
114 .get(path)
115 .and_then(|body| first_cps_reason(body, &bindings)),
116 runtime::Root::Expr(index) => module
117 .root_exprs
118 .get(*index)
119 .and_then(|expr| first_cps_reason(expr, &bindings)),
120 };
121 NativeRootBackend {
122 root: label.clone(),
123 selection: reason
124 .map(|kind| NativeBackendSelection::CpsMainline {
125 reason: NativeBackendReason { root: label, kind },
126 })
127 .unwrap_or(NativeBackendSelection::ValueFastPath),
128 }
129 })
130 .collect();
131 NativeBackendPlan { roots }
132}
133
134fn root_label(root: &runtime::Root) -> NativeRootLabel {
135 match root {
136 runtime::Root::Binding(path) => NativeRootLabel::Binding(path.clone()),
137 runtime::Root::Expr(index) => NativeRootLabel::Expr(*index),
138 }
139}
140
141fn first_cps_reason(
142 root: &runtime::Expr,
143 bindings: &HashMap<typed_ir::Path, &runtime::Expr>,
144) -> Option<NativeBackendReasonKind> {
145 let mut seen_bindings = HashSet::new();
146 first_cps_reason_expr(root, bindings, &mut seen_bindings)
147}
148
149fn first_cps_reason_expr(
150 expr: &runtime::Expr,
151 bindings: &HashMap<typed_ir::Path, &runtime::Expr>,
152 seen_bindings: &mut HashSet<typed_ir::Path>,
153) -> Option<NativeBackendReasonKind> {
154 match &expr.kind {
155 runtime::ExprKind::EffectOp(_) => Some(NativeBackendReasonKind::EffectOperation),
156 runtime::ExprKind::Handle { .. } => Some(NativeBackendReasonKind::Handler),
157 runtime::ExprKind::Thunk { .. } => Some(NativeBackendReasonKind::Thunk),
158 runtime::ExprKind::BindHere { .. } | runtime::ExprKind::AddId { .. } => {
159 Some(NativeBackendReasonKind::ThunkBoundary)
160 }
161 runtime::ExprKind::LocalPushId { .. } => Some(NativeBackendReasonKind::EffectIdScope),
162 runtime::ExprKind::PeekId | runtime::ExprKind::FindId { .. } => {
163 Some(NativeBackendReasonKind::EffectIdRead)
164 }
165 runtime::ExprKind::Var(path) => {
166 if seen_bindings.insert(path.clone()) {
167 let reason = bindings.get(path).and_then(|body| {
168 if binding_body_shadows_path(path, body) {
169 Some(NativeBackendReasonKind::StructuralPatternBinding)
170 } else {
171 first_cps_reason_expr(body, bindings, seen_bindings)
172 }
173 });
174 seen_bindings.remove(path);
175 reason
176 } else {
177 None
178 }
179 }
180 runtime::ExprKind::PrimitiveOp(_) | runtime::ExprKind::Lit(_) => None,
181 runtime::ExprKind::Lambda { .. } => Some(NativeBackendReasonKind::ClosureValue),
182 runtime::ExprKind::Apply { callee, arg, .. } => {
183 first_cps_reason_expr(callee, bindings, seen_bindings)
184 .or_else(|| first_cps_reason_expr(arg, bindings, seen_bindings))
185 }
186 runtime::ExprKind::If {
187 cond,
188 then_branch,
189 else_branch,
190 ..
191 } => first_cps_reason_expr(cond, bindings, seen_bindings)
192 .or_else(|| first_cps_reason_expr(then_branch, bindings, seen_bindings))
193 .or_else(|| first_cps_reason_expr(else_branch, bindings, seen_bindings)),
194 runtime::ExprKind::Tuple(items) => items
195 .iter()
196 .find_map(|item| first_cps_reason_expr(item, bindings, seen_bindings)),
197 runtime::ExprKind::Record { fields, spread } => fields
198 .iter()
199 .find_map(|field| first_cps_reason_expr(&field.value, bindings, seen_bindings))
200 .or_else(|| match spread {
201 Some(runtime::RecordSpreadExpr::Head(expr))
202 | Some(runtime::RecordSpreadExpr::Tail(expr)) => {
203 first_cps_reason_expr(expr, bindings, seen_bindings)
204 }
205 None => None,
206 }),
207 runtime::ExprKind::Variant { value, .. } => value
208 .as_deref()
209 .and_then(|value| first_cps_reason_expr(value, bindings, seen_bindings)),
210 runtime::ExprKind::Select { base, .. } => {
211 first_cps_reason_expr(base, bindings, seen_bindings)
212 }
213 runtime::ExprKind::Match {
214 scrutinee, arms, ..
215 } => first_cps_reason_expr(scrutinee, bindings, seen_bindings).or_else(|| {
216 arms.iter().find_map(|arm| {
217 arm.guard
218 .as_ref()
219 .and_then(|guard| first_cps_reason_expr(guard, bindings, seen_bindings))
220 .or_else(|| first_cps_reason_expr(&arm.body, bindings, seen_bindings))
221 })
222 }),
223 runtime::ExprKind::Block { stmts, tail } => stmts
224 .iter()
225 .find_map(|stmt| match stmt {
226 runtime::Stmt::Let { value, .. } | runtime::Stmt::Expr(value) => {
227 first_cps_reason_expr(value, bindings, seen_bindings)
228 }
229 runtime::Stmt::Module { body, .. } => {
230 first_cps_reason_expr(body, bindings, seen_bindings)
231 }
232 })
233 .or_else(|| {
234 tail.as_deref()
235 .and_then(|tail| first_cps_reason_expr(tail, bindings, seen_bindings))
236 }),
237 runtime::ExprKind::Coerce { expr, .. } | runtime::ExprKind::Pack { expr, .. } => {
238 first_cps_reason_expr(expr, bindings, seen_bindings)
239 }
240 }
241}
242
243fn binding_body_shadows_path(path: &typed_ir::Path, body: &runtime::Expr) -> bool {
244 match &body.kind {
245 runtime::ExprKind::Match { arms, .. } => arms
246 .iter()
247 .any(|arm| pattern_binds_path(&arm.pattern, path)),
248 runtime::ExprKind::Coerce { expr, .. } | runtime::ExprKind::Pack { expr, .. } => {
249 binding_body_shadows_path(path, expr)
250 }
251 _ => false,
252 }
253}
254
255fn pattern_binds_path(pattern: &runtime::Pattern, path: &typed_ir::Path) -> bool {
256 match pattern {
257 runtime::Pattern::Bind { name, .. } => typed_ir::Path::from_name(name.clone()) == *path,
258 runtime::Pattern::Tuple { items, .. } => {
259 items.iter().any(|item| pattern_binds_path(item, path))
260 }
261 runtime::Pattern::List {
262 prefix,
263 spread,
264 suffix,
265 ..
266 } => {
267 prefix.iter().any(|item| pattern_binds_path(item, path))
268 || spread
269 .as_deref()
270 .is_some_and(|spread| pattern_binds_path(spread, path))
271 || suffix.iter().any(|item| pattern_binds_path(item, path))
272 }
273 runtime::Pattern::Record { fields, spread, .. } => {
274 fields
275 .iter()
276 .any(|field| pattern_binds_path(&field.pattern, path))
277 || spread.as_ref().is_some_and(|spread| match spread {
278 runtime::RecordSpreadPattern::Head(pattern)
279 | runtime::RecordSpreadPattern::Tail(pattern) => {
280 pattern_binds_path(pattern, path)
281 }
282 })
283 }
284 runtime::Pattern::Variant {
285 value: Some(value), ..
286 }
287 | runtime::Pattern::As { pattern: value, .. } => pattern_binds_path(value, path),
288 runtime::Pattern::Or { left, right, .. } => {
289 pattern_binds_path(left, path) || pattern_binds_path(right, path)
290 }
291 runtime::Pattern::Wildcard { .. }
292 | runtime::Pattern::Lit { .. }
293 | runtime::Pattern::Variant { value: None, .. } => false,
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 fn module_with_root(expr: runtime::Expr) -> runtime::Module {
302 runtime::Module {
303 path: typed_ir::Path::default(),
304 bindings: Vec::new(),
305 root_exprs: vec![expr],
306 roots: vec![runtime::Root::Expr(0)],
307 role_impls: Vec::new(),
308 }
309 }
310
311 fn module_with_binding(
312 name: &str,
313 body: runtime::Expr,
314 root: runtime::Expr,
315 ) -> runtime::Module {
316 runtime::Module {
317 path: typed_ir::Path::default(),
318 bindings: vec![runtime::Binding {
319 name: path(name),
320 type_params: Vec::new(),
321 scheme: typed_ir::Scheme {
322 requirements: Vec::new(),
323 body: typed_ir::Type::Unknown,
324 },
325 body,
326 }],
327 root_exprs: vec![root],
328 roots: vec![runtime::Root::Expr(0)],
329 role_impls: Vec::new(),
330 }
331 }
332
333 fn path(name: &str) -> typed_ir::Path {
334 typed_ir::Path::from_name(typed_ir::Name(name.to_string()))
335 }
336
337 fn lit_int(value: &str) -> runtime::Expr {
338 runtime::Expr::typed(
339 runtime::ExprKind::Lit(typed_ir::Lit::Int(value.to_string())),
340 runtime::Type::unknown(),
341 )
342 }
343
344 fn var(name: &str) -> runtime::Expr {
345 runtime::Expr::typed(runtime::ExprKind::Var(path(name)), runtime::Type::unknown())
346 }
347
348 fn primitive(op: typed_ir::PrimitiveOp) -> runtime::Expr {
349 runtime::Expr::typed(runtime::ExprKind::PrimitiveOp(op), runtime::Type::unknown())
350 }
351
352 fn apply(callee: runtime::Expr, arg: runtime::Expr) -> runtime::Expr {
353 runtime::Expr::typed(
354 runtime::ExprKind::Apply {
355 callee: Box::new(callee),
356 arg: Box::new(arg),
357 evidence: None,
358 instantiation: None,
359 },
360 runtime::Type::unknown(),
361 )
362 }
363
364 fn list_pattern(items: Vec<runtime::Pattern>) -> runtime::Pattern {
365 runtime::Pattern::List {
366 prefix: items,
367 spread: None,
368 suffix: Vec::new(),
369 ty: runtime::Type::unknown(),
370 }
371 }
372
373 fn bind_pattern(name: &str) -> runtime::Pattern {
374 runtime::Pattern::Bind {
375 name: typed_ir::Name(name.to_string()),
376 ty: runtime::Type::unknown(),
377 }
378 }
379
380 fn identity_lambda() -> runtime::Expr {
381 runtime::Expr::typed(
382 runtime::ExprKind::Lambda {
383 param: typed_ir::Name("x".to_string()),
384 param_effect_annotation: None,
385 param_function_allowed_effects: None,
386 body: Box::new(var("x")),
387 },
388 runtime::Type::unknown(),
389 )
390 }
391
392 #[test]
393 fn selects_value_fast_path_for_pure_root() {
394 let plan = select_native_backends(&module_with_root(lit_int("42")));
395
396 assert_eq!(plan.module_backend(), NativeBackendSelection::ValueFastPath);
397 }
398
399 #[test]
400 fn selects_cps_mainline_for_effect_operation_root() {
401 let expr = runtime::Expr::typed(
402 runtime::ExprKind::EffectOp(path("yield")),
403 runtime::Type::unknown(),
404 );
405 let plan = select_native_backends(&module_with_root(expr));
406
407 assert_eq!(
408 plan.module_backend(),
409 NativeBackendSelection::CpsMainline {
410 reason: NativeBackendReason {
411 root: NativeRootLabel::Expr(0),
412 kind: NativeBackendReasonKind::EffectOperation,
413 },
414 }
415 );
416 }
417
418 #[test]
419 fn follows_reachable_binding_when_selecting_backend() {
420 let body = runtime::Expr::typed(
421 runtime::ExprKind::Handle {
422 body: Box::new(lit_int("1")),
423 arms: Vec::new(),
424 evidence: runtime::JoinEvidence {
425 result: typed_ir::Type::Unknown,
426 },
427 handler: runtime::HandleEffect {
428 consumes: Vec::new(),
429 residual_before: None,
430 residual_after: None,
431 },
432 },
433 runtime::Type::unknown(),
434 );
435 let plan = select_native_backends(&module_with_binding("run", body, var("run")));
436
437 assert_eq!(
438 plan.module_backend(),
439 NativeBackendSelection::CpsMainline {
440 reason: NativeBackendReason {
441 root: NativeRootLabel::Expr(0),
442 kind: NativeBackendReasonKind::Handler,
443 },
444 }
445 );
446 }
447
448 #[test]
449 fn selects_cps_mainline_for_closure_value_root() {
450 let expr = identity_lambda();
451 let plan = select_native_backends(&module_with_root(expr));
452
453 assert_eq!(
454 plan.module_backend(),
455 NativeBackendSelection::CpsMainline {
456 reason: NativeBackendReason {
457 root: NativeRootLabel::Expr(0),
458 kind: NativeBackendReasonKind::ClosureValue,
459 },
460 }
461 );
462 }
463
464 #[test]
465 fn selects_cps_mainline_for_closure_value_inside_record() {
466 let expr = runtime::Expr::typed(
467 runtime::ExprKind::Record {
468 fields: vec![runtime::RecordExprField {
469 name: typed_ir::Name("f".to_string()),
470 value: identity_lambda(),
471 }],
472 spread: None,
473 },
474 runtime::Type::unknown(),
475 );
476 let plan = select_native_backends(&module_with_root(expr));
477
478 assert_eq!(
479 plan.module_backend(),
480 NativeBackendSelection::CpsMainline {
481 reason: NativeBackendReason {
482 root: NativeRootLabel::Expr(0),
483 kind: NativeBackendReasonKind::ClosureValue,
484 },
485 }
486 );
487 }
488
489 #[test]
490 fn selects_cps_mainline_for_closure_value_inside_list_primitive() {
491 let expr = apply(
492 primitive(typed_ir::PrimitiveOp::ListSingleton),
493 identity_lambda(),
494 );
495 let plan = select_native_backends(&module_with_root(expr));
496
497 assert_eq!(
498 plan.module_backend(),
499 NativeBackendSelection::CpsMainline {
500 reason: NativeBackendReason {
501 root: NativeRootLabel::Expr(0),
502 kind: NativeBackendReasonKind::ClosureValue,
503 },
504 }
505 );
506 }
507
508 #[test]
509 fn selects_cps_mainline_for_self_shadowing_structural_binding() {
510 let body = runtime::Expr::typed(
511 runtime::ExprKind::Match {
512 scrutinee: Box::new(lit_int("0")),
513 arms: vec![runtime::MatchArm {
514 pattern: list_pattern(vec![bind_pattern("x"), bind_pattern("y")]),
515 guard: None,
516 body: var("x"),
517 }],
518 evidence: runtime::JoinEvidence {
519 result: typed_ir::Type::Unknown,
520 },
521 },
522 runtime::Type::unknown(),
523 );
524 let plan = select_native_backends(&module_with_binding("x", body, var("x")));
525
526 assert_eq!(
527 plan.module_backend(),
528 NativeBackendSelection::CpsMainline {
529 reason: NativeBackendReason {
530 root: NativeRootLabel::Expr(0),
531 kind: NativeBackendReasonKind::StructuralPatternBinding,
532 },
533 }
534 );
535 }
536}